def test_line_iterative_policy_evaluation(): pygame.init() num_states = 15 rewards = ((0, -1), (14, 1)) terminal = [0, 14] S, A, T, P = create_line_world(num_states, rewards, terminal) Pi = tabular_uniform_random_policy(S.shape[0], A.shape[0]) start_time = time() V = iterative_policy_evaluation(S, A, P, T, Pi) print("--- %s seconds ---" % (time() - start_time)) print(V) win = pygame.display.set_mode((num_states * 100, 100)) st = reset_line(num_states) while not is_terminal(st, T): display_line(win, num_states) event_loop() display_reward_line(win, rewards, num_states) display_mouse_line(win, st, num_states) sleep(1) if V[st + 1] > V[st - 1] or V[st + 1] == 0: a = 1 elif V[st + 1] < V[st - 1] or V[st - 1] == 0: a = 0 st, r, term = step(st, a, T, S, P) display_line(win, num_states) display_reward_line(win, rewards,num_states) display_mouse_line(win, st, num_states) sleep(1)
def test_line_sarsa(): pygame.init() num_states = 15 rewards = ((0, -1), (14, 1)) terminal = [0, 14] S, A, T, P = create_line_world(num_states, rewards, terminal) start_time = time() Q, Pi = tabular_sarsa_control(T, S, P, len(S), len(A), reset_line, is_terminal, step, episodes_count=10000, max_steps_per_episode=100) print("--- %s seconds ---" % (time() - start_time)) for i in range(num_states): print(Q[i], end=" ") win = pygame.display.set_mode((num_states * 100, 100)) st = reset_line(num_states) while not is_terminal(st, T): display_line(win, num_states) event_loop() display_reward_line(win, rewards, num_states) display_mouse_line(win, st, num_states) sleep(1) a = np.argmax(Q[st]) st, r, term = step(st, a, T, S, P) display_line(win, num_states) display_reward_line(win, rewards,num_states) display_mouse_line(win, st, num_states) sleep(1)