Did some sick shit with plots

This commit is contained in:
2019-01-26 18:21:08 +01:00
parent abb3649194
commit 0be126e57b

71
main.py
View File

@@ -9,7 +9,6 @@ import numpy as np
import matplotlib as mpl import matplotlib as mpl
mpl.use('TkAgg') # fixes my macOS bug mpl.use('TkAgg') # fixes my macOS bug
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -99,28 +98,24 @@ def init_global(maze_filename):
def plot_j_policy_on_maze(j, policy): def plot_j_policy_on_maze(j, policy):
j_norm = (j - j.min()) / (j.max() - j.min()) + 1e-50
j_log = np.log10(j_norm)
print(j)
print(j_norm)
print(j_log)
print('-' * 50)
heatmap = np.full(MAZE.shape, np.nan) heatmap = np.full(MAZE.shape, np.nan)
heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j
cmap = mpl.cm.get_cmap('coolwarm') cmap = mpl.cm.get_cmap('coolwarm')
cmap.set_bad(color='black') cmap.set_bad(color='black')
plt.imshow(heatmap, cmap=cmap) plt.imshow(heatmap, cmap=cmap)
plt.colorbar() # plt.colorbar()
# quiver has some weird behavior, the arrow y component must be flipped # quiver has some weird behavior, the arrow y component must be flipped
plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0]) plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0])
plt.gca().get_xaxis().set_visible(False) plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False) plt.tick_params(axis='y', which='both', left=False, labelleft=False)
def plot_cost_history(hist): def plot_cost_history(hist):
error = np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1)) error = np.log10(
plt.xlabel('Number of iterations') np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
plt.ylabel('Cost function RMSE') )
plt.xticks(np.arange(0, len(error), len(error) // 5))
plt.yticks(np.linspace(error.min(), error.max(), 5))
plt.plot(error) plt.plot(error)
@@ -143,13 +138,13 @@ def _evaluate_policy(policy, g):
return np.linalg.solve(np.eye(SN) - ALPHA*M, G) return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
def value_iteration(j, g): def value_iteration(g, j, **_):
return _policy_improvement(j, g) return _policy_improvement(j, g)
def policy_iteration(j, g): def policy_iteration(g, policy, **_):
policy, _ = _policy_improvement(j, g)
j = _evaluate_policy(policy, g) j = _evaluate_policy(policy, g)
policy, _ = _policy_improvement(j, g)
return policy, j return policy, j
@@ -164,12 +159,12 @@ def _terminate_vi(j, j_old, policy, policy_old):
def dynamic_programming(optimizer_step, g, terminator, return_history=False): def dynamic_programming(optimizer_step, g, terminator, return_history=False):
j = np.zeros(SN, dtype=np.float64) j = np.zeros(SN, dtype=np.float64)
policy = None policy = np.full(SN, -1, dtype=np.int32) # idle policy
history = [] history = []
while True: while True:
j_old = j j_old = j
policy_old = policy policy_old = policy
policy, j = optimizer_step(j, g) policy, j = optimizer_step(g, j=j, policy=policy)
if return_history: if return_history:
history.append(j) history.append(j)
if terminator(j, j_old, policy, policy_old): if terminator(j, j_old, policy, policy_old):
@@ -177,7 +172,13 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False):
if not return_history: if not return_history:
return j, policy return j, policy
else: else:
return np.array(history) history = np.array(history)
# cover some edgy cases
if (history[-1] == history[-2]).all():
history = history[:-1]
return history
if __name__ == '__main__': if __name__ == '__main__':
@@ -198,8 +199,9 @@ if __name__ == '__main__':
'Policy Iteration': _terminate_pi} 'Policy Iteration': _terminate_pi}
for a in [0.9, 0.5, 0.01]: for a in [0.9, 0.5, 0.01]:
plt.figure(figsize=(9, 6)) plt.figure(figsize=(9, 7))
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95) plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95,
wspace=0.1)
plt.suptitle('DISCOUNT = ' + str(a)) plt.suptitle('DISCOUNT = ' + str(a))
i = 1 i = 1
for opt in ['Value Iteration', 'Policy Iteration']: for opt in ['Value Iteration', 'Policy Iteration']:
@@ -209,27 +211,42 @@ if __name__ == '__main__':
j, policy = dynamic_programming(optimizers[opt], costs[cost], j, policy = dynamic_programming(optimizers[opt], costs[cost],
terminators[opt]) terminators[opt])
plt.subplot(2, 2, i) plt.subplot(2, 2, i)
plt.gca().set_title(name)
plot_j_policy_on_maze(j, policy) plot_j_policy_on_maze(j, policy)
if i <= 2:
plt.gca().set_title('Cost: {}'.format(cost),
fontsize='x-large')
if (i - 1) % 2 == 0:
plt.ylabel(opt, fontsize='x-large')
i += 1 i += 1
# Error graphs # Error graphs
for opt in ['Value Iteration', 'Policy Iteration']: for opt in ['Value Iteration', 'Policy Iteration']:
plt.figure(figsize=(9, 6)) plt.figure(figsize=(7, 10))
plt.subplots_adjust(wspace=0.4, hspace=0.4) plt.figtext(0.5, 0.04, 'Number of iterations', ha='center',
fontsize='large')
plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center',
rotation='vertical', fontsize='large')
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92,
top=0.9)
plt.suptitle(opt) plt.suptitle(opt)
i = 1 i = 1
for a in [0.99, 0.7, 0.1]:
for cost in ['g1', 'g2']: for cost in ['g1', 'g2']:
for a in [0.99, 0.7, 0.5]: # name = 'Cost: {}, discount: {}'.format(cost, a)
name = 'Cost: {}, discount: {}'.format(cost, a)
ALPHA = a ALPHA = a
history = dynamic_programming(optimizers[opt], costs[cost], history = dynamic_programming(optimizers[opt], costs[cost],
terminators[opt], terminators[opt],
return_history=True) return_history=True)
plt.subplot(2, 3, i) plt.subplot(3, 2, i)
plt.gca().set_title(name) # plt.gca().set_title(name)
plot_cost_history(history) plot_cost_history(history)
if i <= 2:
plt.gca().set_title('Cost: {}'.format(cost))
if (i - 1) % 2 == 0:
plt.ylabel('Discount: {}'.format(a), fontsize='large')
i += 1 i += 1
print('I ran in {} seconds'.format(time() - start)) print('I ran in {} seconds'.format(time() - start))
plt.show() plt.show()