def test_sanity_a2ctrainer_cartpole(self): """Test-runs a2c on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.A2C(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_a2ctrainer_save_restore(self): """Check save and restore of A2C trainer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=20) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) tmp_dir = self.create_tempdir().full_path trainer1 = actor_critic.A2C( task, value_model=value_model, value_optimizer=opt.Adam, value_batch_size=2, value_train_steps_per_epoch=1, policy_model=policy_model, policy_optimizer=opt.Adam, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, n_shared_layers=1, output_dir=tmp_dir) trainer1.run(2) self.assertEqual(trainer1.current_epoch, 2) self.assertEqual(trainer1._value_trainer.step, 2) self.assertEqual(trainer1._policy_trainer.step, 4) # Trainer 2 starts where trainer 1 stopped. trainer2 = actor_critic.A2C( task, value_model=value_model, value_optimizer=opt.Adam, value_batch_size=2, value_train_steps_per_epoch=1, policy_model=policy_model, policy_optimizer=opt.Adam, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, n_shared_layers=1, output_dir=tmp_dir) trainer2.run(1) self.assertEqual(trainer2.current_epoch, 3) self.assertEqual(trainer2._value_trainer.step, 3) self.assertEqual(trainer2._policy_trainer.step, 6) trainer1.close() trainer2.close()