From 9127b93d10b66cfda38a1de280fbb9f05683989d Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Sun, 16 Dec 2018 17:33:44 +0100 Subject: [PATCH] Implemented graphs, kinda ugly for now --- main.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 4a4559a..084d934 100644 --- a/main.py +++ b/main.py @@ -2,8 +2,8 @@ from __future__ import print_function from __future__ import division from __future__ import unicode_literals -import itertools from argparse import ArgumentParser +from time import time import numpy as np @@ -14,8 +14,8 @@ import matplotlib.pyplot as plt P = 0.1 -ALPHA = 0.80 -EPSILON = 0.0001 # Convergence criterium +ALPHA = 0.90 +EPSILON = 1e-8 # Convergence criterium # Global state MAZE = None # Map of the environment @@ -79,11 +79,6 @@ def _valid_target(target): ) -def _cartesian(arr): - """Return cartesian product of sets.""" - return itertools.product(*arr) - - def u_of_x(x): """Return a list of allowed actions for the given state x.""" return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))] @@ -132,7 +127,15 @@ def plot_j_policy_on_maze(j, policy): plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0], [ACTIONS[u][1] for u in policy], [-ACTIONS[u][0] for u in policy]) - plt.show() + plt.gca().get_xaxis().set_visible(False) + plt.gca().get_yaxis().set_visible(False) + + +def plot_cost_history(hist): + error = [((h - hist[-1])**2).sum()**0.5 for h in hist[:-1]] + plt.xlabel('Number of iterations') + plt.ylabel('Cost function error') + plt.plot(error) def _policy_improvement(j, g): @@ -158,8 +161,9 @@ def _evaluate_policy(policy, g): return np.linalg.solve(np.eye(len(S_TO_IJ)) - ALPHA*M, G) -def value_iteration(g): +def value_iteration(g, return_history=False): j = np.random.randn(len(S_TO_IJ)) + history = [j] while True: policy = _policy_improvement(j, g) j_new = [] @@ -167,28 +171,41 @@ def value_iteration(g): j_new.append(h_function(x, u, j, g)) j_old = j j = np.array(j_new) + if return_history: + history.append(j) if max(abs(j - j_old)) < EPSILON: break - return j, policy + if not return_history: + return j, policy + else: + return history -def policy_iteration(g): +def policy_iteration(g, return_history=False): j = None policy = [np.random.choice(u_of_x(x)) for x in S_TO_IJ] + history = [] while True: j_old = j j = _evaluate_policy(policy, g) + history.append(j) if j_old is not None and max(abs(j - j_old)) < EPSILON: break policy = _policy_improvement(j, g) - return j, policy + if not return_history: + return j, policy + else: + return history if __name__ == '__main__': + # Argument Parsing ap = ArgumentParser() ap.add_argument('maze_file', help='Path to maze file') args = ap.parse_args() + start = time() + # Initialization MAZE = np.genfromtxt( args.maze_file, dtype=str, @@ -196,6 +213,40 @@ if __name__ == '__main__': STATE_MASK = (MAZE != '1') S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK] - j, policy = value_iteration(cost_treasure) - print(j) - plot_j_policy_on_maze(j, policy) + # J / policy for both algorithms for both cost functions for 3 alphas + costs = {'g1': cost_treasure, 'g2': cost_energy} + optimizers = {'Value Iteration': value_iteration, + 'Policy Iteration': policy_iteration} + + for a in [0.9, 0.5, 0.01]: + plt.figure() + plt.suptitle('DISCOUNT = ' + str(a)) + i = 1 + for opt in ['Value Iteration', 'Policy Iteration']: + for g in ['g1', 'g2']: + name = ' / '.join([opt, g]) + ALPHA = a + j, policy = optimizers[opt](costs[g]) + plt.subplot(2, 2, i) + plt.gca().set_title(name) + plot_j_policy_on_maze(j, policy) + i += 1 + + # plt.show() + # Error graphs + for opt in ['Value Iteration', 'Policy Iteration']: + plt.figure() + plt.suptitle(opt) + i = 1 + for g in ['g1', 'g2']: + for a in [0.9, 0.8, 0.7]: + name = 'Cost: {}, discount: {}'.format(g, a) + ALPHA = a + history = optimizers[opt](costs[g], return_history=True) + plt.subplot(2, 3, i) + plt.gca().set_title(name) + plot_cost_history(history) + i += 1 + + print('I ran in {} seconds'.format(time() - start)) + plt.show()