Exemplo n.º 1
0
 def test_sampling_awrtrainer_mountain_acr(self):
     """Test-runs Sampling AWR on MountainCarContinuous."""
     task = rl_task.RLTask('MountainCarContinuous-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.SamplingAWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=2,
         value_train_steps_per_epoch=2,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=2,
         policy_train_steps_per_epoch=2,
         n_trajectories_per_epoch=2,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
         q_value_n_samples=3,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
Exemplo n.º 2
0
 def test_sampling_awrtrainer_cartpole_sample_all_discrete(self):
   """Test-runs AWR on cartpole with Transformer, n_actions = n_samples."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=0,
                         max_steps=20)
   body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
   policy_model = functools.partial(models.Policy, body=body)
   value_model = functools.partial(models.Value, body=body)
   lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
       constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic.SamplingAWRTrainer(
       task,
       n_shared_layers=0,
       added_policy_slice_length=1,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_lr_schedule=lr,
       value_batch_size=2,
       value_train_steps_per_epoch=2,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       n_trajectories_per_epoch=2,
       advantage_estimator=advantages.monte_carlo,
       advantage_normalization=False,
       q_value_n_samples=2,
       q_value_aggregate_max=True,
       reweight=False,
   )
   trainer.run(1)
   self.assertEqual(1, trainer.current_epoch)