def test_awrjoint_save_restore(self): """Check save and restore of joint AWR trainer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, max_steps=2) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(4), tl.Relu()), ) tmp_dir = self.create_tempdir().full_path trainer1 = actor_critic_joint.AWRJoint(task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer1.run(2) self.assertEqual(trainer1.current_epoch, 2) self.assertEqual(trainer1._trainer.step, 2) # Agent 2 starts where agent 1 stopped. trainer2 = actor_critic_joint.AWRJoint(task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer2.run(1) self.assertEqual(trainer2.current_epoch, 3) self.assertEqual(trainer2._trainer.step, 3) trainer1.close() trainer2.close()
def test_jointawrtrainer_cartpole_transformer(self): """Test-runs joint AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=2) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode) joint_model = functools.partial(models.PolicyAndValue, body=body) trainer = actor_critic_joint.AWRJoint( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=2, n_trajectories_per_epoch=2, max_slice_length=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_jointawrtrainer_cartpole(self): """Test-runs joint AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=2) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), ) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.AWRJoint( task, joint_model=joint_model, optimizer=opt.Adam, lr_schedule=lr, batch_size=4, train_steps_per_epoch=2, n_trajectories_per_epoch=5) trainer.run(2) self.assertEqual(2, trainer.current_epoch)