コード例 #1
0
    def test_boxing(self):
        """Test-runs PPO on Boxing."""
        env = environments.load_from_settings(platform='atari',
                                              settings={
                                                  'levelName': 'boxing',
                                                  'interleaved_pixels': True,
                                                  'zero_indexed_actions': True
                                              })
        env = atari_wrapper.AtariWrapper(environment=env, num_stacked_frames=1)

        task = rl_task.RLTask(env,
                              initial_trajectories=20,
                              dm_suite=True,
                              max_steps=200)

        body = lambda mode: atari_cnn.AtariCnnBody()

        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)

        # pylint: disable=g-long-lambda
        lr_value = lambda h: lr_schedules.MultifactorSchedule(
            h,
            constant=3e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')
        lr_policy = lambda h: lr_schedules.MultifactorSchedule(
            h,
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')
        # pylint: enable=g-long-lambda

        trainer = actor_critic.PPOTrainer(task,
                                          n_shared_layers=0,
                                          value_model=value_model,
                                          value_optimizer=opt.Adam,
                                          value_lr_schedule=lr_value,
                                          value_batch_size=1,
                                          value_train_steps_per_epoch=1,
                                          policy_model=policy_model,
                                          policy_optimizer=opt.Adam,
                                          policy_lr_schedule=lr_policy,
                                          policy_batch_size=1,
                                          policy_train_steps_per_epoch=1,
                                          collect_per_epoch=10)
        trainer.run(2)
        # Make sure that we test everywhere at least for 2 epochs, beucase
        # the first epoch is different
        self.assertEqual(2, trainer.current_epoch)
コード例 #2
0
ファイル: training_test.py プロジェクト: victorustc/trax
 def test_policytrainer_cartpole(self):
     """Trains a policy on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1,
                           max_steps=200)
     model = functools.partial(
         models.Policy,
         body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
             tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()),
     )
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         collect_per_epoch=2)
     # Assert that we get to 200 at some point and then exit so the test is as
     # fast as possible.
     for ep in range(200):
         trainer.run(1)
         self.assertEqual(trainer.current_epoch, ep + 1)
         if trainer.avg_returns[-1] == 200.0:
             return
     self.fail('The expected score of 200 has not been reached. '
               'Maximum was {}.'.format(max(trainer.avg_returns)))
コード例 #3
0
ファイル: actor_critic_test.py プロジェクト: zhaoqiuye/trax
 def test_sanity_a2ctrainer_cartpole(self):
     """Test-runs a2c on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.A2CTrainer(task,
                                       n_shared_layers=1,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=2,
                                       value_train_steps_per_epoch=2,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=2,
                                       policy_train_steps_per_epoch=2,
                                       n_trajectories_per_epoch=2)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
コード例 #4
0
ファイル: actor_critic_test.py プロジェクト: zhaoqiuye/trax
 def test_sampling_awrtrainer_mountain_acr(self):
     """Test-runs Sampling AWR on MountainCarContinuous."""
     task = rl_task.RLTask('MountainCarContinuous-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.SamplingAWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=2,
         value_train_steps_per_epoch=2,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=2,
         policy_train_steps_per_epoch=2,
         n_trajectories_per_epoch=2,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
         q_value_n_samples=3,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
コード例 #5
0
ファイル: actor_critic_test.py プロジェクト: zhaoqiuye/trax
 def test_awrtrainer_cartpole(self):
     """Test-runs AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=32,
         value_train_steps_per_epoch=200,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=32,
         policy_train_steps_per_epoch=200,
         n_trajectories_per_epoch=10,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
     self.assertGreater(trainer.avg_returns[-1], 35.0)
コード例 #6
0
ファイル: actor_critic_test.py プロジェクト: zhaoqiuye/trax
    def test_sanity_ppo_cartpole(self):
        """Run PPO and check whether it correctly runs for 2 epochs.s."""
        task = rl_task.RLTask('CartPole-v1',
                              initial_trajectories=0,
                              max_steps=200)

        lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
            h,
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')

        body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)
        trainer = actor_critic.PPOTrainer(task,
                                          n_shared_layers=1,
                                          value_model=value_model,
                                          value_optimizer=opt.Adam,
                                          value_lr_schedule=lr,
                                          value_batch_size=128,
                                          value_train_steps_per_epoch=10,
                                          policy_model=policy_model,
                                          policy_optimizer=opt.Adam,
                                          policy_lr_schedule=lr,
                                          policy_batch_size=128,
                                          policy_train_steps_per_epoch=10,
                                          n_trajectories_per_epoch=10)

        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
コード例 #7
0
 def test_jointawrtrainer_cartpole(self):
     """Test-runs joint AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     shared_model = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_top = lambda mode: tl.Serial(tl.Dense(2), tl.LogSoftmax())
     value_top = lambda mode: tl.Dense(1)
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic_joint.AWRJointTrainer(
         task,
         shared_model=shared_model,
         policy_top=policy_top,
         value_top=value_top,
         optimizer=opt.Adam,
         lr_schedule=lr,
         batch_size=32,
         train_steps_per_epoch=1000,
         collect_per_epoch=10)
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
コード例 #8
0
ファイル: actor_critic_test.py プロジェクト: victorustc/trax
 def test_sanity_awrtrainer_transformer_cartpole(self):
   """Test-runs AWR on cartpole with Transformer."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
       d_model=32, d_ff=32, n_layers=1, n_heads=1, mode=mode)
   policy_model = functools.partial(models.Policy, body=body)
   value_model = functools.partial(models.Value, body=body)
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic.AWRTrainer(
       task,
       n_shared_layers=0,
       max_slice_length=127,
       added_policy_slice_length=1,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_lr_schedule=lr,
       value_batch_size=4,
       value_train_steps_per_epoch=200,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=4,
       policy_train_steps_per_epoch=200,
       collect_per_epoch=10)
   trainer.run(2)
   self.assertEqual(2, trainer.current_epoch)
コード例 #9
0
ファイル: actor_critic_test.py プロジェクト: u03410050/trax
 def test_a2ctrainer_cartpole(self):
     """Test-runs a2c on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1,
                           max_steps=2)
     policy_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
     value_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(1))
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AdvantageActorCriticTrainer(
         task,
         n_shared_layers=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=2,
         value_train_steps_per_epoch=2,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=2,
         policy_train_steps_per_epoch=2,
         collect_per_epoch=2)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
コード例 #10
0
ファイル: actor_critic_test.py プロジェクト: u03410050/trax
 def test_awrtrainer_cartpole(self):
     """Test-runs AWR on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     policy_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
     value_model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(1))
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWRTrainer(task,
                                       n_shared_layers=0,
                                       value_model=value_model,
                                       value_optimizer=opt.Adam,
                                       value_lr_schedule=lr,
                                       value_batch_size=32,
                                       value_train_steps_per_epoch=1000,
                                       policy_model=policy_model,
                                       policy_optimizer=opt.Adam,
                                       policy_lr_schedule=lr,
                                       policy_batch_size=32,
                                       policy_train_steps_per_epoch=1000,
                                       collect_per_epoch=10)
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
     self.assertGreater(trainer.avg_returns[-1], 180.0)
コード例 #11
0
    def test_jointppotrainer_cartpole(self):
        """Test-runs joint PPO on CartPole."""

        task = rl_task.RLTask('CartPole-v0',
                              initial_trajectories=0,
                              max_steps=2)
        joint_model = functools.partial(
            models.PolicyAndValue,
            body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()),
        )
        lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
            h,
            constant=1e-2,
            warmup_steps=100,
            factors='constant * linear_warmup')

        trainer = actor_critic_joint.PPOJointTrainer(
            task,
            joint_model=joint_model,
            optimizer=opt.Adam,
            lr_schedule=lr,
            batch_size=4,
            train_steps_per_epoch=2,
            n_trajectories_per_epoch=5)
        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
コード例 #12
0
ファイル: training_test.py プロジェクト: zsunpku/trax
 def test_policytrainer_cartpole(self):
   """Trains a policy on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=100,
                         max_steps=200)
   model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
       tl.Dense(32), tl.Relu(), tl.Dense(3), tl.LogSoftmax())
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-3, warmup_steps=100, factors='constant * linear_warmup')
   trainer = training.ExamplePolicyTrainer(task, model, opt.Adam, lr)
   trainer.run(1)
   self.assertEqual(1, trainer.current_epoch)
