def test_sanity_awrtrainer_transformer_cartpole(self): """Test-runs AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, max_steps=2) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWR(task, n_shared_layers=0, max_slice_length=2, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=1, n_eval_episodes=1) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_jointa2ctrainer_cartpole_transformer(self): """Test-runs joint A2C on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode) joint_model = functools.partial(models.PolicyAndValue, body=body) trainer = actor_critic_joint.A2CJointTrainer( task, joint_model=joint_model, optimizer=opt.RMSProp, batch_size=4, train_steps_per_epoch=2, collect_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)