예제 #1
0
 def test_sanity_awrtrainer_transformer_cartpole(self):
     """Test-runs AWR on cartpole with Transformer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=2,
                           max_steps=2)
     body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
         d_model=2,
         d_ff=2,
         n_layers=1,
         n_heads=1,
         mode=mode)
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWR(task,
                                n_shared_layers=0,
                                max_slice_length=2,
                                added_policy_slice_length=1,
                                value_model=value_model,
                                value_optimizer=opt.Adam,
                                value_lr_schedule=lr,
                                value_batch_size=2,
                                value_train_steps_per_epoch=2,
                                policy_model=policy_model,
                                policy_optimizer=opt.Adam,
                                policy_lr_schedule=lr,
                                policy_batch_size=2,
                                policy_train_steps_per_epoch=2,
                                n_trajectories_per_epoch=1,
                                n_eval_episodes=1)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
예제 #2
0
 def test_jointa2ctrainer_cartpole_transformer(self):
   """Test-runs joint A2C on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode)
   joint_model = functools.partial(models.PolicyAndValue, body=body)
   trainer = actor_critic_joint.A2CJointTrainer(
       task,
       joint_model=joint_model,
       optimizer=opt.RMSProp,
       batch_size=4,
       train_steps_per_epoch=2,
       collect_per_epoch=2)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)