Exemplo n.º 1
0
def main():
    global env, RL
    env = Maze('./env/maps/map3.json', full_observation=True)
    RL = DeepQNetwork(
        n_actions=4,
        n_features=env.height * env.width,
        restore_path=None,
        # restore_path=base_path + 'model_dqn.ckpt',
        learning_rate=0.00001,
        reward_decay=0.9,
        e_greedy=0.95,
        replace_target_iter=4e4,
        batch_size=64,
        e_greedy_init=0,
        # e_greedy_increment=None,
        e_greedy_increment=1e-3,
        output_graph=False,
    )
    env.after(100, run_maze)
    env.mainloop()
Exemplo n.º 2
0
def main():
    global env
    env = Maze('./env/maps/map3.json')
    env.after(100, run_maze)
    env.mainloop()  # mainloop() to run the application.
Exemplo n.º 3
0
Arquivo: rl.py Projeto: LDNN97/RLES
    #         action = rl.choose_action(str(observation))
    #         observation_, reward, done = env.step(action)
    #         rl.learn(str(observation), action, reward, str(observation_))
    #         observation = observation_
    #         if done:
    #             break

    # on policy
    rl = SarsaTable(actions=list(range(env.n_actions)))
    for episode in range(100):
        observation = env.reset()
        action = rl.choose_action(str(observation))
        while True:
            observation_, reward, done = env.step(action)
            action_ = rl.choose_action(str(observation_))
            rl.learn(str(observation), action, reward, str(observation_),
                     action_)
            observation = observation_
            action = action_
            if done:
                break

    print('game over')
    env.destroy()


if __name__ == '__main__':
    env = Maze()
    env.after(1000, main)
    env.mainloop()
Exemplo n.º 4
0
def main():
    global env, RL, env_model

    # if_dyna = True
    # env = Maze('./env/maps/map2.json')
    # if if_dyna:
    #     # ---------- Dyna Q ---------- # #
    #     from brain.dyna_Q import QLearningTable, EnvModel
    #     RL = QLearningTable(actions=list(range(env.n_actions)))
    #     env_model = EnvModel(actions=list(range(env.n_actions)))
    #     env.after(0, update_dyna_q)  # Call function update() once after given time/ms.
    # else:
    #     # # -------- Q Learning -------- # #
    #     from brain.Q_learning import QLearningTable
    #     RL = QLearningTable(actions=list(range(env.n_actions)))
    #     env.after(0, update_q())  # Call function update() once after given time/ms.

    time_cmp = []
    # -------- Q Learning -------- # #
    from brain.Q_learning import QLearningTable
    start = time.time()
    env = Maze('./env/maps/map2.json')
    RL = QLearningTable(actions=list(range(env.n_actions)))
    env.after(0,
              update_q())  # Call function update() once after given time/ms.
    env.mainloop()
    sum_time = time.time() - start
    time_cmp.append(sum_time)
    # ---------- Dyna Q ---------- # #
    from brain.dyna_Q import QLearningTable, EnvModel
    for n in [5, 10, 25, 50]:
        start = time.time()
        env = Maze('./env/maps/map2.json')
        RL = QLearningTable(actions=list(range(env.n_actions)))
        env_model = EnvModel(actions=list(range(env.n_actions)))
        print('Dyna-{}'.format(n))
        env.after(0, update_dyna_q,
                  n)  # n is the parameter of update_dyna_q().
        env.mainloop()  # mainloop() to run the application.
        sum_time = time.time() - start
        time_cmp.append(sum_time)

    # This part must after env.mainloop()
    # plot all lines.
    all_aver_steps = [np.load('./logs/q_learning/q_learning.npy').tolist()]
    for n in [5, 10, 25, 50]:
        all_aver_steps.append(
            np.load('./logs/dyna_q/dyna_q_{}.npy'.format(n)).tolist())
    plot_multi_lines(
        all_aver_steps,
        all_labels=['q_learning', 'dyna_5', 'dyna_10', 'dyna_25', 'dyna_50'],
        save_path='./logs/cmp_all.png')

    # only plot dyna_Q
    all_aver_steps = []
    for n in [5, 10, 25, 50]:
        all_aver_steps.append(
            np.load('./logs/dyna_q/dyna_q_{}.npy'.format(n))[0:100].tolist())
    plot_multi_lines(all_aver_steps,
                     all_labels=['dyna_5', 'dyna_10', 'dyna_25', 'dyna_50'],
                     save_path='./logs/cmp_all_dyna_Q.png')

    print(time_cmp)
Exemplo n.º 5
0
def main():
    global env
    env = Maze('./env/maps/map1.json', full_observation=True)
    env.after(100, run_maze)  # Call function update() once after given time/ms.
    env.mainloop()  # mainloop() to run the application.