예제 #1
0
파일: main.py 프로젝트: Uranium2/DRL_4IABD
def test_line_iterative_policy_evaluation():
    pygame.init()

    num_states = 15
    rewards = ((0, -1),
               (14, 1))

    terminal = [0, 14]

    S, A, T, P = create_line_world(num_states, rewards, terminal)
    Pi = tabular_uniform_random_policy(S.shape[0], A.shape[0])

    start_time = time()
    V = iterative_policy_evaluation(S, A, P, T, Pi)
    print("--- %s seconds ---" % (time() - start_time))

    print(V)


    win = pygame.display.set_mode((num_states * 100, 100))

    st = reset_line(num_states)

    while not is_terminal(st, T):
        display_line(win, num_states)
        event_loop()
        display_reward_line(win, rewards, num_states)
        display_mouse_line(win, st, num_states)
        sleep(1)

        if V[st + 1] > V[st - 1] or V[st + 1] == 0:
            a = 1
        elif V[st + 1] <  V[st - 1] or V[st - 1] == 0:
            a = 0
        st, r, term = step(st, a, T, S, P)

    display_line(win, num_states)
    display_reward_line(win, rewards,num_states)
    display_mouse_line(win, st, num_states)
    sleep(1)
예제 #2
0
파일: main.py 프로젝트: Uranium2/DRL_4IABD
def test_line_sarsa():
    pygame.init()

    num_states = 15
    rewards = ((0, -1),
               (14, 1))

    terminal = [0, 14]

    S, A, T, P = create_line_world(num_states, rewards, terminal)

    start_time = time()

    Q, Pi = tabular_sarsa_control(T, S, P, len(S), len(A), reset_line, is_terminal, step,
                                                      episodes_count=10000, max_steps_per_episode=100)
    print("--- %s seconds ---" % (time() - start_time))
    for i in range(num_states):
        print(Q[i], end=" ")


    win = pygame.display.set_mode((num_states * 100, 100))

    st = reset_line(num_states)

    while not is_terminal(st, T):
        display_line(win, num_states)
        event_loop()
        display_reward_line(win, rewards, num_states)
        display_mouse_line(win, st, num_states)
        sleep(1)

        a = np.argmax(Q[st])

        st, r, term = step(st, a, T, S, P)

    display_line(win, num_states)
    display_reward_line(win, rewards,num_states)
    display_mouse_line(win, st, num_states)
    sleep(1)