Esempio n. 1
0
 def test_jointa2ctrainer_cartpole_transformer(self):
   """Test-runs joint A2C on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode)
   joint_model = functools.partial(models.PolicyAndValue, body=body)
   trainer = actor_critic_joint.A2CJointTrainer(
       task,
       joint_model=joint_model,
       optimizer=opt.RMSProp,
       batch_size=4,
       train_steps_per_epoch=2,
       collect_per_epoch=2)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)
Esempio n. 2
0
 def test_jointa2ctrainer_cartpole(self):
   """Test-runs joint A2C on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   joint_model = functools.partial(
       models.PolicyAndValue,
       body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()),
   )
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic_joint.A2CJointTrainer(
       task,
       joint_model=joint_model,
       optimizer=opt.RMSProp,
       lr_schedule=lr,
       batch_size=2,
       train_steps_per_epoch=1,
       collect_per_epoch=1)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)