Did some sick shit with plots

This commit is contained in:
2019-01-26 18:21:08 +01:00
parent abb3649194
commit 0be126e57b

73
main.py
View File

@@ -9,7 +9,6 @@ import numpy as np
import matplotlib as mpl
mpl.use('TkAgg') # fixes my macOS bug
import matplotlib.pyplot as plt
@@ -99,28 +98,24 @@ def init_global(maze_filename):
def plot_j_policy_on_maze(j, policy):
j_norm = (j - j.min()) / (j.max() - j.min()) + 1e-50
j_log = np.log10(j_norm)
print(j)
print(j_norm)
print(j_log)
print('-' * 50)
heatmap = np.full(MAZE.shape, np.nan)
heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j
cmap = mpl.cm.get_cmap('coolwarm')
cmap.set_bad(color='black')
plt.imshow(heatmap, cmap=cmap)
plt.colorbar()
# plt.colorbar()
# quiver has some weird behavior, the arrow y component must be flipped
plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0])
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.tick_params(axis='y', which='both', left=False, labelleft=False)
def plot_cost_history(hist):
error = np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
plt.xlabel('Number of iterations')
plt.ylabel('Cost function RMSE')
error = np.log10(
np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
)
plt.xticks(np.arange(0, len(error), len(error) // 5))
plt.yticks(np.linspace(error.min(), error.max(), 5))
plt.plot(error)
@@ -143,13 +138,13 @@ def _evaluate_policy(policy, g):
return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
def value_iteration(j, g):
def value_iteration(g, j, **_):
return _policy_improvement(j, g)
def policy_iteration(j, g):
policy, _ = _policy_improvement(j, g)
def policy_iteration(g, policy, **_):
j = _evaluate_policy(policy, g)
policy, _ = _policy_improvement(j, g)
return policy, j
@@ -164,12 +159,12 @@ def _terminate_vi(j, j_old, policy, policy_old):
def dynamic_programming(optimizer_step, g, terminator, return_history=False):
j = np.zeros(SN, dtype=np.float64)
policy = None
policy = np.full(SN, -1, dtype=np.int32) # idle policy
history = []
while True:
j_old = j
policy_old = policy
policy, j = optimizer_step(j, g)
policy, j = optimizer_step(g, j=j, policy=policy)
if return_history:
history.append(j)
if terminator(j, j_old, policy, policy_old):
@@ -177,7 +172,13 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False):
if not return_history:
return j, policy
else:
return np.array(history)
history = np.array(history)
# cover some edgy cases
if (history[-1] == history[-2]).all():
history = history[:-1]
return history
if __name__ == '__main__':
@@ -198,8 +199,9 @@ if __name__ == '__main__':
'Policy Iteration': _terminate_pi}
for a in [0.9, 0.5, 0.01]:
plt.figure(figsize=(9, 6))
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
plt.figure(figsize=(9, 7))
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95,
wspace=0.1)
plt.suptitle('DISCOUNT = ' + str(a))
i = 1
for opt in ['Value Iteration', 'Policy Iteration']:
@@ -209,27 +211,42 @@ if __name__ == '__main__':
j, policy = dynamic_programming(optimizers[opt], costs[cost],
terminators[opt])
plt.subplot(2, 2, i)
plt.gca().set_title(name)
plot_j_policy_on_maze(j, policy)
if i <= 2:
plt.gca().set_title('Cost: {}'.format(cost),
fontsize='x-large')
if (i - 1) % 2 == 0:
plt.ylabel(opt, fontsize='x-large')
i += 1
# Error graphs
for opt in ['Value Iteration', 'Policy Iteration']:
plt.figure(figsize=(9, 6))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
plt.figure(figsize=(7, 10))
plt.figtext(0.5, 0.04, 'Number of iterations', ha='center',
fontsize='large')
plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center',
rotation='vertical', fontsize='large')
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92,
top=0.9)
plt.suptitle(opt)
i = 1
for cost in ['g1', 'g2']:
for a in [0.99, 0.7, 0.5]:
name = 'Cost: {}, discount: {}'.format(cost, a)
for a in [0.99, 0.7, 0.1]:
for cost in ['g1', 'g2']:
# name = 'Cost: {}, discount: {}'.format(cost, a)
ALPHA = a
history = dynamic_programming(optimizers[opt], costs[cost],
terminators[opt],
return_history=True)
plt.subplot(2, 3, i)
plt.gca().set_title(name)
plt.subplot(3, 2, i)
# plt.gca().set_title(name)
plot_cost_history(history)
if i <= 2:
plt.gca().set_title('Cost: {}'.format(cost))
if (i - 1) % 2 == 0:
plt.ylabel('Discount: {}'.format(a), fontsize='large')
i += 1
print('I ran in {} seconds'.format(time() - start))
plt.show()