Did some sick shit with plots
This commit is contained in:
73
main.py
73
main.py
@@ -9,7 +9,6 @@ import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
mpl.use('TkAgg') # fixes my macOS bug
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@@ -99,28 +98,24 @@ def init_global(maze_filename):
|
||||
|
||||
|
||||
def plot_j_policy_on_maze(j, policy):
|
||||
j_norm = (j - j.min()) / (j.max() - j.min()) + 1e-50
|
||||
j_log = np.log10(j_norm)
|
||||
print(j)
|
||||
print(j_norm)
|
||||
print(j_log)
|
||||
print('-' * 50)
|
||||
heatmap = np.full(MAZE.shape, np.nan)
|
||||
heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j
|
||||
cmap = mpl.cm.get_cmap('coolwarm')
|
||||
cmap.set_bad(color='black')
|
||||
plt.imshow(heatmap, cmap=cmap)
|
||||
plt.colorbar()
|
||||
# plt.colorbar()
|
||||
# quiver has some weird behavior, the arrow y component must be flipped
|
||||
plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0])
|
||||
plt.gca().get_xaxis().set_visible(False)
|
||||
plt.gca().get_yaxis().set_visible(False)
|
||||
plt.tick_params(axis='y', which='both', left=False, labelleft=False)
|
||||
|
||||
|
||||
def plot_cost_history(hist):
|
||||
error = np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
|
||||
plt.xlabel('Number of iterations')
|
||||
plt.ylabel('Cost function RMSE')
|
||||
error = np.log10(
|
||||
np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1))
|
||||
)
|
||||
plt.xticks(np.arange(0, len(error), len(error) // 5))
|
||||
plt.yticks(np.linspace(error.min(), error.max(), 5))
|
||||
plt.plot(error)
|
||||
|
||||
|
||||
@@ -143,13 +138,13 @@ def _evaluate_policy(policy, g):
|
||||
return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
|
||||
|
||||
|
||||
def value_iteration(j, g):
|
||||
def value_iteration(g, j, **_):
|
||||
return _policy_improvement(j, g)
|
||||
|
||||
|
||||
def policy_iteration(j, g):
|
||||
policy, _ = _policy_improvement(j, g)
|
||||
def policy_iteration(g, policy, **_):
|
||||
j = _evaluate_policy(policy, g)
|
||||
policy, _ = _policy_improvement(j, g)
|
||||
return policy, j
|
||||
|
||||
|
||||
@@ -164,12 +159,12 @@ def _terminate_vi(j, j_old, policy, policy_old):
|
||||
|
||||
def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
||||
j = np.zeros(SN, dtype=np.float64)
|
||||
policy = None
|
||||
policy = np.full(SN, -1, dtype=np.int32) # idle policy
|
||||
history = []
|
||||
while True:
|
||||
j_old = j
|
||||
policy_old = policy
|
||||
policy, j = optimizer_step(j, g)
|
||||
policy, j = optimizer_step(g, j=j, policy=policy)
|
||||
if return_history:
|
||||
history.append(j)
|
||||
if terminator(j, j_old, policy, policy_old):
|
||||
@@ -177,7 +172,13 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
||||
if not return_history:
|
||||
return j, policy
|
||||
else:
|
||||
return np.array(history)
|
||||
history = np.array(history)
|
||||
|
||||
# cover some edgy cases
|
||||
if (history[-1] == history[-2]).all():
|
||||
history = history[:-1]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -198,8 +199,9 @@ if __name__ == '__main__':
|
||||
'Policy Iteration': _terminate_pi}
|
||||
|
||||
for a in [0.9, 0.5, 0.01]:
|
||||
plt.figure(figsize=(9, 6))
|
||||
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
|
||||
plt.figure(figsize=(9, 7))
|
||||
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95,
|
||||
wspace=0.1)
|
||||
plt.suptitle('DISCOUNT = ' + str(a))
|
||||
i = 1
|
||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||
@@ -209,27 +211,42 @@ if __name__ == '__main__':
|
||||
j, policy = dynamic_programming(optimizers[opt], costs[cost],
|
||||
terminators[opt])
|
||||
plt.subplot(2, 2, i)
|
||||
plt.gca().set_title(name)
|
||||
plot_j_policy_on_maze(j, policy)
|
||||
if i <= 2:
|
||||
plt.gca().set_title('Cost: {}'.format(cost),
|
||||
fontsize='x-large')
|
||||
if (i - 1) % 2 == 0:
|
||||
plt.ylabel(opt, fontsize='x-large')
|
||||
i += 1
|
||||
|
||||
# Error graphs
|
||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||
plt.figure(figsize=(9, 6))
|
||||
plt.subplots_adjust(wspace=0.4, hspace=0.4)
|
||||
plt.figure(figsize=(7, 10))
|
||||
plt.figtext(0.5, 0.04, 'Number of iterations', ha='center',
|
||||
fontsize='large')
|
||||
plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center',
|
||||
rotation='vertical', fontsize='large')
|
||||
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92,
|
||||
top=0.9)
|
||||
plt.suptitle(opt)
|
||||
i = 1
|
||||
for cost in ['g1', 'g2']:
|
||||
for a in [0.99, 0.7, 0.5]:
|
||||
name = 'Cost: {}, discount: {}'.format(cost, a)
|
||||
for a in [0.99, 0.7, 0.1]:
|
||||
for cost in ['g1', 'g2']:
|
||||
# name = 'Cost: {}, discount: {}'.format(cost, a)
|
||||
ALPHA = a
|
||||
history = dynamic_programming(optimizers[opt], costs[cost],
|
||||
terminators[opt],
|
||||
return_history=True)
|
||||
plt.subplot(2, 3, i)
|
||||
plt.gca().set_title(name)
|
||||
plt.subplot(3, 2, i)
|
||||
# plt.gca().set_title(name)
|
||||
plot_cost_history(history)
|
||||
if i <= 2:
|
||||
plt.gca().set_title('Cost: {}'.format(cost))
|
||||
if (i - 1) % 2 == 0:
|
||||
plt.ylabel('Discount: {}'.format(a), fontsize='large')
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
print('I ran in {} seconds'.format(time() - start))
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user