コード例 #13
0
 def test_policytrainer_cartpole(self):
   """Trains a policy on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=750,
                         max_steps=200)
   model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
       tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup')
   trainer = training.PolicyGradient(
       task, model, opt.Adam, lr_schedule=lr, batch_size=128,
       train_steps_per_epoch=700, collect_per_epoch=50)
   trainer.run(1)
   # This should *mostly* pass, this means that this test is flaky.
   self.assertGreater(trainer.avg_returns[-1], 35.0)
   self.assertEqual(1, trainer.current_epoch)
コード例 #14
0
ファイル: actor_critic_test.py プロジェクト: zhaoqiuye/trax
 def test_awrtrainer_cartpole_shared(self):
     """Test-runs AWR on cartpole with shared layers."""
     # This test is flaky, and this is the simplest way to retry in OSS.
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     # pylint: disable=g-long-lambda
     lr = (lambda h: lr_schedules.MultifactorSchedule(
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup'))
     # pylint: enable=g-long-lambda
     max_avg_returns = -math.inf
     for _ in range(5):
         trainer = actor_critic.AWRTrainer(
             task,
             n_shared_layers=1,
             added_policy_slice_length=1,
             value_model=value_model,
             value_optimizer=opt.Adam,
             value_lr_schedule=lr,
             value_batch_size=32,
             value_train_steps_per_epoch=200,
             policy_model=policy_model,
             policy_optimizer=opt.Adam,
             policy_lr_schedule=lr,
             policy_batch_size=32,
             policy_train_steps_per_epoch=200,
             n_trajectories_per_epoch=10,
             advantage_estimator=advantages.monte_carlo,
             advantage_normalization=False,
         )
         trainer.run(1)
         self.assertEqual(1, trainer.current_epoch)
         max_avg_returns = (max_avg_returns
                            if max_avg_returns > trainer.avg_returns[-1]
                            else trainer.avg_returns[-1])
         if trainer.avg_returns[-1] > 35.0:
             return
     self.fail(
         f'We did not reach a score > 35.0, max was {max_avg_returns}.')
コード例 #15
0
 def test_policytrainer_cartpole(self):
   """Trains a policy on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1,
                         max_steps=2)
   # TODO(pkozakowski): Use Distribution.n_inputs to initialize the action
   # head.
   model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
       tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup')
   trainer = training.PolicyGradientTrainer(
       task,
       policy_model=model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       collect_per_epoch=2)
   trainer.run(1)
   self.assertEqual(1, trainer.current_epoch)