perfect PI terminator, nicer graphs

This commit is contained in:
2018-12-18 14:19:09 +01:00
parent e214a3cc7d
commit b469086b4b

23
main.py
View File

@@ -15,7 +15,8 @@ import matplotlib.pyplot as plt
P = 0.1 P = 0.1
ALPHA = 0.90 ALPHA = 0.90
EPSILON = 1e-12 # Convergence criterium EPSILON = 1e-12
# EPSILON = 1e-12 # Convergence criterium
A2 = np.array([ # Action index to action mapping A2 = np.array([ # Action index to action mapping
[-1, 0], # Up [-1, 0], # Up
[ 1, 0], # Down [ 1, 0], # Down
@@ -56,7 +57,7 @@ def init_global(maze_filename):
# Basic maze structure initialization # Basic maze structure initialization
MAZE = np.genfromtxt( MAZE = np.genfromtxt(
maze_filename, maze_filename,
dtype=str, dtype='|S1',
) )
state_mask = (MAZE != '1') state_mask = (MAZE != '1')
@@ -72,7 +73,7 @@ def init_global(maze_filename):
maze_cost[MAZE == 'T'] = 50 maze_cost[MAZE == 'T'] = 50
maze_cost[MAZE == 'G'] = -1 maze_cost[MAZE == 'G'] = -1
G1_X = maze_cost.copy()[state_mask] G1_X = maze_cost.copy()[state_mask]
maze_cost[maze_cost < 1] += 1 # assert np.nan < whatever == False maze_cost[(MAZE=='0') | (MAZE=='S') | (MAZE=='G')] += 1
G2_X = maze_cost.copy()[state_mask] G2_X = maze_cost.copy()[state_mask]
# Actual environment modelling # Actual environment modelling
@@ -146,20 +147,23 @@ def policy_iteration(j, g):
return policy, j return policy, j
def _terminate(j, j_old): def _terminate(j, j_old, policy, policy_old):
# TODO: DIS # eps = EPSILON
return np.abs(j - j_old).max() < EPSILON # return np.abs(j - j_old).max() < eps
return np.all(policy == policy_old)
def dynamic_programming(optimizer_step, g, return_history=False): def dynamic_programming(optimizer_step, g, return_history=False):
j = np.zeros(SN, dtype=np.float64) j = np.zeros(SN, dtype=np.float64)
policy = None
history = [] history = []
while True: while True:
j_old = j j_old = j
policy_old = policy
policy, j = optimizer_step(j, g) policy, j = optimizer_step(j, g)
if return_history: if return_history:
history.append(j) history.append(j)
if _terminate(j, j_old): if _terminate(j, j_old, policy, policy_old):
break break
if not return_history: if not return_history:
return j, policy return j, policy
@@ -191,7 +195,9 @@ if __name__ == '__main__':
name = ' / '.join([opt, cost]) name = ' / '.join([opt, cost])
ALPHA = a ALPHA = a
j, policy = dynamic_programming(optimizers[opt], costs[cost]) j, policy = dynamic_programming(optimizers[opt], costs[cost])
print(name, j) print(name)
print(j)
# print(name, j)
plt.subplot(2, 2, i) plt.subplot(2, 2, i)
plt.gca().set_title(name) plt.gca().set_title(name)
plot_j_policy_on_maze(j, policy) plot_j_policy_on_maze(j, policy)
@@ -200,6 +206,7 @@ if __name__ == '__main__':
# Error graphs # Error graphs
for opt in ['Value Iteration', 'Policy Iteration']: for opt in ['Value Iteration', 'Policy Iteration']:
plt.figure() plt.figure()
plt.subplots_adjust(wspace=0.45, hspace=0.45)
plt.suptitle(opt) plt.suptitle(opt)
i = 1 i = 1
for cost in ['g1', 'g2']: for cost in ['g1', 'g2']: