diff --git a/main.py b/main.py index fdac15f..b96d226 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,6 @@ import numpy as np import matplotlib as mpl mpl.use('TkAgg') # fixes my macOS bug - import matplotlib.pyplot as plt @@ -99,28 +98,24 @@ def init_global(maze_filename): def plot_j_policy_on_maze(j, policy): - j_norm = (j - j.min()) / (j.max() - j.min()) + 1e-50 - j_log = np.log10(j_norm) - print(j) - print(j_norm) - print(j_log) - print('-' * 50) heatmap = np.full(MAZE.shape, np.nan) heatmap[S_TO_IJ[:, 0], S_TO_IJ[:, 1]] = j cmap = mpl.cm.get_cmap('coolwarm') cmap.set_bad(color='black') plt.imshow(heatmap, cmap=cmap) - plt.colorbar() + # plt.colorbar() # quiver has some weird behavior, the arrow y component must be flipped plt.quiver(S_TO_IJ[:, 1], S_TO_IJ[:, 0], A2[policy, 1], -A2[policy, 0]) plt.gca().get_xaxis().set_visible(False) - plt.gca().get_yaxis().set_visible(False) + plt.tick_params(axis='y', which='both', left=False, labelleft=False) def plot_cost_history(hist): - error = np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1)) - plt.xlabel('Number of iterations') - plt.ylabel('Cost function RMSE') + error = np.log10( + np.sqrt(np.square(hist[:-1] - hist[-1]).mean(axis=1)) + ) + plt.xticks(np.arange(0, len(error), len(error) // 5)) + plt.yticks(np.linspace(error.min(), error.max(), 5)) plt.plot(error) @@ -143,13 +138,13 @@ def _evaluate_policy(policy, g): return np.linalg.solve(np.eye(SN) - ALPHA*M, G) -def value_iteration(j, g): +def value_iteration(g, j, **_): return _policy_improvement(j, g) -def policy_iteration(j, g): - policy, _ = _policy_improvement(j, g) +def policy_iteration(g, policy, **_): j = _evaluate_policy(policy, g) + policy, _ = _policy_improvement(j, g) return policy, j @@ -164,12 +159,12 @@ def _terminate_vi(j, j_old, policy, policy_old): def dynamic_programming(optimizer_step, g, terminator, return_history=False): j = np.zeros(SN, dtype=np.float64) - policy = None + policy = np.full(SN, -1, dtype=np.int32) # idle policy history = [] while True: j_old = j policy_old = policy - policy, j = optimizer_step(j, g) + policy, j = optimizer_step(g, j=j, policy=policy) if return_history: history.append(j) if terminator(j, j_old, policy, policy_old): @@ -177,7 +172,13 @@ def dynamic_programming(optimizer_step, g, terminator, return_history=False): if not return_history: return j, policy else: - return np.array(history) + history = np.array(history) + + # cover some edgy cases + if (history[-1] == history[-2]).all(): + history = history[:-1] + + return history if __name__ == '__main__': @@ -198,8 +199,9 @@ if __name__ == '__main__': 'Policy Iteration': _terminate_pi} for a in [0.9, 0.5, 0.01]: - plt.figure(figsize=(9, 6)) - plt.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95) + plt.figure(figsize=(9, 7)) + plt.subplots_adjust(top=0.9, bottom=0.05, left=0.1, right=0.95, + wspace=0.1) plt.suptitle('DISCOUNT = ' + str(a)) i = 1 for opt in ['Value Iteration', 'Policy Iteration']: @@ -209,27 +211,42 @@ if __name__ == '__main__': j, policy = dynamic_programming(optimizers[opt], costs[cost], terminators[opt]) plt.subplot(2, 2, i) - plt.gca().set_title(name) plot_j_policy_on_maze(j, policy) + if i <= 2: + plt.gca().set_title('Cost: {}'.format(cost), + fontsize='x-large') + if (i - 1) % 2 == 0: + plt.ylabel(opt, fontsize='x-large') i += 1 # Error graphs for opt in ['Value Iteration', 'Policy Iteration']: - plt.figure(figsize=(9, 6)) - plt.subplots_adjust(wspace=0.4, hspace=0.4) + plt.figure(figsize=(7, 10)) + plt.figtext(0.5, 0.04, 'Number of iterations', ha='center', + fontsize='large') + plt.figtext(0.05, 0.5, 'Logarithm of cost RMSE', va='center', + rotation='vertical', fontsize='large') + plt.subplots_adjust(wspace=0.38, hspace=0.35, left=0.205, right=0.92, + top=0.9) plt.suptitle(opt) i = 1 - for cost in ['g1', 'g2']: - for a in [0.99, 0.7, 0.5]: - name = 'Cost: {}, discount: {}'.format(cost, a) + for a in [0.99, 0.7, 0.1]: + for cost in ['g1', 'g2']: + # name = 'Cost: {}, discount: {}'.format(cost, a) ALPHA = a history = dynamic_programming(optimizers[opt], costs[cost], terminators[opt], return_history=True) - plt.subplot(2, 3, i) - plt.gca().set_title(name) + plt.subplot(3, 2, i) + # plt.gca().set_title(name) plot_cost_history(history) + if i <= 2: + plt.gca().set_title('Cost: {}'.format(cost)) + if (i - 1) % 2 == 0: + plt.ylabel('Discount: {}'.format(a), fontsize='large') + i += 1 + print('I ran in {} seconds'.format(time() - start)) plt.show()