Esempio n. 1
0
 def test_awrjoint_save_restore(self):
     """Check save and restore of joint AWR trainer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=100,
                           max_steps=200)
     joint_model = functools.partial(
         models.PolicyAndValue,
         body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
     )
     tmp_dir = self.create_tempdir().full_path
     trainer1 = actor_critic_joint.AWRJointTrainer(
         task,
         joint_model=joint_model,
         optimizer=opt.Adam,
         batch_size=4,
         train_steps_per_epoch=1,
         n_trajectories_per_epoch=2,
         output_dir=tmp_dir)
     trainer1.run(2)
     self.assertEqual(trainer1.current_epoch, 2)
     self.assertEqual(trainer1._trainer.step, 2)
     # Trainer 2 starts where trainer 1 stopped.
     trainer2 = actor_critic_joint.AWRJointTrainer(
         task,
         joint_model=joint_model,
         optimizer=opt.Adam,
         batch_size=4,
         train_steps_per_epoch=1,
         n_trajectories_per_epoch=2,
         output_dir=tmp_dir)
     trainer2.run(1)
     self.assertEqual(trainer2.current_epoch, 3)
     self.assertEqual(trainer2._trainer.step, 3)
     trainer1.close()
     trainer2.close()
Esempio n. 2
0
 def test_jointawrtrainer_cartpole(self):
     """Test-runs joint AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     shared_model = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_top = lambda mode: tl.Serial(tl.Dense(2), tl.LogSoftmax())
     value_top = lambda mode: tl.Dense(1)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic_joint.AWRJointTrainer(
         task,
         shared_model=shared_model,
         policy_top=policy_top,
         value_top=value_top,
         optimizer=opt.Adam,
         lr_schedule=lr,
         batch_size=32,
         train_steps_per_epoch=1000,
         collect_per_epoch=10)
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
Esempio n. 3
0
 def test_jointawrtrainer_cartpole_transformer(self):
   """Test-runs joint AWR on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode)
   joint_model = functools.partial(models.PolicyAndValue, body=body)
   trainer = actor_critic_joint.AWRJointTrainer(
       task,
       joint_model=joint_model,
       optimizer=opt.Adam,
       batch_size=4,
       train_steps_per_epoch=2,
       collect_per_epoch=2,
       max_slice_length=128)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)
Esempio n. 4
0
 def test_jointawrtrainer_cartpole(self):
   """Test-runs joint AWR on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   joint_model = functools.partial(
       models.PolicyAndValue,
       body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
   )
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic_joint.AWRJointTrainer(
       task,
       joint_model=joint_model,
       optimizer=opt.Adam,
       lr_schedule=lr,
       batch_size=4,
       train_steps_per_epoch=2,
       collect_per_epoch=5)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)