def train_sac(ctxt=None): trainer = Trainer(ctxt) env = MyGymEnv(gym_env, max_episode_length=100) policy = CategoricalGRUPolicy(name='policy', env_spec=env.spec, state_include_action=False).to( global_device()) qf1 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) qf2 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) sampler = LocalSampler( agents=policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker) self.algo = LoggedSAC(env=env, env_spec=env.spec, policy=policy, qf1=qf1, qf2=qf2, sampler=sampler, gradient_steps_per_itr=1000, max_episode_length_eval=100, replay_buffer=replay_buffer, min_buffer_size=1e4, target_update_tau=5e-3, discount=0.99, buffer_batch_size=256, reward_scale=1., steps_per_epoch=1) trainer.setup(self.algo, env) trainer.train(n_epochs=n_eps, batch_size=4000) return self.algo.rew_chkpts
def setup(): set_seed(24) n_epochs = 11 steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = 100 * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.9, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64) return algo, env, replay_buffer, n_epochs, sampler_batch_size
def test_forward(batch_size): env_spec = GymEnv(DummyBoxEnv()).spec obs_dim = env_spec.observation_space.flat_dim obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=(2, 2)) qvals = qf(obs) policy = DiscreteQFArgmaxPolicy(qf, env_spec) assert (policy(obs) == torch.argmax(qvals, dim=1)).all() assert policy(obs).shape == (batch_size, )
def test_output_shape(batch_size, hidden_sizes): env_spec = GymEnv(DummyBoxEnv()).spec obs_dim = env_spec.observation_space.flat_dim obs = torch.ones(batch_size, obs_dim, dtype=torch.float32) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output = qf(obs) assert output.shape == (batch_size, env_spec.action_space.flat_dim)
def test_get_action(): env_spec = GymEnv(DummyBoxEnv()).spec obs_dim = env_spec.observation_space.flat_dim obs = torch.ones([ obs_dim, ], dtype=torch.float32) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=(2, 2)) qvals = qf(obs.unsqueeze(0)) policy = DiscreteQFArgmaxPolicy(qf, env_spec) action, _ = policy.get_action(obs.numpy()) assert action == torch.argmax(qvals, dim=1).numpy() assert action.shape == ()
def test_is_pickleable(batch_size): env_spec = GymEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim obs = torch.ones([batch_size, obs_dim], dtype=torch.float32) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=(2, 2)) policy = DiscreteQFArgmaxPolicy(qf, env_spec) output1 = policy.get_actions(obs.numpy())[0] p = pickle.dumps(policy) policy_pickled = pickle.loads(p) output2 = policy_pickled.get_actions(obs.numpy())[0] assert np.array_equal(output1, output2)
def dqn_cartpole(ctxt=None, seed=24): """Train DQN with CartPole-v0 environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. """ set_seed(seed) runner = Trainer(ctxt) n_epochs = 100 steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4) sampler = LocalSampler(agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.9, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size) env.close()
def test_forward(hidden_sizes): env_spec = GymEnv(DummyBoxEnv()).spec obs_dim = env_spec.observation_space.flat_dim obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output = qf(obs) expected_output = torch.full([1, 1], fill_value=(obs_dim) * np.prod(hidden_sizes), dtype=torch.float32) assert torch.eq(output, expected_output).all()
def test_is_pickleable(hidden_sizes): env_spec = GymEnv(DummyBoxEnv()).spec obs_dim = env_spec.observation_space.flat_dim obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) qf = DiscreteMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output1 = qf(obs) p = pickle.dumps(qf) qf_pickled = pickle.loads(p) output2 = qf_pickled(obs) assert torch.eq(output1, output2).all()
def train_dqn(ctxt=None): set_seed(seed) trainer = Trainer(ctxt) env = MyGymEnv(gym_env, max_episode_length=100) steps_per_epoch = 10 sampler_batch_size = 4000 num_timesteps = n_eps * steps_per_epoch * sampler_batch_size replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy( env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.4, ) sampler = LocalSampler( agents=exploration_policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker, ) self.algo = LoggedDQN( env=env, env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, sampler=sampler, steps_per_epoch=steps_per_epoch, qf_lr=5e-5, discount=0.99, min_buffer_size=int(1e4), n_train_steps=500, target_update_freq=30, buffer_batch_size=64, ) trainer.setup(self.algo, env) trainer.train(n_epochs=n_eps, batch_size=sampler_batch_size) return self.algo.rew_chkpts