Пример #1
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(convs=[(32, 8, 4), (64, 4, 2),
                                            (64, 3, 1)],
                                     hiddens=[256],
                                     dueling=True)

    model = DeepQ(env=env,
                  policy=q_func,
                  learning_rate=1e-4,
                  buffer_size=10000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.01,
                  train_freq=4,
                  learning_starts=10000,
                  target_network_update_freq=1000,
                  gamma=0.99,
                  prioritized_replay=True,
                  prioritized_replay_alpha=0.6,
                  checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Пример #2
0
def test_deepq():
    """
    test DeepQ on atari
    """
    logger.configure()
    set_global_seeds(SEED)
    env = make_atari(ENV_ID)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)

    model = DeepQ(env=env,
                  policy=CnnPolicy,
                  learning_rate=1e-4,
                  buffer_size=10000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.01,
                  train_freq=4,
                  learning_starts=10000,
                  target_network_update_freq=1000,
                  gamma=0.99,
                  prioritized_replay=True,
                  prioritized_replay_alpha=0.6,
                  checkpoint_freq=10000)
    model.learn(total_timesteps=NUM_TIMESTEPS)

    env.close()
    del model, env
Пример #3
0
def main():
    """
    run the atari test
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = wrap_atari_dqn(env)
    q_func = deepq_models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    model = DeepQ(
        env=env,
        policy=q_func,
        learning_rate=1e-4,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )
    model.learn(total_timesteps=args.num_timesteps)

    env.close()
Пример #4
0
def main():
    """
    run a trained model for the pong problem
    """
    env = gym.make("PongNoFrameskip-v4")
    env = deepq.wrap_atari_dqn(env)
    model = DeepQ.load("pong_model.pkl", env)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            action, _ = model.predict(obs)
            obs, rew, done, _ = env.step(action)
            episode_rew += rew
        print("Episode reward", episode_rew)
def main(args):
    """
    run a trained model for the mountain car problem

    :param args: (ArgumentParser) the input arguments
    """
    env = gym.make("MountainCar-v0")
    model = DeepQ.load("mountaincar_model.pkl", env)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            if not args.no_render:
                env.render()
            action, _ = model.predict(obs)
            obs, rew, done, _ = env.step(action)
            episode_rew += rew
        print("Episode reward", episode_rew)
        # No render is only used for automatic testing
        if args.no_render:
            break
Пример #6
0
def main(args):
    """
    train and save the DeepQ model, for the cartpole problem

    :param args: (ArgumentParser) the input arguments
    """
    env = gym.make("CartPole-v0")
    model = DeepQ(
        env=env,
        policy=MlpPolicy,
        learning_rate=1e-3,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
    )
    model.learn(total_timesteps=args.max_timesteps, callback=callback)

    print("Saving model to cartpole_model.pkl")
    model.save("cartpole_model.pkl")
Пример #7
0
def main(args):
    """
    train and save the DeepQ model, for the mountain car problem

    :param args: (ArgumentParser) the input arguments
    """
    env = gym.make("MountainCar-v0")

    # using layer norm policy here is important for parameter space noise!
    model = DeepQ(policy=CustomPolicy,
                  env=env,
                  learning_rate=1e-3,
                  buffer_size=50000,
                  exploration_fraction=0.1,
                  exploration_final_eps=0.1,
                  param_noise=True)
    model.learn(total_timesteps=args.max_timesteps)

    print("Saving model to mountaincar_model.pkl")
    model.save("mountaincar_model.pkl")
Пример #8
0
from stable_baselines.common.identity_env import IdentityEnv
from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.deepq import models as deepq_models

learn_func_list = [
    lambda e: A2C(
        policy=MlpPolicy, learning_rate=1e-3, n_steps=1, gamma=0.7, env=e).
    learn(total_timesteps=10000, seed=0),
    lambda e: ACER(policy=MlpPolicy, env=e, n_steps=1, replay_ratio=1).learn(
        total_timesteps=10000, seed=0),
    lambda e: ACKTR(policy=MlpPolicy, env=e, learning_rate=5e-4, n_steps=1
                    ).learn(total_timesteps=20000, seed=0),
    lambda e: DeepQ(policy=deepq_models.mlp([32]),
                    batch_size=16,
                    gamma=0.1,
                    exploration_fraction=0.001,
                    env=e).learn(total_timesteps=40000, seed=0),
    lambda e: PPO1(policy=MlpPolicy,
                   env=e,
                   lam=0.7,
                   optim_batchsize=16,
                   optim_stepsize=1e-3).learn(total_timesteps=10000, seed=0),
    lambda e: PPO2(policy=MlpPolicy, env=e, learning_rate=1.5e-3, lam=0.8
                   ).learn(total_timesteps=20000, seed=0),
    lambda e: TRPO(policy=MlpPolicy, env=e, max_kl=0.05, lam=0.7).learn(
        total_timesteps=10000, seed=0),
]


@pytest.mark.slow