Did some sick shit with plots
This commit is contained in:
73
main.py
73
main.py
@@ -9,7 +9,6 @@ import numpy as np
|
|||||||
|
|
||||||
import matplotlib as mpl
|
import matplotlib as mpl
|
||||||
mpl.use('TkAgg') # fixes my macOS bug
|
mpl.use('TkAgg') # fixes my macOS bug
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
@@ -99,28 +98,24 @@ def init_global(maze_filename):
|
|||||||
|
|
||||||
|
|
||||||
def plot_j_policy_on_maze(j, policy):
|
def plot_j_policy_on_maze(j, policy):
|
||||||
j_norm = (j - j.min()) / (j.max() - j.min()) + 1e-50
|
|
||||||
j_log = np.log10(j_norm)
|
|
||||||
print(j)
|
|
||||||
print(j_norm)
|
|
||||||
print(j_log)
|
|
||||||
print('-' * 50)
|
|
||||||
heatmap = np.full(MAZE.shape, np.nan)
|
heatmap = np.full(MAZE.shape, np.nan)
|
||||||
heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j
|
heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j
|
||||||
cmap = mpl.cm.get_cmap('coolwarm')
|
cmap = mpl.cm.get_cmap('coolwarm')
|
||||||
cmap.set_bad(color='black')
|
cmap.set_bad(color='black')
|
||||||
plt.imshow(heatmap, cmap=cmap)
|
plt.imshow(heatmap, cmap=cmap)
|
||||||
plt.colorbar()
|
# plt.colorbar()
|
||||||
# quiver has some weird behavior, the arrow y component must be flipped
|
# quiver has some weird behavior, the arrow y component must be flipped
|
||||||
plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0])
|
plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0])
|
||||||
plt.gca().get_xaxis().set_visible(False)
|
plt.gca().get_xaxis().set_visible(False)
|
||||||
plt.gca().get_yaxis().set_visible(False)
|
plt.tick_params(axis='y', which='both', left=False, labelleft=False)
|
||||||
|
|
||||||
|
|
||||||
def plot_cost_history(hist):
|
def plot_cost_history(hist):
|
||||||
error = np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
|
error = np.log10(
|
||||||
plt.xlabel('Number of iterations')
|
np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
|
||||||
plt.ylabel('Cost function RMSE')
|
)
|
||||||
|
plt.xticks(np.arange(0, len(error), len(error) // 5))
|
||||||
|
plt.yticks(np.linspace(error.min(), error.max(), 5))
|
||||||
plt.plot(error)
|
plt.plot(error)
|
||||||
|
|
||||||
|
|
||||||
@@ -143,13 +138,13 @@ def _evaluate_policy(policy, g):
|
|||||||
return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
|
return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
|
||||||
|
|
||||||
|
|
||||||
def value_iteration(j, g):
|
def value_iteration(g, j, **_):
|
||||||
return _policy_improvement(j, g)
|
return _policy_improvement(j, g)
|
||||||
|
|
||||||
|
|
||||||
def policy_iteration(j, g):
|
def policy_iteration(g, policy, **_):
|
||||||
policy, _ = _policy_improvement(j, g)
|
|
||||||
j = _evaluate_policy(policy, g)
|
j = _evaluate_policy(policy, g)
|
||||||
|
policy, _ = _policy_improvement(j, g)
|
||||||
return policy, j
|
return policy, j
|
||||||
|
|
||||||
|
|
||||||
@@ -164,12 +159,12 @@ def _terminate_vi(j, j_old, policy, policy_old):
|
|||||||
|
|
||||||
def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
||||||
j = np.zeros(SN, dtype=np.float64)
|
j = np.zeros(SN, dtype=np.float64)
|
||||||
policy = None
|
policy = np.full(SN, -1, dtype=np.int32) # idle policy
|
||||||
history = []
|
history = []
|
||||||
while True:
|
while True:
|
||||||
j_old = j
|
j_old = j
|
||||||
policy_old = policy
|
policy_old = policy
|
||||||
policy, j = optimizer_step(j, g)
|
policy, j = optimizer_step(g, j=j, policy=policy)
|
||||||
if return_history:
|
if return_history:
|
||||||
history.append(j)
|
history.append(j)
|
||||||
if terminator(j, j_old, policy, policy_old):
|
if terminator(j, j_old, policy, policy_old):
|
||||||
@@ -177,7 +172,13 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
|||||||
if not return_history:
|
if not return_history:
|
||||||
return j, policy
|
return j, policy
|
||||||
else:
|
else:
|
||||||
return np.array(history)
|
history = np.array(history)
|
||||||
|
|
||||||
|
# cover some edgy cases
|
||||||
|
if (history[-1] == history[-2]).all():
|
||||||
|
history = history[:-1]
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -198,8 +199,9 @@ if __name__ == '__main__':
|
|||||||
'Policy Iteration': _terminate_pi}
|
'Policy Iteration': _terminate_pi}
|
||||||
|
|
||||||
for a in [0.9, 0.5, 0.01]:
|
for a in [0.9, 0.5, 0.01]:
|
||||||
plt.figure(figsize=(9, 6))
|
plt.figure(figsize=(9, 7))
|
||||||
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
|
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95,
|
||||||
|
wspace=0.1)
|
||||||
plt.suptitle('DISCOUNT = ' + str(a))
|
plt.suptitle('DISCOUNT = ' + str(a))
|
||||||
i = 1
|
i = 1
|
||||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||||
@@ -209,27 +211,42 @@ if __name__ == '__main__':
|
|||||||
j, policy = dynamic_programming(optimizers[opt], costs[cost],
|
j, policy = dynamic_programming(optimizers[opt], costs[cost],
|
||||||
terminators[opt])
|
terminators[opt])
|
||||||
plt.subplot(2, 2, i)
|
plt.subplot(2, 2, i)
|
||||||
plt.gca().set_title(name)
|
|
||||||
plot_j_policy_on_maze(j, policy)
|
plot_j_policy_on_maze(j, policy)
|
||||||
|
if i <= 2:
|
||||||
|
plt.gca().set_title('Cost: {}'.format(cost),
|
||||||
|
fontsize='x-large')
|
||||||
|
if (i - 1) % 2 == 0:
|
||||||
|
plt.ylabel(opt, fontsize='x-large')
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
# Error graphs
|
# Error graphs
|
||||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||||
plt.figure(figsize=(9, 6))
|
plt.figure(figsize=(7, 10))
|
||||||
plt.subplots_adjust(wspace=0.4, hspace=0.4)
|
plt.figtext(0.5, 0.04, 'Number of iterations', ha='center',
|
||||||
|
fontsize='large')
|
||||||
|
plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center',
|
||||||
|
rotation='vertical', fontsize='large')
|
||||||
|
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92,
|
||||||
|
top=0.9)
|
||||||
plt.suptitle(opt)
|
plt.suptitle(opt)
|
||||||
i = 1
|
i = 1
|
||||||
for cost in ['g1', 'g2']:
|
for a in [0.99, 0.7, 0.1]:
|
||||||
for a in [0.99, 0.7, 0.5]:
|
for cost in ['g1', 'g2']:
|
||||||
name = 'Cost: {}, discount: {}'.format(cost, a)
|
# name = 'Cost: {}, discount: {}'.format(cost, a)
|
||||||
ALPHA = a
|
ALPHA = a
|
||||||
history = dynamic_programming(optimizers[opt], costs[cost],
|
history = dynamic_programming(optimizers[opt], costs[cost],
|
||||||
terminators[opt],
|
terminators[opt],
|
||||||
return_history=True)
|
return_history=True)
|
||||||
plt.subplot(2, 3, i)
|
plt.subplot(3, 2, i)
|
||||||
plt.gca().set_title(name)
|
# plt.gca().set_title(name)
|
||||||
plot_cost_history(history)
|
plot_cost_history(history)
|
||||||
|
if i <= 2:
|
||||||
|
plt.gca().set_title('Cost: {}'.format(cost))
|
||||||
|
if (i - 1) % 2 == 0:
|
||||||
|
plt.ylabel('Discount: {}'.format(a), fontsize='large')
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
print('I ran in {} seconds'.format(time() - start))
|
print('I ran in {} seconds'.format(time() - start))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user