Implemented graphs, kinda ugly for now
This commit is contained in:
83
main.py
83
main.py
@@ -2,8 +2,8 @@ from __future__ import print_function
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import itertools
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from time import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -14,8 +14,8 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
|
|
||||||
P = 0.1
|
P = 0.1
|
||||||
ALPHA = 0.80
|
ALPHA = 0.90
|
||||||
EPSILON = 0.0001 # Convergence criterium
|
EPSILON = 1e-8 # Convergence criterium
|
||||||
|
|
||||||
# Global state
|
# Global state
|
||||||
MAZE = None # Map of the environment
|
MAZE = None # Map of the environment
|
||||||
@@ -79,11 +79,6 @@ def _valid_target(target):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cartesian(arr):
|
|
||||||
"""Return cartesian product of sets."""
|
|
||||||
return itertools.product(*arr)
|
|
||||||
|
|
||||||
|
|
||||||
def u_of_x(x):
|
def u_of_x(x):
|
||||||
"""Return a list of allowed actions for the given state x."""
|
"""Return a list of allowed actions for the given state x."""
|
||||||
return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))]
|
return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))]
|
||||||
@@ -132,7 +127,15 @@ def plot_j_policy_on_maze(j, policy):
|
|||||||
plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0],
|
plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0],
|
||||||
[ACTIONS[u][1] for u in policy],
|
[ACTIONS[u][1] for u in policy],
|
||||||
[-ACTIONS[u][0] for u in policy])
|
[-ACTIONS[u][0] for u in policy])
|
||||||
plt.show()
|
plt.gca().get_xaxis().set_visible(False)
|
||||||
|
plt.gca().get_yaxis().set_visible(False)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_cost_history(hist):
|
||||||
|
error = [((h - hist[-1])**2).sum()**0.5 for h in hist[:-1]]
|
||||||
|
plt.xlabel('Number of iterations')
|
||||||
|
plt.ylabel('Cost function error')
|
||||||
|
plt.plot(error)
|
||||||
|
|
||||||
|
|
||||||
def _policy_improvement(j, g):
|
def _policy_improvement(j, g):
|
||||||
@@ -158,8 +161,9 @@ def _evaluate_policy(policy, g):
|
|||||||
return np.linalg.solve(np.eye(len(S_TO_IJ)) - ALPHA*M, G)
|
return np.linalg.solve(np.eye(len(S_TO_IJ)) - ALPHA*M, G)
|
||||||
|
|
||||||
|
|
||||||
def value_iteration(g):
|
def value_iteration(g, return_history=False):
|
||||||
j = np.random.randn(len(S_TO_IJ))
|
j = np.random.randn(len(S_TO_IJ))
|
||||||
|
history = [j]
|
||||||
while True:
|
while True:
|
||||||
policy = _policy_improvement(j, g)
|
policy = _policy_improvement(j, g)
|
||||||
j_new = []
|
j_new = []
|
||||||
@@ -167,28 +171,41 @@ def value_iteration(g):
|
|||||||
j_new.append(h_function(x, u, j, g))
|
j_new.append(h_function(x, u, j, g))
|
||||||
j_old = j
|
j_old = j
|
||||||
j = np.array(j_new)
|
j = np.array(j_new)
|
||||||
|
if return_history:
|
||||||
|
history.append(j)
|
||||||
if max(abs(j - j_old)) < EPSILON:
|
if max(abs(j - j_old)) < EPSILON:
|
||||||
break
|
break
|
||||||
return j, policy
|
if not return_history:
|
||||||
|
return j, policy
|
||||||
|
else:
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
def policy_iteration(g):
|
def policy_iteration(g, return_history=False):
|
||||||
j = None
|
j = None
|
||||||
policy = [np.random.choice(u_of_x(x)) for x in S_TO_IJ]
|
policy = [np.random.choice(u_of_x(x)) for x in S_TO_IJ]
|
||||||
|
history = []
|
||||||
while True:
|
while True:
|
||||||
j_old = j
|
j_old = j
|
||||||
j = _evaluate_policy(policy, g)
|
j = _evaluate_policy(policy, g)
|
||||||
|
history.append(j)
|
||||||
if j_old is not None and max(abs(j - j_old)) < EPSILON:
|
if j_old is not None and max(abs(j - j_old)) < EPSILON:
|
||||||
break
|
break
|
||||||
policy = _policy_improvement(j, g)
|
policy = _policy_improvement(j, g)
|
||||||
return j, policy
|
if not return_history:
|
||||||
|
return j, policy
|
||||||
|
else:
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# Argument Parsing
|
||||||
ap = ArgumentParser()
|
ap = ArgumentParser()
|
||||||
ap.add_argument('maze_file', help='Path to maze file')
|
ap.add_argument('maze_file', help='Path to maze file')
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
# Initialization
|
||||||
MAZE = np.genfromtxt(
|
MAZE = np.genfromtxt(
|
||||||
args.maze_file,
|
args.maze_file,
|
||||||
dtype=str,
|
dtype=str,
|
||||||
@@ -196,6 +213,40 @@ if __name__ == '__main__':
|
|||||||
STATE_MASK = (MAZE != '1')
|
STATE_MASK = (MAZE != '1')
|
||||||
S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK]
|
S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK]
|
||||||
|
|
||||||
j, policy = value_iteration(cost_treasure)
|
# J / policy for both algorithms for both cost functions for 3 alphas
|
||||||
print(j)
|
costs = {'g1': cost_treasure, 'g2': cost_energy}
|
||||||
plot_j_policy_on_maze(j, policy)
|
optimizers = {'Value Iteration': value_iteration,
|
||||||
|
'Policy Iteration': policy_iteration}
|
||||||
|
|
||||||
|
for a in [0.9, 0.5, 0.01]:
|
||||||
|
plt.figure()
|
||||||
|
plt.suptitle('DISCOUNT = ' + str(a))
|
||||||
|
i = 1
|
||||||
|
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||||
|
for g in ['g1', 'g2']:
|
||||||
|
name = ' / '.join([opt, g])
|
||||||
|
ALPHA = a
|
||||||
|
j, policy = optimizers[opt](costs[g])
|
||||||
|
plt.subplot(2, 2, i)
|
||||||
|
plt.gca().set_title(name)
|
||||||
|
plot_j_policy_on_maze(j, policy)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# plt.show()
|
||||||
|
# Error graphs
|
||||||
|
for opt in ['Value Iteration', 'Policy Iteration']:
|
||||||
|
plt.figure()
|
||||||
|
plt.suptitle(opt)
|
||||||
|
i = 1
|
||||||
|
for g in ['g1', 'g2']:
|
||||||
|
for a in [0.9, 0.8, 0.7]:
|
||||||
|
name = 'Cost: {}, discount: {}'.format(g, a)
|
||||||
|
ALPHA = a
|
||||||
|
history = optimizers[opt](costs[g], return_history=True)
|
||||||
|
plt.subplot(2, 3, i)
|
||||||
|
plt.gca().set_title(name)
|
||||||
|
plot_cost_history(history)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
print('I ran in {} seconds'.format(time() - start))
|
||||||
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user