Example #1
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)

    model = DeepQ(env=env,
                  policy=CnnPolicy,
                  learning_rate=1e-4,
                  buffer_size=10000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.01,
                  train_freq=4,
                  learning_starts=10000,
                  target_network_update_freq=1000,
                  gamma=0.99,
                  prioritized_replay=True,
                  prioritized_replay_alpha=0.6,
                  checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Example #2
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2),
                                            (64, 3, 1)],
                                     hiddens=[256],
                                     dueling=True)

    model = DeepQ(env=env,
                  policy=q_func,
                  learning_rate=1e-4,
                  buffer_size=10000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.01,
                  train_freq=4,
                  learning_starts=10000,
                  target_network_update_freq=1000,
                  gamma=0.99,
                  prioritized_replay=True,
                  prioritized_replay_alpha=0.6,
                  checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Example #3
0
def main():
    """
    run the atari test
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    model = DeepQ(
        env=env,
        policy=q_func,
        learning_rate=1e-4,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )
    model.learn(total_timesteps=args.num_timesteps)

    env.close()
Example #4
0
def main(args):
    """
    train and save the DeepQ model, for the cartpole problem

    :param args: (ArgumentParser) the input arguments
    """
    env = gym.make("CartPole-v0")
    model = DeepQ(
        env=env,
        policy=MlpPolicy,
        learning_rate=1e-3,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
    )
    model.learn(total_timesteps=args.max_timesteps, callback=callback)

    print("Saving model to cartpole_model.pkl")
    model.save("cartpole_model.pkl")
Example #5
0
def main(args):
    """
    train and save the DeepQ model, for the mountain car problem

    :param args: (ArgumentParser) the input arguments
    """
    env = gym.make("MountainCar-v0")

    # using layer norm policy here is important for parameter space noise!
    model = DeepQ(policy=CustomPolicy,
                  env=env,
                  learning_rate=1e-3,
                  buffer_size=50000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.1,
                  param_noise=True)
    model.learn(total_timesteps=args.max_timesteps)

    print("Saving model to mountaincar_model.pkl")
    model.save("mountaincar_model.pkl")