I may as well submit this version
This commit is contained in:
12
main.py
12
main.py
@@ -162,7 +162,6 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False):
|
||||
|
||||
|
||||
def plot_j_policy_on_maze(j, policy, normalize=True):
|
||||
|
||||
heatmap = np.full(MAZE.shape, np.nan, dtype=np.float64)
|
||||
if normalize:
|
||||
# Non-linear, but a discrete representation of different costs
|
||||
@@ -221,7 +220,7 @@ if __name__ == '__main__':
|
||||
plt.figure(figsize=(9, 7))
|
||||
plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95,
|
||||
wspace=0.1)
|
||||
plt.suptitle('DISCOUNT: {}'.format(a) +
|
||||
plt.suptitle('Discount: {}'.format(a) +
|
||||
('\nNormalized view' if normalize else ''))
|
||||
i = 1
|
||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||
@@ -242,24 +241,22 @@ if __name__ == '__main__':
|
||||
|
||||
# Error graphs
|
||||
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||
plt.figure(figsize=(7, 10))
|
||||
plt.figure(figsize=(6, 10))
|
||||
plt.figtext(0.5, 0.04, 'Number of iterations', ha='center',
|
||||
fontsize='large')
|
||||
plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center',
|
||||
plt.figtext(0.01, 0.5, 'Logarithm of cost RMSE', va='center',
|
||||
rotation='vertical', fontsize='large')
|
||||
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92,
|
||||
plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.98,
|
||||
top=0.9)
|
||||
plt.suptitle(opt)
|
||||
i = 1
|
||||
for a in [0.99, 0.7, 0.1]:
|
||||
for cost in ['g1', 'g2']:
|
||||
# name = 'Cost: {}, discount: {}'.format(cost, a)
|
||||
ALPHA = a
|
||||
history = dynamic_programming(optimizers[opt], costs[cost],
|
||||
terminators[opt],
|
||||
return_history=True)
|
||||
plt.subplot(3, 2, i)
|
||||
# plt.gca().set_title(name)
|
||||
plot_cost_history(history)
|
||||
if i <= 2:
|
||||
plt.gca().set_title('Cost: {}'.format(cost))
|
||||
@@ -268,6 +265,5 @@ if __name__ == '__main__':
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
print('I ran in {} seconds'.format(time() - start))
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user