示例#1
0
 def test_sanity_a2ctrainer_cartpole(self):
     """Test-runs a2c on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.A2C(task,
                                n_shared_layers=1,
                                value_model=value_model,
                                value_optimizer=opt.Adam,
                                value_lr_schedule=lr,
                                value_batch_size=2,
                                value_train_steps_per_epoch=2,
                                policy_model=policy_model,
                                policy_optimizer=opt.Adam,
                                policy_lr_schedule=lr,
                                policy_batch_size=2,
                                policy_train_steps_per_epoch=2,
                                n_trajectories_per_epoch=2)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
示例#2
0
 def test_policytrainer_cartpole(self):
   """Trains a policy on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1,
                         max_steps=200)
   model = functools.partial(
       models.Policy,
       body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
           tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()
       ),
   )
   lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
       constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   max_avg_returns = -math.inf
   for _ in range(5):
     trainer = training.PolicyGradient(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         n_trajectories_per_epoch=2)
     # Assert that we get to 200 at some point and then exit so the test is as
     # fast as possible.
     for ep in range(200):
       trainer.run(1)
       self.assertEqual(trainer.current_epoch, ep + 1)
       if trainer.avg_returns[-1] == 200.0:
         return
     max_avg_returns = max(max_avg_returns, trainer.avg_returns[-1])
   self.fail(
       'The expected score of 200 has not been reached. '
       'Maximum at end was {}.'.format(max_avg_returns)
   )
示例#3
0
 def test_sanity_awrtrainer_transformer_cartpole(self):
     """Test-runs AWR on cartpole with Transformer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=2,
                           max_steps=2)
     body = lambda mode: models.TransformerDecoder(  # pylint: disable=g-long-lambda
         d_model=2,
         d_ff=2,
         n_layers=1,
         n_heads=1,
         mode=mode)
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.AWR(task,
                                n_shared_layers=0,
                                max_slice_length=2,
                                added_policy_slice_length=1,
                                value_model=value_model,
                                value_optimizer=opt.Adam,
                                value_lr_schedule=lr,
                                value_batch_size=2,
                                value_train_steps_per_epoch=2,
                                policy_model=policy_model,
                                policy_optimizer=opt.Adam,
                                policy_lr_schedule=lr,
                                policy_batch_size=2,
                                policy_train_steps_per_epoch=2,
                                n_trajectories_per_epoch=1,
                                n_eval_episodes=1)
     trainer.run(2)
     self.assertEqual(2, trainer.current_epoch)
示例#4
0
 def test_sampling_awrtrainer_mountain_acr(self):
     """Test-runs Sampling AWR on MountainCarContinuous."""
     task = rl_task.RLTask('MountainCarContinuous-v0',
                           initial_trajectories=0,
                           max_steps=2)
     body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = actor_critic.SamplingAWR(
         task,
         n_shared_layers=0,
         added_policy_slice_length=1,
         value_model=value_model,
         value_optimizer=opt.Adam,
         value_lr_schedule=lr,
         value_batch_size=2,
         value_train_steps_per_epoch=2,
         policy_model=policy_model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=2,
         policy_train_steps_per_epoch=2,
         n_trajectories_per_epoch=2,
         advantage_estimator=advantages.monte_carlo,
         advantage_normalization=False,
         q_value_n_samples=3,
     )
     trainer.run(1)
     self.assertEqual(1, trainer.current_epoch)
示例#5
0
  def test_no_int32_or_uint32_returned(self):
    """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
    if xla_bridge.device_count() > 1:
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with fastmath.use_backend(fastmath.Backend.TFNP), \
          self.tmp_dir() as output_dir:
      n_classes = 1001
      model_fn = functools.partial(models.Resnet50,
                                   n_output_classes=n_classes)
      inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=tl.CrossEntropyLoss(),
          optimizer=trax_opt.SM3,
          lr_schedule=lr.multifactor(),
          inputs=inputs,
      )
      trainer.reset(output_dir)
      trainer.train_epoch(1, 0)
      # Those are the things returned by Trainer._jit_update_fn
      arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                trainer._model_state, trainer._rngs)
      arrays = tf.nest.flatten(arrays)
      for x in arrays:
        if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or
                                           x.dtype == jnp.uint32):
          raise ValueError('Found an array of int32 or uint32: %s' % x)
