def test_forward(batch_size, hidden_channels, kernel_sizes, strides): env_spec = GymEnv(DummyBoxEnv(obs_dim=(3, 10, 10))).spec obs_dim = env_spec.observation_space.shape obs = torch.zeros((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteCNNQFunction(env_spec=env_spec, kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, cnn_hidden_nonlinearity=None, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, is_image=False) output = qf(obs) expected_output = torch.zeros(output.shape) assert output.shape == (batch_size, env_spec.action_space.flat_dim) assert torch.eq(output, expected_output).all()
def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides): env_spec = GymEnv(DummyBoxEnv(obs_dim=(3, 10, 10))).spec obs_dim = env_spec.observation_space.shape obs = torch.ones((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteCNNQFunction(env_spec=env_spec, kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, cnn_hidden_nonlinearity=None, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, is_image=False) output1 = qf(obs) p = pickle.dumps(qf) qf_pickled = pickle.loads(p) output2 = qf_pickled(obs) assert torch.eq(output1, output2).all()
def dqn_atari(ctxt=None, env=None, seed=24, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, **kwargs): """Train DQN with PongNoFrameskip-v4 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by Trainer to create the snapshotter. env (str): Name of the atari environment, eg. 'PongNoFrameskip-v4'. seed (int): Used to seed the random number generator to produce determinism. n_workers (int): Number of workers to use. Defaults to the number of CPU cores available. max_episode_length (int): Max length of an episode. If None, defaults to the timelimit specific to the environment. Used by integration tests. kwargs (dict): hyperparameters to be saved to variant.json. """ assert n_workers > 0 assert env is not None env = gym.make(env) env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4, axis=0) env = GymEnv(env, max_episode_length=max_episode_length, is_image=True) set_seed(seed) trainer = Trainer(ctxt) n_epochs = hyperparams['n_epochs'] steps_per_epoch = hyperparams['steps_per_epoch'] sampler_batch_size = hyperparams['sampler_batch_size'] num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size replay_buffer = PathBuffer( capacity_in_transitions=hyperparams['buffer_size']) qf = DiscreteCNNQFunction( env_spec=env.spec, image_format='NCHW', hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), hidden_sizes=hyperparams['hidden_sizes']) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=hyperparams['max_epsilon'], min_epsilon=hyperparams['min_epsilon'], decay_ratio=hyperparams['decay_ratio']) sampler = LocalSampler(agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker, n_workers=n_workers) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=hyperparams['lr'], clip_gradient=hyperparams['clip_gradient'], discount=hyperparams['discount'], min_buffer_size=hyperparams['min_buffer_size'], n_train_steps=hyperparams['n_train_steps'], target_update_freq=hyperparams['target_update_freq'], buffer_batch_size=hyperparams['buffer_batch_size']) set_gpu_mode(False) torch.set_num_threads(1) if torch.cuda.is_available(): set_gpu_mode(True) algo.to() trainer.setup(algo, env) trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size) env.close()