Ejemplo n.º 1
0
 def test_awrjoint_save_restore(self):
     """Check save and restore of joint AWR trainer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=2,
                           max_steps=2)
     joint_model = functools.partial(
         models.PolicyAndValue,
         body=lambda mode: tl.Serial(tl.Dense(4), tl.Relu()),
     )
     tmp_dir = self.create_tempdir().full_path
     trainer1 = actor_critic_joint.AWRJoint(task,
                                            joint_model=joint_model,
                                            optimizer=opt.Adam,
                                            batch_size=4,
                                            train_steps_per_epoch=1,
                                            n_trajectories_per_epoch=2,
                                            output_dir=tmp_dir)
     trainer1.run(2)
     self.assertEqual(trainer1.current_epoch, 2)
     self.assertEqual(trainer1._trainer.step, 2)
     # Agent 2 starts where agent 1 stopped.
     trainer2 = actor_critic_joint.AWRJoint(task,
                                            joint_model=joint_model,
                                            optimizer=opt.Adam,
                                            batch_size=4,
                                            train_steps_per_epoch=1,
                                            n_trajectories_per_epoch=2,
                                            output_dir=tmp_dir)
     trainer2.run(1)
     self.assertEqual(trainer2.current_epoch, 3)
     self.assertEqual(trainer2._trainer.step, 3)
     trainer1.close()
     trainer2.close()
Ejemplo n.º 2
0
 def test_jointawrtrainer_cartpole_transformer(self):
   """Test-runs joint AWR on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1,
                         max_steps=2)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode)
   joint_model = functools.partial(models.PolicyAndValue, body=body)
   trainer = actor_critic_joint.AWRJoint(
       task,
       joint_model=joint_model,
       optimizer=opt.Adam,
       batch_size=4,
       train_steps_per_epoch=2,
       n_trajectories_per_epoch=2,
       max_slice_length=2)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)
Ejemplo n.º 3
0
 def test_jointawrtrainer_cartpole(self):
   """Test-runs joint AWR on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1,
                         max_steps=2)
   joint_model = functools.partial(
       models.PolicyAndValue,
       body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
   )
   lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
       constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic_joint.AWRJoint(
       task,
       joint_model=joint_model,
       optimizer=opt.Adam,
       lr_schedule=lr,
       batch_size=4,
       train_steps_per_epoch=2,
       n_trajectories_per_epoch=5)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)