示例#1
0
 def test_sanity_awrtrainer_transformer_cartpole(self):
   """Test-runs AWR on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=2,
                         max_steps=2)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode)
   policy_model = functools.partial(models.Policy, body=body)
   value_model = functools.partial(models.Value, body=body)
   lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
       constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic.AWRTrainer(
       task,
       n_shared_layers=0,
       max_slice_length=2,
       added_policy_slice_length=1,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_lr_schedule=lr,
       value_batch_size=2,
       value_train_steps_per_epoch=2,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       n_trajectories_per_epoch=1,
       n_eval_episodes=1)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)
示例#2
0
 def test_awrtrainer_cartpole(self):
     """Test-runs AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=32,
         value_train_steps_per_epoch=200,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=32,
         policy_train_steps_per_epoch=200,
         n_trajectories_per_epoch=10,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
     self.assertGreater(trainer.avg_returns[-1], 35.0)
示例#3
0
 def test_awrtrainer_cartpole(self):
     """Test-runs AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     policy_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
     value_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(1))
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(task,
                                       n_shared_layers=0,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=32,
                                       value_train_steps_per_epoch=1000,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=32,
                                       policy_train_steps_per_epoch=1000,
                                       collect_per_epoch=10)
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
     self.assertGreater(trainer.avg_returns[-1], 180.0)
示例#4
0
 def test_awrtrainer_cartpole_shared(self):
     """Test-runs AWR on cartpole with shared layers."""
     # This test is flaky, and this is the simplest way to retry in OSS.
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     # pylint: disable=g-long-lambda
     lr = (lambda h: lr_schedules.MultifactorSchedule(
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup'))
     # pylint: enable=g-long-lambda
     max_avg_returns = -math.inf
     for _ in range(5):
         trainer = actor_critic.AWRTrainer(
             task,
             n_shared_layers=1,
             added_policy_slice_length=1,
             value_model=value_model,
             value_optimizer=opt.Adam,
             value_lr_schedule=lr,
             value_batch_size=32,
             value_train_steps_per_epoch=200,
             policy_model=policy_model,
             policy_optimizer=opt.Adam,
             policy_lr_schedule=lr,
             policy_batch_size=32,
             policy_train_steps_per_epoch=200,
             n_trajectories_per_epoch=10,
             advantage_estimator=advantages.monte_carlo,
             advantage_normalization=False,
         )
         trainer.run(1)
         self.assertEqual(1, trainer.current_epoch)
         max_avg_returns = (max_avg_returns
                            if max_avg_returns > trainer.avg_returns[-1]
                            else trainer.avg_returns[-1])
         if trainer.avg_returns[-1] > 35.0:
             return
     self.fail(
         f'We did not reach a score > 35.0, max was {max_avg_returns}.')