コード例 #1
0
    def test_jointppotrainer_cartpole(self):
        """Test-runs joint PPO on CartPole."""

        task = rl_task.RLTask('CartPole-v0',
                              initial_trajectories=0,
                              max_steps=2)
        joint_model = functools.partial(
            models.PolicyAndValue,
            body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()),
        )
        lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
            h,
            constant=1e-2,
            warmup_steps=100,
            factors='constant * linear_warmup')

        trainer = actor_critic_joint.PPOJointTrainer(
            task,
            joint_model=joint_model,
            optimizer=opt.Adam,
            lr_schedule=lr,
            batch_size=4,
            train_steps_per_epoch=2,
            n_trajectories_per_epoch=5)
        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
コード例 #2
0
 def test_sanity_a2ctrainer_cartpole(self):
     """Test-runs a2c on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.A2CTrainer(task,
                                       n_shared_layers=1,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=2,
                                       value_train_steps_per_epoch=2,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=2,
                                       policy_train_steps_per_epoch=2,
                                       n_trajectories_per_epoch=2)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
コード例 #3
0
ファイル: training_test.py プロジェクト: ulrikSebastienR/trax
 def test_policytrainer_cartpole(self):
     """Trains a policy on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1,
                           max_steps=200)
     model = functools.partial(
         models.Policy,
         body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
             tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()),
     )
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         n_trajectories_per_epoch=2)
     # Assert that we get to 200 at some point and then exit so the test is as
     # fast as possible.
     for ep in range(200):
         trainer.run(1)
         self.assertEqual(trainer.current_epoch, ep + 1)
         if trainer.avg_returns[-1] == 200.0:
             return
     self.fail('The expected score of 200 has not been reached. '
               'Maximum was {}.'.format(max(trainer.avg_returns)))
コード例 #4
0
 def test_sampling_awrtrainer_mountain_acr(self):
     """Test-runs Sampling AWR on MountainCarContinuous."""
     task = rl_task.RLTask('MountainCarContinuous-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.SamplingAWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=2,
         value_train_steps_per_epoch=2,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=2,
         policy_train_steps_per_epoch=2,
         n_trajectories_per_epoch=2,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
         q_value_n_samples=3,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
コード例 #5
0
 def test_awrtrainer_cartpole(self):
     """Test-runs AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=32,
         value_train_steps_per_epoch=200,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=32,
         policy_train_steps_per_epoch=200,
         n_trajectories_per_epoch=10,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
     self.assertGreater(trainer.avg_returns[-1], 35.0)
コード例 #6
0
    def test_sanity_ppo_cartpole(self):
        """Run PPO and check whether it correctly runs for 2 epochs.s."""
        task = rl_task.RLTask('CartPole-v1',
                              initial_trajectories=0,
                              max_steps=200)

        lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
            h,
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')

        body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)
        trainer = actor_critic.PPOTrainer(task,
                                          n_shared_layers=1,
                                          value_model=value_model,
                                          value_optimizer=opt.Adam,
                                          value_lr_schedule=lr,
                                          value_batch_size=128,
                                          value_train_steps_per_epoch=10,
                                          policy_model=policy_model,
                                          policy_optimizer=opt.Adam,
                                          policy_lr_schedule=lr,
                                          policy_batch_size=128,
                                          policy_train_steps_per_epoch=10,
                                          n_trajectories_per_epoch=10)

        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
コード例 #7
0
 def test_awrtrainer_cartpole_shared(self):
     """Test-runs AWR on cartpole with shared layers."""
     # This test is flaky, and this is the simplest way to retry in OSS.
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     # pylint: disable=g-long-lambda
     lr = (lambda h: lr_schedules.MultifactorSchedule(
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup'))
     # pylint: enable=g-long-lambda
     max_avg_returns = -math.inf
     for _ in range(5):
         trainer = actor_critic.AWRTrainer(
             task,
             n_shared_layers=1,
             added_policy_slice_length=1,
             value_model=value_model,
             value_optimizer=opt.Adam,
             value_lr_schedule=lr,
             value_batch_size=32,
             value_train_steps_per_epoch=200,
             policy_model=policy_model,
             policy_optimizer=opt.Adam,
             policy_lr_schedule=lr,
             policy_batch_size=32,
             policy_train_steps_per_epoch=200,
             n_trajectories_per_epoch=10,
             advantage_estimator=advantages.monte_carlo,
             advantage_normalization=False,
         )
         trainer.run(1)
         self.assertEqual(1, trainer.current_epoch)
         max_avg_returns = (max_avg_returns
                            if max_avg_returns > trainer.avg_returns[-1]
                            else trainer.avg_returns[-1])
         if trainer.avg_returns[-1] > 35.0:
             return
     self.fail(
         f'We did not reach a score > 35.0, max was {max_avg_returns}.')
コード例 #8
0
 def test_sanity_awrtrainer_transformer_cartpole(self):
     """Test-runs AWR on cartpole with Transformer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=2,
                           max_steps=2)
     body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
         d_model=2,
         d_ff=2,
         n_layers=1,
         n_heads=1,
         mode=mode)
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(task,
                                       n_shared_layers=0,
                                       max_slice_length=2,
                                       added_policy_slice_length=1,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=2,
                                       value_train_steps_per_epoch=2,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=2,
                                       policy_train_steps_per_epoch=2,
                                       n_trajectories_per_epoch=1,
                                       n_eval_episodes=1)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)