예제 #1
0
            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                break
            step += 1

    # end of game
    print('game over')


if __name__ == "__main__":
    # maze game
    env = EVs()
    RL = DeepQNetwork(
        env.n_actions,
        env.n_features,
        learning_rate=0.01,
        reward_decay=0.9,
        e_greedy=0.9,
        replace_target_iter=200,
        memory_size=2000,
        # output_graph=True
    )
    run_EVs()
    # env.after(100, run_EVs)
    # env.mainloop()
    # print(env.actions[100:110])
예제 #2
0
import gym
from brain import DeepQNetwork

env = gym.make('MountainCar-v0')
env = env.unwrapped

print(env.action_space)
print(env.observation_space)
print(env.observation_space.high)
print(env.observation_space.low)

RL = DeepQNetwork(
    n_actions=3,
    n_features=2,
    learning_rate=0.001,
    e_greedy=0.9,
    replace_target_iter=300,
    memory_size=3000,
    e_greedy_increment=0.0002,
)

total_steps = 0

for i_episode in range(10):

    observation = env.reset()
    ep_r = 0
    while True:
        env.render()

        action = RL.choose_action(observation)
예제 #3
0
    args = parser.parse_args()

    if args.method == 'Q-learning':
        RL = QLearningTable(range(0, env.action_space.n))
    elif args.method == 'Sarsa':
        RL = SarsaTable(range(0, env.action_space.n))
    elif args.method == 'SarsaLambda':
        RL = SarsaLambdaTable(range(0, env.action_space.n))
    elif args.method == 'DQN':
        if args.test == 'True':
            RL = DeepQNetwork(env.action_space.n,
                              2,
                              lr=0.1,
                              batch_size=128,
                              reward_decay=0.9,
                              e_greedy=0.9,
                              replace_target_iter=300,
                              memory_size=3000,
                              e_greedy_increment=0.0001,
                              path='./model/model',
                              test=True)
        else:
            RL = DeepQNetwork(
                env.action_space.n,
                2,
                lr=0.1,
                batch_size=128,
                reward_decay=0.9,
                e_greedy=0.9,
                replace_target_iter=300,
                memory_size=3000,
예제 #4
0

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--reuse',
                        default='False',
                        help='is testing mode or not')
    args = parser.parse_args()

    RL = DeepQNetwork(
        3,
        2,
        args.reuse,
        learning_rate=0.1,
        reward_decay=0.9,
        e_greedy=0.9,
        replace_target_iter=200,
        memory_size=2000,
        # output_graph=True
    )
    run()

    import matplotlib.pyplot as plt
    plt.plot(np.arange(len(steps)), steps)
    plt.ylabel('steps cost')
    plt.xlabel('episode')
    plt.savefig('steps_picture.png')
    plt.show()
예제 #5
0

def run():
    step = 0
    for episode in range(nEpisodes):
        state = game.reset()
        while True:
            action = RL.chooseAction(state)
            observation, reward, done = game.step(action)
            RL.storeTransition(state, action, reward, observation)
            if step > 200 and step % 5 == 0:
                RL.learn()

            state = observation

            if done:
                break
            step += 1
        print("score : ", game.score, game.board.getState())
    print('done')


if __name__ == '__main__':
    game = Game()
    RL = DeepQNetwork(game.nActions,
                      game.nFeatures,
                      learningRate=0.01,
                      replaceTarget=200,
                      memorySize=2000)
    run()
예제 #6
0
    evaluation()


def evaluation():
    """Evaluate the performance of AI.
    """
    pass


if __name__ == '__main__':
    # get the DeepQNetwork Agent
    RL = DeepQNetwork(
        N_ACTIONS,
        N_FEATURES,
        learning_rate=0.001,
        reward_decay=0.9,
        e_greedy=0.9,
        replace_target_iter=300,
        memory_size=2000,
        e_greedy_increment=None,
        e_policy_threshold=REPLY_START_THRESHOLD,
    )

    # Calculate running time
    start_time = time.time()
    run()
    end_time = time.time()
    running_time = (end_time - start_time) / 60
    print('running_time: ' + str(running_time) + 'min')