def test_awrjoint_save_restore(self): """Check save and restore of joint AWR trainer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), ) tmp_dir = self.create_tempdir().full_path trainer1 = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer1.run(2) self.assertEqual(trainer1.current_epoch, 2) self.assertEqual(trainer1._trainer.step, 2) # Trainer 2 starts where trainer 1 stopped. trainer2 = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer2.run(1) self.assertEqual(trainer2.current_epoch, 3) self.assertEqual(trainer2._trainer.step, 3) trainer1.close() trainer2.close()
def test_jointawrtrainer_cartpole(self): """Test-runs joint AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) shared_model = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_top = lambda mode: tl.Serial(tl.Dense(2), tl.LogSoftmax()) value_top = lambda mode: tl.Dense(1) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.AWRJointTrainer( task, shared_model=shared_model, policy_top=policy_top, value_top=value_top, optimizer=opt.Adam, lr_schedule=lr, batch_size=32, train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_jointawrtrainer_cartpole_transformer(self): """Test-runs joint AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode) joint_model = functools.partial(models.PolicyAndValue, body=body) trainer = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=2, collect_per_epoch=2, max_slice_length=128) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_jointawrtrainer_cartpole(self): """Test-runs joint AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), ) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, lr_schedule=lr, batch_size=4, train_steps_per_epoch=2, collect_per_epoch=5) trainer.run(2) self.assertEqual(2, trainer.current_epoch)