def test_wrap_dqn(self): env = wrap_dqn(gym.make("SpaceInvadersNoFrameskip-v4"), wrap_ndarray=True) obs = env.reset() self.assertEqual(type(obs), np.ndarray) self.assertEqual(obs.shape, (84, 84, 4))
if __name__ == '__main__': parser = Trainer.get_argument() parser = DQN.get_argument(parser) parser.add_argument('--env-name', type=str, default="SpaceInvadersNoFrameskip-v4") parser.set_defaults(episode_max_steps=108000) parser.set_defaults(test_interval=10000) parser.set_defaults(max_steps=int(1e9)) parser.set_defaults(save_model_interval=500000) parser.set_defaults(gpu=0) parser.set_defaults(show_test_images=True) parser.set_defaults(memory_capacity=int(1e6)) args = parser.parse_args() env = wrap_dqn(gym.make(args.env_name)) test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False) # Following parameters are equivalent to DeepMind DQN paper # https://www.nature.com/articles/nature14236 policy = DQN( enable_double_dqn=args.enable_double_dqn, enable_dueling_dqn=args.enable_dueling_dqn, enable_noisy_dqn=args.enable_noisy_dqn, state_shape=env.observation_space.shape, action_dim=env.action_space.n, lr=0.0000625, # This value is from Rainbow adam_eps=1.5e-4, # This value is from Rainbow n_warmup=50000, target_replace_interval=10000, batch_size=32, memory_capacity=args.memory_capacity,
test_env = gym.make(args.env_name) if is_atari_env(env): # Parameters come from Appendix.B in original paper. # See https://arxiv.org/abs/1910.07207 parser.set_defaults(episode_max_steps=108000) parser.set_defaults(test_interval=int(1e5)) parser.set_defaults(show_test_images=True) parser.set_defaults(max_steps=int(1e9)) parser.set_defaults(target_update_interval=8000) parser.set_defaults(n_warmup=int(2e4)) args = parser.parse_args() if args.gpu == -1: print("Are you sure you're trying to solve Atari without GPU?") env = wrap_dqn(env, wrap_ndarray=True) test_env = wrap_dqn(test_env, wrap_ndarray=True, reward_clipping=False) policy = SACDiscrete( state_shape=env.observation_space.shape, action_dim=env.action_space.n, discount=0.99, critic_fn=AtariQFunc, actor_fn=AtariCategoricalActor, lr=3e-4, memory_capacity=args.memory_capacity, batch_size=64, n_warmup=args.n_warmup, update_interval=4, target_update_interval=args.target_update_interval, auto_alpha=args.auto_alpha, gpu=args.gpu)