예제 #1
0
    def test_wrap_dqn(self):
        env = wrap_dqn(gym.make("SpaceInvadersNoFrameskip-v4"),
                       wrap_ndarray=True)

        obs = env.reset()
        self.assertEqual(type(obs), np.ndarray)
        self.assertEqual(obs.shape, (84, 84, 4))
예제 #2
0
if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    parser.set_defaults(episode_max_steps=108000)
    parser.set_defaults(test_interval=10000)
    parser.set_defaults(max_steps=int(1e9))
    parser.set_defaults(save_model_interval=500000)
    parser.set_defaults(gpu=0)
    parser.set_defaults(show_test_images=True)
    parser.set_defaults(memory_capacity=int(1e6))
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)
    # Following parameters are equivalent to DeepMind DQN paper
    # https://www.nature.com/articles/nature14236
    policy = DQN(
        enable_double_dqn=args.enable_double_dqn,
        enable_dueling_dqn=args.enable_dueling_dqn,
        enable_noisy_dqn=args.enable_noisy_dqn,
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.n,
        lr=0.0000625,  # This value is from Rainbow
        adam_eps=1.5e-4,  # This value is from Rainbow
        n_warmup=50000,
        target_replace_interval=10000,
        batch_size=32,
        memory_capacity=args.memory_capacity,
예제 #3
0
    test_env = gym.make(args.env_name)

    if is_atari_env(env):
        # Parameters come from Appendix.B in original paper.
        # See https://arxiv.org/abs/1910.07207
        parser.set_defaults(episode_max_steps=108000)
        parser.set_defaults(test_interval=int(1e5))
        parser.set_defaults(show_test_images=True)
        parser.set_defaults(max_steps=int(1e9))
        parser.set_defaults(target_update_interval=8000)
        parser.set_defaults(n_warmup=int(2e4))
        args = parser.parse_args()
        if args.gpu == -1:
            print("Are you sure you're trying to solve Atari without GPU?")

        env = wrap_dqn(env, wrap_ndarray=True)
        test_env = wrap_dqn(test_env, wrap_ndarray=True, reward_clipping=False)
        policy = SACDiscrete(
            state_shape=env.observation_space.shape,
            action_dim=env.action_space.n,
            discount=0.99,
            critic_fn=AtariQFunc,
            actor_fn=AtariCategoricalActor,
            lr=3e-4,
            memory_capacity=args.memory_capacity,
            batch_size=64,
            n_warmup=args.n_warmup,
            update_interval=4,
            target_update_interval=args.target_update_interval,
            auto_alpha=args.auto_alpha,
            gpu=args.gpu)