Ejemplo n.º 1
0
 def test_sanity_a2ctrainer_cartpole(self):
     """Test-runs a2c on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.A2CTrainer(task,
                                       n_shared_layers=1,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=2,
                                       value_train_steps_per_epoch=2,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=2,
                                       policy_train_steps_per_epoch=2,
                                       n_trajectories_per_epoch=2)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
Ejemplo n.º 2
0
 def test_a2ctrainer_save_restore(self):
   """Check save and restore of A2C trainer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=0,
                         max_steps=20)
   body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
   policy_model = functools.partial(models.Policy, body=body)
   value_model = functools.partial(models.Value, body=body)
   tmp_dir = self.create_tempdir().full_path
   trainer1 = actor_critic.A2CTrainer(
       task,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_batch_size=2,
       value_train_steps_per_epoch=1,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       n_trajectories_per_epoch=2,
       n_shared_layers=1,
       output_dir=tmp_dir)
   trainer1.run(2)
   self.assertEqual(trainer1.current_epoch, 2)
   self.assertEqual(trainer1._value_trainer.step, 2)
   self.assertEqual(trainer1._policy_trainer.step, 4)
   # Trainer 2 starts where trainer 1 stopped.
   trainer2 = actor_critic.A2CTrainer(
       task,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_batch_size=2,
       value_train_steps_per_epoch=1,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       n_trajectories_per_epoch=2,
       n_shared_layers=1,
       output_dir=tmp_dir)
   trainer2.run(1)
   self.assertEqual(trainer2.current_epoch, 3)
   self.assertEqual(trainer2._value_trainer.step, 3)
   self.assertEqual(trainer2._policy_trainer.step, 6)
   trainer1.close()
   trainer2.close()