def test_sampling_awrtrainer_mountain_acr(self): """Test-runs Sampling AWR on MountainCarContinuous.""" task = rl_task.RLTask('MountainCarContinuous-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.SamplingAWRTrainer( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, q_value_n_samples=3, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_sampling_awrtrainer_cartpole_sample_all_discrete(self): """Test-runs AWR on cartpole with Transformer, n_actions = n_samples.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=20) body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.SamplingAWRTrainer( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, q_value_n_samples=2, q_value_aggregate_max=True, reweight=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch)