def test_sanity_awrtrainer_transformer_cartpole(self): """Test-runs AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, max_steps=2) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer( task, n_shared_layers=0, max_slice_length=2, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=1, n_eval_episodes=1) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=200, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, n_trajectories_per_epoch=10, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 35.0)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) policy_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) value_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer(task, n_shared_layers=0, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=1000, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 180.0)
def test_awrtrainer_cartpole_shared(self): """Test-runs AWR on cartpole with shared layers.""" # This test is flaky, and this is the simplest way to retry in OSS. task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) # pylint: disable=g-long-lambda lr = (lambda h: lr_schedules.MultifactorSchedule( h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')) # pylint: enable=g-long-lambda max_avg_returns = -math.inf for _ in range(5): trainer = actor_critic.AWRTrainer( task, n_shared_layers=1, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=200, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, n_trajectories_per_epoch=10, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) max_avg_returns = (max_avg_returns if max_avg_returns > trainer.avg_returns[-1] else trainer.avg_returns[-1]) if trainer.avg_returns[-1] > 35.0: return self.fail( f'We did not reach a score > 35.0, max was {max_avg_returns}.')