refactored into numpy, 10x faster

This commit is contained in:
2018-12-17 18:47:00 +01:00
parent 9127b93d10
commit 29a2d7feb5

142
main.py
View File

@@ -21,6 +21,21 @@ EPSILON = 1e-8 # Convergence criterium
MAZE = None # Map of the environment MAZE = None # Map of the environment
STATE_MASK = None # Fields of maze belonging to state space STATE_MASK = None # Fields of maze belonging to state space
S_TO_IJ = None # Mapping of state vector to coordinates S_TO_IJ = None # Mapping of state vector to coordinates
IJ_TO_S = None # Mapping of coordinates to state vector
U_OF_X = None # The allowed action space matrix representation
PW_OF_X_U = None # The probability distribution of disturbance
G1_X = None # The cost function vector representation (depends only on state)
G2_X = None # The second cost function vector representation
F_X_U_W = None # The state function
SN = None # Number of states
A2 = np.array([
[-1, 0],
[1, 0],
[0, -1],
[0, 1],
[0, 0]
])
ACTIONS = { ACTIONS = {
'UP': (-1, 0), 'UP': (-1, 0),
@@ -35,6 +50,7 @@ def _ij_to_s(ij):
return np.argwhere(np.all(ij == S_TO_IJ, axis=1)).flatten()[0] return np.argwhere(np.all(ij == S_TO_IJ, axis=1)).flatten()[0]
# TODO: for all x and u in one go
def h_function(x, u, j, g): def h_function(x, u, j, g):
"""Return E_pi_w[g(x, pi(x), w) + alpha*J(f(x, pi(x), w))].""" """Return E_pi_w[g(x, pi(x), w) + alpha*J(f(x, pi(x), w))]."""
pw = pw_of_x_u(x, u) pw = pw_of_x_u(x, u)
@@ -45,6 +61,12 @@ def h_function(x, u, j, g):
return expectation return expectation
def h_matrix(j, g):
result = (PW_OF_X_U * (g[F_X_U_W] + ALPHA*j[F_X_U_W])).sum(axis=2)
result[~U_OF_X] = np.inf # discard invalid policies
return result
def f(x, u, w): def f(x, u, w):
return _move(_move(x, ACTIONS[u]), ACTIONS[w]) return _move(_move(x, ACTIONS[u]), ACTIONS[w])
@@ -75,10 +97,57 @@ def _valid_target(target):
return ( return (
0 <= target[0] < MAZE.shape[0] and 0 <= target[0] < MAZE.shape[0] and
0 <= target[1] < MAZE.shape[1] and 0 <= target[1] < MAZE.shape[1] and
MAZE[target] != '1' MAZE[tuple(target)] != '1'
) )
def _init_global(maze_file):
global MAZE, STATE_MASK, SN, S_TO_IJ, IJ_TO_S
global U_OF_X, PW_OF_X_U, F_X_U_W, G1_X, G2_X
# Basic maze structure initialization
MAZE = np.genfromtxt(
maze_file,
dtype=str,
)
STATE_MASK = (MAZE != '1')
S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK]
SN = len(S_TO_IJ)
IJ_TO_S = np.zeros(MAZE.shape, dtype=np.int32)
IJ_TO_S[STATE_MASK] = np.arange(SN)
# One step cost functions initialization
maze_cost = np.zeros(MAZE.shape)
maze_cost[MAZE == '1'] = np.nan
maze_cost[(MAZE == '0') | (MAZE == 'S')] = 0
maze_cost[MAZE == 'T'] = 50
maze_cost[MAZE == 'G'] = -1
G1_X = maze_cost.copy()[STATE_MASK]
maze_cost[maze_cost < 1] += 1 # assert np.nan < whatever == True
G2_X = maze_cost.copy()[STATE_MASK]
# Actual environment modelling
U_OF_X = np.zeros((SN, len(A2)), dtype=np.bool)
PW_OF_X_U = np.zeros((SN, len(A2), len(A2)))
F_X_U_W = np.zeros(PW_OF_X_U.shape, dtype=np.int32)
for ix, x in enumerate(S_TO_IJ):
for iu, u in enumerate(A2):
if _valid_target(x + u):
U_OF_X[ix, iu] = True
if iu in (0, 1):
possible_iw = [2, 3]
elif iu in (2, 3):
possible_iw = [0, 1]
for iw in possible_iw:
if _valid_target(x + u + A2[iw]):
PW_OF_X_U[ix, iu, iw] = P
F_X_U_W[ix, iu, iw] = IJ_TO_S[tuple(x + u + A2[iw])]
# IDLE w is always possible
PW_OF_X_U[ix, iu, -1] = 1 - PW_OF_X_U[ix, iu].sum()
F_X_U_W[ix, iu, -1] = IJ_TO_S[tuple(x + u)]
def u_of_x(x): def u_of_x(x):
"""Return a list of allowed actions for the given state x.""" """Return a list of allowed actions for the given state x."""
return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))] return [u for u in ACTIONS if _valid_target(_move(x, ACTIONS[u]))]
@@ -125,8 +194,7 @@ def plot_j_policy_on_maze(j, policy):
plt.imshow(heatmap, cmap=cmap) plt.imshow(heatmap, cmap=cmap)
plt.colorbar() plt.colorbar()
plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0], plt.quiver(S_TO_IJ[:,1], S_TO_IJ[:,0],
[ACTIONS[u][1] for u in policy], A2[policy, 1], -A2[policy, 0])
[-ACTIONS[u][0] for u in policy])
plt.gca().get_xaxis().set_visible(False) plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False) plt.gca().get_yaxis().set_visible(False)
@@ -139,41 +207,43 @@ def plot_cost_history(hist):
def _policy_improvement(j, g): def _policy_improvement(j, g):
policy = [] h_mat = h_matrix(j, g)
for x in S_TO_IJ: return np.argmin(h_mat, axis=1), h_mat.min(axis=1)
policy.append(min(
u_of_x(x), key=lambda u: h_function(x, u, j, g)
))
return policy
def _evaluate_policy(policy, g): def _evaluate_policy(policy, g):
G = [] pw_pi = PW_OF_X_U[np.arange(SN), policy] # p(w) given policy for all x
M = np.zeros((len(S_TO_IJ), len(S_TO_IJ))) targs = F_X_U_W[np.arange(SN), policy] # all f(x, u(x))
for x, u in zip(S_TO_IJ, policy): G = (pw_pi * g[targs]).sum(axis=1)
pw = pw_of_x_u(x, u)
G.append(sum(pw[w] * g(x, u, w) for w in pw)) M = np.zeros((SN, SN)) # Markov matrix for given determ policy
targets = [(_ij_to_s(f(x, u, w)), pw[w]) for w in pw] x_from = [x_ff for x_f, nz in
iox = _ij_to_s(x) zip(np.arange(SN), np.count_nonzero(pw_pi, axis=1))
for t, pww in targets: for x_ff in [x_f] * nz]
M[iox, t] = pww M[x_from, targs[pw_pi > 0]] = pw_pi[pw_pi > 0]
G = np.array(G) # M[np.arange(SN), F_X_U_W[PW_OF_X_U > 0]] = PW_OF_X_U[PW_OF_X_U > 0]
return np.linalg.solve(np.eye(len(S_TO_IJ)) - ALPHA*M, G) # for x, u in zip(S_TO_IJ, policy):
# pw = pw_of_x_u(x, u)
# G.append(sum(pw[w] * g(x, u, w) for w in pw))
# targets = [(_ij_to_s(f(x, u, w)), pw[w]) for w in pw]
# iox = _ij_to_s(x)
# for t, pww in targets:
# M[iox, t] = pww
# G = np.array(G)
return np.linalg.solve(np.eye(SN) - ALPHA*M, G)
def value_iteration(g, return_history=False): def value_iteration(g, return_history=False):
j = np.random.randn(len(S_TO_IJ)) j = np.zeros(SN)
history = [j] history = [j]
while True: while True:
policy = _policy_improvement(j, g) # print(j)
j_new = [] policy, j_new = _policy_improvement(j, g)
for x, u in zip(S_TO_IJ, policy):
j_new.append(h_function(x, u, j, g))
j_old = j j_old = j
j = np.array(j_new) j = j_new
if return_history: if return_history:
history.append(j) history.append(j)
if max(abs(j - j_old)) < EPSILON: if np.abs(j - j_old).max() < EPSILON:
break break
if not return_history: if not return_history:
return j, policy return j, policy
@@ -183,7 +253,7 @@ def value_iteration(g, return_history=False):
def policy_iteration(g, return_history=False): def policy_iteration(g, return_history=False):
j = None j = None
policy = [np.random.choice(u_of_x(x)) for x in S_TO_IJ] policy = np.full(SN, len(A2) - 1)
history = [] history = []
while True: while True:
j_old = j j_old = j
@@ -191,7 +261,7 @@ def policy_iteration(g, return_history=False):
history.append(j) history.append(j)
if j_old is not None and max(abs(j - j_old)) < EPSILON: if j_old is not None and max(abs(j - j_old)) < EPSILON:
break break
policy = _policy_improvement(j, g) policy, _ = _policy_improvement(j, g)
if not return_history: if not return_history:
return j, policy return j, policy
else: else:
@@ -204,17 +274,13 @@ if __name__ == '__main__':
ap.add_argument('maze_file', help='Path to maze file') ap.add_argument('maze_file', help='Path to maze file')
args = ap.parse_args() args = ap.parse_args()
start = time() # start = time()
# Initialization # Initialization
MAZE = np.genfromtxt( start = time()
args.maze_file, _init_global(args.maze_file)
dtype=str,
)
STATE_MASK = (MAZE != '1')
S_TO_IJ = np.indices(MAZE.shape).transpose(1, 2, 0)[STATE_MASK]
# J / policy for both algorithms for both cost functions for 3 alphas # J / policy for both algorithms for both cost functions for 3 alphas
costs = {'g1': cost_treasure, 'g2': cost_energy} costs = {'g1': G1_X, 'g2': G2_X}
optimizers = {'Value Iteration': value_iteration, optimizers = {'Value Iteration': value_iteration,
'Policy Iteration': policy_iteration} 'Policy Iteration': policy_iteration}
@@ -227,12 +293,12 @@ if __name__ == '__main__':
name = ' / '.join([opt, g]) name = ' / '.join([opt, g])
ALPHA = a ALPHA = a
j, policy = optimizers[opt](costs[g]) j, policy = optimizers[opt](costs[g])
print(name, j)
plt.subplot(2, 2, i) plt.subplot(2, 2, i)
plt.gca().set_title(name) plt.gca().set_title(name)
plot_j_policy_on_maze(j, policy) plot_j_policy_on_maze(j, policy)
i += 1 i += 1
# plt.show()
# Error graphs # Error graphs
for opt in ['Value Iteration', 'Policy Iteration']: for opt in ['Value Iteration', 'Policy Iteration']:
plt.figure() plt.figure()