initial commit (job done tho)
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# vim
|
||||||
|
.*.sw*
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
*.txt
|
||||||
201
main.py
Normal file
201
main.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import matplotlib as mpl
|
||||||
|
mpl.use('TkAgg')
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
P = 0.1
|
||||||
|
ALPHA = 0.80
|
||||||
|
EPSILON = 0.0001 # Convergence criterium
|
||||||
|
|
||||||
|
# Global state
|
||||||
|
MAZE = None # Map of the environment
|
||||||
|
STATE_MASK = None # Fields of maze belonging to state space
|
||||||
|
S_TO_IJ = None # Mapping of state vector to coordinates
|
||||||
|
|
||||||
|
ACTIONS = {
|
||||||
|
'UP': (-1, 0),
|
||||||
|
'DOWN': (1, 0),
|
||||||
|
'LEFT': (0, -1),
|
||||||
|
'RIGHT': (0, 1),
|
||||||
|
'IDLE': (0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ij_to_s(ij):
|
||||||
|
return np.argwhere(np.all(ij == S_TO_IJ, axis=1)).flatten()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def h_function(x, u, j, g):
|
||||||
|
"""Return E_pi_w[g(x, pi(x), w) + alpha*J(f(x, pi(x), w))]."""
|
||||||
|
pw = pw_of_x_u(x, u)
|
||||||
|
expectation = sum(
|
||||||
|
pw[w] * (g(x, u, w) + ALPHA*j[_ij_to_s(f(x, u, w))])
|
||||||
|
for w in pw
|
||||||
|
)
|
||||||
|
return expectation
|
||||||
|
|
||||||
|
|
||||||
|
def f(x, u, w):
|
||||||
|
return _move(_move(x, ACTIONS[u]), ACTIONS[w])
|
||||||
|
|
||||||
|
|
||||||
|
def cost_treasure(x, u, w):
|
||||||
|
xt = f(x, u, w)
|
||||||
|
options = {
|
||||||
|
'T': 50,
|
||||||
|
'G': -1,
|
||||||
|
}
|
||||||
|
return options.get(MAZE[xt], 0)
|
||||||
|
|
||||||
|
|
||||||
|
def cost_energy(x, u, w):
|
||||||
|
xt = f(x, u, w)
|
||||||
|
options = {
|
||||||
|
'T': 50,
|
||||||
|
'G': 0
|
||||||
|
}
|
||||||
|
return options.get(MAZE[xt], 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _move(start, move):
|
||||||
|
return start[0] + move[0], start[1] + move[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_target(target):
|
||||||
|
return (
|
||||||
|
0 <= target[0] < MAZE.shape[0] and
|
||||||
|
0 <= target[1] < MAZE.shape[1] and
|
||||||
|
MAZE[target] != '1'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cartesian(arr):
|
||||||
|
"""Return cartesian product of sets."""
|
||||||
|
return itertools.product(*arr)
|
||||||
|
|
||||||
|
|
||||||
|
def u_of_x(x):
|
||||||
|
"""Return a list of allowed actions for the given state x."""
|
||||||
|
return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))]
|
||||||
|
|
||||||
|
|
||||||
|
def pw_of_x_u(x, u):
|
||||||
|
"""Calculate probabilities of disturbances given state and action.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : tuple of ints
|
||||||
|
The state coordinate
|
||||||
|
(it is up to user to ensure this is a valid state).
|
||||||
|
u : str
|
||||||
|
The name of the action (again, up to the user to ensure validity).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A mapping of valid disturbances to their probabilities.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if u in ('LEFT', 'RIGHT'):
|
||||||
|
possible_w = ('UP', 'IDLE', 'DOWN')
|
||||||
|
elif u in ('UP', 'DOWN'):
|
||||||
|
possible_w = ('LEFT', 'IDLE', 'RIGHT')
|
||||||
|
else: # I assume that the IDLE action is deterministic
|
||||||
|
possible_w = ('IDLE',)
|
||||||
|
|
||||||
|
allowed_w = [
|
||||||
|
w for w in possible_w if
|
||||||
|
_valid_target(f(x, u, w))
|
||||||
|
]
|
||||||
|
probs = {w: P for w in allowed_w if w != 'IDLE'}
|
||||||
|
probs['IDLE'] = 1 - sum(probs.values())
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def plot_j_policy_on_maze(j, policy):
|
||||||
|
heatmap = np.ones(MAZE.shape) * np.nan # Ugly
|
||||||
|
heatmap[STATE_MASK] = j # Even uglier
|
||||||
|
cmap = mpl.cm.get_cmap('coolwarm')
|
||||||
|
cmap.set_bad(color='black')
|
||||||
|
plt.imshow(heatmap, cmap=cmap)
|
||||||
|
plt.colorbar()
|
||||||
|
plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0],
|
||||||
|
[ACTIONS[u][1] for u in policy],
|
||||||
|
[-ACTIONS[u][0] for u in policy])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_improvement(j, g):
|
||||||
|
policy = []
|
||||||
|
for x in S_TO_IJ:
|
||||||
|
policy.append(min(
|
||||||
|
u_of_x(x), key=lambda u: h_function(x, u, j, g)
|
||||||
|
))
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_policy(policy, g):
|
||||||
|
G = []
|
||||||
|
M = np.zeros((len(S_TO_IJ), len(S_TO_IJ)))
|
||||||
|
for x, u in zip(S_TO_IJ, policy):
|
||||||
|
pw = pw_of_x_u(x, u)
|
||||||
|
G.append(sum(pw[w] * g(x, u, w) for w in pw))
|
||||||
|
targets = [(_ij_to_s(f(x, u, w)), pw[w]) for w in pw]
|
||||||
|
iox = _ij_to_s(x)
|
||||||
|
for t, pww in targets:
|
||||||
|
M[iox, t] = pww
|
||||||
|
G = np.array(G)
|
||||||
|
return np.linalg.solve(np.eye(len(S_TO_IJ)) - ALPHA*M, G)
|
||||||
|
|
||||||
|
|
||||||
|
def value_iteration(g):
|
||||||
|
j = np.random.randn(len(S_TO_IJ))
|
||||||
|
while True:
|
||||||
|
policy = _policy_improvement(j, g)
|
||||||
|
j_new = []
|
||||||
|
for x, u in zip(S_TO_IJ, policy):
|
||||||
|
j_new.append(h_function(x, u, j, g))
|
||||||
|
j_old = j
|
||||||
|
j = np.array(j_new)
|
||||||
|
if max(abs(j - j_old)) < EPSILON:
|
||||||
|
break
|
||||||
|
return j, policy
|
||||||
|
|
||||||
|
|
||||||
|
def policy_iteration(g):
|
||||||
|
j = None
|
||||||
|
policy = [np.random.choice(u_of_x(x)) for x in S_TO_IJ]
|
||||||
|
while True:
|
||||||
|
j_old = j
|
||||||
|
j = _evaluate_policy(policy, g)
|
||||||
|
if j_old is not None and max(abs(j - j_old)) < EPSILON:
|
||||||
|
break
|
||||||
|
policy = _policy_improvement(j, g)
|
||||||
|
return j, policy
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ap = ArgumentParser()
|
||||||
|
ap.add_argument('maze_file', help='Path to maze file')
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
MAZE = np.genfromtxt(
|
||||||
|
args.maze_file,
|
||||||
|
dtype=str,
|
||||||
|
)
|
||||||
|
STATE_MASK = (MAZE != '1')
|
||||||
|
S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK]
|
||||||
|
|
||||||
|
j, policy = value_iteration(cost_treasure)
|
||||||
|
print(j)
|
||||||
|
plot_j_policy_on_maze(j, policy)
|
||||||
Reference in New Issue
Block a user