diff --git a/main.py b/main.py index ae447b8..ffd3381 100644 --- a/main.py +++ b/main.py @@ -92,6 +92,16 @@ def init_global(maze_filename): PW_OF_X_U[ix, iu, -1] = 1 - PW_OF_X_U[ix, iu].sum() F_X_U_W[ix, iu, -1] = ij_to_s[tuple(x + u)] + # Forbid to leave the goal state + # (could've been done through cost function though) + goal_idx = ij_to_s[np.where(MAZE == b'G')][0] + U_OF_X[goal_idx] = False + U_OF_X[goal_idx, -1] = True + PW_OF_X_U[goal_idx] = 0 + PW_OF_X_U[goal_idx, -1, -1] = 1 + F_X_U_W[goal_idx] = 0 + F_X_U_W[goal_idx, -1, -1] = goal_idx + def h_matrix(j, g): h_x_u = (PW_OF_X_U * (g[F_X_U_W] + ALPHA*j[F_X_U_W])).sum(axis=2)