示例#6
0
    def test_sanity_ppo_cartpole(self):
        """Run PPO and check whether it correctly runs for 2 epochs.s."""
        task = rl_task.RLTask('CartPole-v1',
                              initial_trajectories=0,
                              max_steps=200)

        lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')

        body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)
        trainer = actor_critic.PPO(task,
                                   n_shared_layers=1,
                                   value_model=value_model,
                                   value_optimizer=opt.Adam,
                                   value_lr_schedule=lr,
                                   value_batch_size=128,
                                   value_train_steps_per_epoch=10,
                                   policy_model=policy_model,
                                   policy_optimizer=opt.Adam,
                                   policy_lr_schedule=lr,
                                   policy_batch_size=128,
                                   policy_train_steps_per_epoch=10,
                                   n_trajectories_per_epoch=10)

        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
示例#7
0
    def test_reset_twice(self, backend):
        if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with fastmath.use_backend(backend):
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.CrossEntropyLoss(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            output_dir1 = self.create_tempdir(name='output_dir1').full_path
            trainer.reset(output_dir1)
            trainer.evaluate(1)
            output_dir2 = self.create_tempdir(name='output_dir2').full_path
            trainer.reset(output_dir2)
            trainer.evaluate(1)
示例#8
0
 def test_awrtrainer_cartpole(self):
   """Test-runs AWR on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000,
                         max_steps=200)
   body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
   policy_model = functools.partial(models.Policy, body=body)
   value_model = functools.partial(models.Value, body=body)
   lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
       constant=1e-2, warmup_steps=100, factors='constant * linear_warmup')
   trainer = actor_critic.AWRTrainer(
       task,
       n_shared_layers=0,
       added_policy_slice_length=1,
       value_model=value_model,
       value_optimizer=opt.Adam,
       value_lr_schedule=lr,
       value_batch_size=32,
       value_train_steps_per_epoch=200,
       policy_model=policy_model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=32,
       policy_train_steps_per_epoch=200,
       n_trajectories_per_epoch=10,
       advantage_estimator=advantages.monte_carlo,
       advantage_normalization=False,
   )
   trainer.run(1)
   self.assertEqual(1, trainer.current_epoch)
   self.assertGreater(trainer.avg_returns[-1], 35.0)
示例#9
0
 def test_policy_gradient_cartpole(self):
     """Trains a policy on cartpole."""
     task = rl_task.RLTask('CartPole-v0', max_steps=200)
     lr = lambda: lr_schedules.multifactor(constant=1e-2,
                                           factors='constant')
     max_avg_returns = -math.inf
     for _ in range(2):
         agent = training.PolicyGradient(
             task,
             model_fn=self._model_fn,
             optimizer=opt.Adam,
             lr_schedule=lr,
             batch_size=128,
             n_trajectories_per_epoch=2,
         )
         # Assert that we get to 200 at some point and then exit so the test is as
         # fast as possible.
         for ep in range(200):
             agent.run(1)
             self.assertEqual(agent.current_epoch, ep + 1)
             if agent.avg_returns[-1] == 200.0:
                 return
         max_avg_returns = max(max_avg_returns, agent.avg_returns[-1])
     self.fail('The expected score of 200 has not been reached. '
               'Maximum at end was {}.'.format(max_avg_returns))
示例#10
0
    def test_jointppotrainer_cartpole(self):
        """Test-runs joint PPO on CartPole."""

        task = rl_task.RLTask('CartPole-v0',
                              initial_trajectories=0,
                              max_steps=2)
        joint_model = functools.partial(
            models.PolicyAndValue,
            body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()),
        )
        lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
            constant=1e-2,
            warmup_steps=100,
            factors='constant * linear_warmup')

        trainer = actor_critic_joint.PPOJointTrainer(
            task,
            joint_model=joint_model,
            optimizer=opt.Adam,
            lr_schedule=lr,
            batch_size=4,
            train_steps_per_epoch=2,
            n_trajectories_per_epoch=5)
        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
示例#11
0
 def test_awrtrainer_cartpole_shared(self):
     """Test-runs AWR on cartpole with shared layers."""
     # This test is flaky, and this is the simplest way to retry in OSS.
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1000,
                           max_steps=200)
     body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
     policy_model = functools.partial(models.Policy, body=body)
     value_model = functools.partial(models.Value, body=body)
     # pylint: disable=g-long-lambda
     lr = (lambda: lr_schedules.multifactor(constant=1e-2,
                                            warmup_steps=100,
                                            factors=
                                            'constant * linear_warmup'))
     # pylint: enable=g-long-lambda
     max_avg_returns = -math.inf
     for _ in range(5):
         trainer = actor_critic.AWRTrainer(
             task,
             n_shared_layers=1,
             added_policy_slice_length=1,
             value_model=value_model,
             value_optimizer=opt.Adam,
             value_lr_schedule=lr,
             value_batch_size=32,
             value_train_steps_per_epoch=200,
             policy_model=policy_model,
             policy_optimizer=opt.Adam,
             policy_lr_schedule=lr,
             policy_batch_size=32,
             policy_train_steps_per_epoch=200,
             n_trajectories_per_epoch=10,
             advantage_estimator=advantages.monte_carlo,
             advantage_normalization=False,
         )
         trainer.run(1)
         self.assertEqual(1, trainer.current_epoch)
         max_avg_returns = (max_avg_returns
                            if max_avg_returns > trainer.avg_returns[-1]
                            else trainer.avg_returns[-1])
         if trainer.avg_returns[-1] > 35.0:
             return
     self.fail(
         f'We did not reach a score > 35.0, max was {max_avg_returns}.')
示例#12
0
    def test_reset_twice(self, backend):
        with fastmath.use_backend(backend):
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         layer_widths=(16, 16, n_classes))
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.WeightedCategoryCrossEntropy(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            output_dir1 = self.create_tempdir(name='output_dir1').full_path
            trainer.reset(output_dir1)
            trainer.evaluate(1)
            output_dir2 = self.create_tempdir(name='output_dir2').full_path
            trainer.reset(output_dir2)
            trainer.evaluate(1)
示例#13
0
    def test_dqntrainer_cartpole(self):
        """Test-runs joint PPO on CartPole."""

        task = rl_task.RLTask('CartPole-v0',
                              initial_trajectories=0,
                              max_steps=2)
        value_body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())

        lr = lambda: lr_schedules.multifactor(  # pylint: disable=g-long-lambda
            constant=1e-2,
            warmup_steps=100,
            factors='constant * linear_warmup')

        trainer = training.DQN(task,
                               value_body=value_body,
                               value_optimizer=opt.Adam,
                               value_lr_schedule=lr,
                               value_batch_size=4,
                               value_train_steps_per_epoch=2,
                               n_trajectories_per_epoch=5)
        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
示例#14
0
    def test_reset_twice(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with fastmath.use_backend(backend_name), self.tmp_dir() as output_dir1, \
              self.tmp_dir() as output_dir2:
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=layers.CrossEntropyLoss(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            trainer.reset(output_dir1)
            trainer.evaluate(1)
            trainer.reset(output_dir2)
            trainer.evaluate(1)
示例#15
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.CrossEntropyLoss(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule=lr.multifactor(),
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          custom_train_fn=None):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    custom_train_fn: custom train function to call, entirely bypassing this one

  Returns:
    trax.TrainerState
  """
  if custom_train_fn is not None:
    return custom_train_fn(output_dir, model=model)

  n_devices = num_devices()
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs,
                          output_dir,
                          random_seed=random_seed,
                          n_devices=n_devices,
                          checkpoints_at=checkpoints_at,
                          metrics=metrics,
                          checkpoint_lowest=checkpoint_lowest,
                          checkpoint_highest=checkpoint_highest)

  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
  trainer.print_n_weights()

  try:
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
      trainer.train_epoch(epoch_steps, eval_steps)

      # Bookkeeping we do at the first step
      if trainer.step == 1:
        # Save computation graph (single-device only for now)
        if (save_graphs and fastmath.backend_name() == 'jax'):
          trainer.save_computation_graphs()

        # Save Gin config
        trainer.save_gin()

    trainer.log_step('Training done')
  except Exception as e:
    raise e
  finally:
    trainer.close()
  return trainer.state