Exemple #1
0
  def __init__(self, task, policy_model=None, policy_optimizer=None,
               policy_lr_schedule=lr.MultifactorSchedule, policy_batch_size=64,
               policy_train_steps_per_epoch=500, policy_evals_per_epoch=1,
               policy_eval_steps=1, collect_per_epoch=50,
               max_slice_length=1, output_dir=None):
    """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
    """
    super(PolicyTrainer, self).__init__(
        task, collect_per_epoch=collect_per_epoch, output_dir=output_dir)
    self._policy_batch_size = policy_batch_size
    self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
    self._policy_evals_per_epoch = policy_evals_per_epoch
    self._policy_eval_steps = policy_eval_steps
    self._collect_per_epoch = collect_per_epoch
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)

    # Inputs to the policy model are produced by self._policy_batches_stream.
    self._policy_inputs = supervised.Inputs(
        train_stream=lambda _: self.policy_batches_stream())

    policy_model = functools.partial(
        policy_model,
        policy_distribution=self._policy_dist,
    )

    # This is the policy Trainer that will be used to train the policy model.
    # * inputs to the trainer come from self.policy_batches_stream
    # * we are using has_weights=True to allow inputs to set weights
    # * outputs, targets and weights are passed to self.policy_loss
    self._policy_trainer = supervised.Trainer(
        model=policy_model,
        optimizer=policy_optimizer,
        lr_schedule=policy_lr_schedule,
        loss_fn=self.policy_loss,
        inputs=self._policy_inputs,
        output_dir=output_dir,
        metrics={'policy_loss': self.policy_loss},
        has_weights=True)
    self._policy_eval_model = policy_model(mode='eval')
    policy_batch = next(self.policy_batches_stream())
    self._policy_eval_model.init(policy_batch)
Exemple #2
0
    def __init__(self,
                 task,
                 joint_model=None,
                 optimizer=None,
                 lr_schedule=lr.MultifactorSchedule,
                 batch_size=64,
                 train_steps_per_epoch=500,
                 collect_per_epoch=50,
                 max_slice_length=1,
                 output_dir=None):
        """Configures the joint trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      joint_model: Trax layer, representing the joint policy and value model.
      optimizer: the optimizer to use to train the joint model.
      lr_schedule: learning rate schedule to use to train the joint model/.
      batch_size: batch size used to train the joint model.
      train_steps_per_epoch: how long to train the joint model in each RL epoch.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
    """
        super(ActorCriticJointTrainer,
              self).__init__(task,
                             collect_per_epoch=collect_per_epoch,
                             output_dir=output_dir)
        self._batch_size = batch_size
        self._train_steps_per_epoch = train_steps_per_epoch
        self._collect_per_epoch = collect_per_epoch
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.env.action_space)

        # Inputs to the joint model are produced by self.batches_stream.
        self._inputs = supervised.Inputs(
            train_stream=lambda _: self.batches_stream())

        joint_model = functools.partial(
            joint_model,
            policy_distribution=self._policy_dist,
        )

        # This is the joint Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.batches_stream
        # * outputs are passed to self._joint_loss
        self._trainer = supervised.Trainer(
            model=joint_model,
            optimizer=optimizer,
            lr_schedule=lr_schedule,
            loss_fn=self.joint_loss,
            inputs=self._inputs,
            output_dir=output_dir,
            # TODO(lukaszkaiser): log policy and value losses too.
            metrics={'joint_loss': self.joint_loss})
        self._eval_model = joint_model(mode='eval')
        example_batch = next(self.batches_stream())
        self._eval_model.init(example_batch)
Exemple #3
0
  def test_integration_with_policy_tasks(self):
    # Integration test for policy + value training and eval.
    optimizer = opt.Adam()
    lr_schedule = lr_schedules.constant(1e-3)
    advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1)
    policy_dist = distributions.create_distribution(self._task.action_space)
    body = lambda mode: tl.Dense(64)
    train_model = models.PolicyAndValue(policy_dist, body=body)
    eval_model = models.PolicyAndValue(policy_dist, body=body)

    head_selector = tl.Select([1])
    value_train_task = value_tasks.ValueTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        advantage_estimator,
        model=train_model,
        target_model=eval_model,
        head_selector=head_selector,
    )
    value_eval_task = value_tasks.ValueEvalTask(
        value_train_task, head_selector=head_selector
    )

    # Drop the value head - just tl.Select([0]) would pass it, and it would
    # override the targets.
    head_selector = tl.Select([0], n_in=2)
    policy_train_task = policy_tasks.PolicyTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        policy_dist,
        advantage_estimator,
        # Plug a trained critic as our value estimate.
        value_fn=value_train_task.value,
        head_selector=head_selector,
    )
    policy_eval_task = policy_tasks.PolicyEvalTask(
        policy_train_task, head_selector=head_selector
    )

    loop = training.Loop(
        model=train_model,
        eval_model=eval_model,
        tasks=[policy_train_task, value_train_task],
        eval_tasks=[policy_eval_task, value_eval_task],
        eval_at=(lambda _: True),
        # Switch the task every step.
        which_task=(lambda step: step % 2),
    )
    # Run for a couple of steps to make sure there are a few task switches.
    loop.run(n_steps=10)
Exemple #4
0
    def test_shapes(self, space, gin_config):
        gin.parse_config(gin_config)

        batch_shape = (2, 3)
        distribution = distributions.create_distribution(space)
        inputs = np.random.random(batch_shape + (distribution.n_inputs, ))
        point = distribution.sample(inputs)
        self.assertEqual(point.shape, batch_shape + space.shape)
        # Check if the datatypes are compatible, i.e. either both floating or both
        # integral.
        self.assertEqual(isinstance(point.dtype, float),
                         isinstance(space.dtype, float))
        log_prob = distribution.log_prob(inputs, point)
        self.assertEqual(log_prob.shape, batch_shape)
Exemple #5
0
  def __init__(self, task,
               joint_model=None,
               optimizer=None,
               lr_schedule=lr.multifactor,
               batch_size=64,
               train_steps_per_epoch=500,
               supervised_evals_per_epoch=1,
               supervised_eval_steps=1,
               n_trajectories_per_epoch=50,
               max_slice_length=1,
               normalize_advantages=True,
               output_dir=None,
               n_replay_epochs=1):
    """Configures the joint trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      joint_model: Trax layer, representing the joint policy and value model.
      optimizer: the optimizer to use to train the joint model.
      lr_schedule: learning rate schedule to use to train the joint model/.
      batch_size: batch size used to train the joint model.
      train_steps_per_epoch: how long to train the joint model in each RL epoch.
      supervised_evals_per_epoch: number of value trainer evaluations per RL
          epoch - only affects metric reporting.
      supervised_eval_steps: number of value trainer steps per evaluation -
          only affects metric reporting.
      n_trajectories_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      normalize_advantages: if True, then normalize advantages - currently
          implemented only in PPO.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      n_replay_epochs: how many last epochs to take into the replay buffer;
           > 1 only makes sense for off-policy algorithms.
    """
    super().__init__(
        task,
        n_trajectories_per_epoch=n_trajectories_per_epoch,
        output_dir=output_dir,
    )
    self._batch_size = batch_size
    self._train_steps_per_epoch = train_steps_per_epoch
    self._supervised_evals_per_epoch = supervised_evals_per_epoch
    self._supervised_eval_steps = supervised_eval_steps
    self._n_trajectories_per_epoch = n_trajectories_per_epoch
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)
    self._lr_schedule = lr_schedule()
    self._optimizer = optimizer
    self._normalize_advantages = normalize_advantages
    self._n_replay_epochs = n_replay_epochs
    self._task.set_n_replay_epochs(n_replay_epochs)

    # Inputs to the joint model are produced by self.batches_stream.
    self._inputs = data.inputs.Inputs(
        train_stream=lambda _: self.batches_stream())

    self._joint_model = functools.partial(
        joint_model,
        policy_distribution=self._policy_dist,
    )

    # This is the joint Trainer that will be used to train the policy model.
    # * inputs to the trainer come from self.batches_stream
    # * outputs are passed to self._joint_loss
    self._trainer = supervised.Trainer(
        model=self._joint_model,
        optimizer=self._optimizer,
        lr_schedule=self._lr_schedule,
        loss_fn=self.joint_loss,
        inputs=self._inputs,
        output_dir=output_dir,
        metrics={'joint_loss': self.joint_loss,
                 'advantage_mean': self.advantage_mean,
                 'advantage_norm': self.advantage_norm,
                 'value_loss': self.value_loss,
                 'explained_variance': self.explained_variance,
                 'log_probs_mean': self.log_probs_mean,
                 'preferred_move': self.preferred_move})
    self._eval_model = tl.Accelerate(
        self._joint_model(mode='eval'), n_devices=1)
    example_batch = next(self.batches_stream())
    self._eval_model.init(example_batch)
Exemple #6
0
    def __init__(self,
                 task,
                 policy_model=None,
                 policy_optimizer=None,
                 policy_lr_schedule=lr.multifactor,
                 policy_batch_size=64,
                 policy_train_steps_per_epoch=500,
                 policy_evals_per_epoch=1,
                 policy_eval_steps=1,
                 n_eval_episodes=0,
                 only_eval=False,
                 max_slice_length=1,
                 output_dir=None,
                 **kwargs):
        """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super().__init__(task,
                         n_eval_episodes=n_eval_episodes,
                         output_dir=output_dir,
                         **kwargs)
        self._policy_batch_size = policy_batch_size
        self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
        self._policy_evals_per_epoch = policy_evals_per_epoch
        self._policy_eval_steps = policy_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)

        # Inputs to the policy model are produced by self._policy_batches_stream.
        self._policy_inputs = data.inputs.Inputs(
            train_stream=lambda _: self.policy_batches_stream())

        policy_model = functools.partial(
            policy_model,
            policy_distribution=self._policy_dist,
        )

        # This is the policy Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.policy_batches_stream
        # * outputs, targets and weights are passed to self.policy_loss
        self._policy_trainer = supervised.Trainer(
            model=policy_model,
            optimizer=policy_optimizer,
            lr_schedule=policy_lr_schedule(),
            loss_fn=self.policy_loss,
            inputs=self._policy_inputs,
            output_dir=output_dir,
            metrics=self.policy_metrics,
        )
        self._policy_collect_model = tl.Accelerate(
            policy_model(mode='collect'), n_devices=1)
        policy_batch = next(self.policy_batches_stream())
        self._policy_collect_model.init(shapes.signature(policy_batch))
        self._policy_eval_model = tl.Accelerate(
            policy_model(mode='eval'), n_devices=1)  # Not collecting stats
        self._policy_eval_model.init(shapes.signature(policy_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()
Exemple #7
0
  def __init__(self, task,
               value_body=None,
               value_optimizer=None,
               value_lr_schedule=lr.multifactor,
               value_batch_size=64,
               value_train_steps_per_epoch=500,
               value_evals_per_epoch=1,
               value_eval_steps=1,
               exploration_rate=functools.partial(
                   lr.multifactor,
                   factors='constant * decay_every',
                   constant=1.,  # pylint: disable=redefined-outer-name
                   decay_factor=0.99,
                   steps_per_decay=1,
                   minimum=0.1),
               n_eval_episodes=0,
               only_eval=False,
               n_replay_epochs=1,
               max_slice_length=1,
               sync_freq=1000,
               scale_value_targets=True,
               output_dir=None,
               **kwargs):
    """Configures the value trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      value_body: Trax layer, representing the body of the value model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      value_optimizer: the optimizer to use to train the policy model.
      value_lr_schedule: learning rate schedule to use to train the policy.
      value_batch_size: batch size used to train the policy model.
      value_train_steps_per_epoch: how long to train policy in each RL epoch.
      value_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      value_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      exploration_rate: exploration rate schedule - used in the policy method.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      max_slice_length: the maximum length of trajectory slices to use; it is
          the second dimenions of the value network output:
          (batch, max_slice_length, number of actions)
          Higher max_slice_length implies that the network has to predict more
          values into the future.
      sync_freq: frequency when to synchronize the target
        network with the trained network. This is necessary for training the
        network on bootstrapped targets, e.g. using n-step returns.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`. We are trying to fix the problem with very large
          returns in some games in a way which does not introduce an additional
          hyperparameters.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
    super(ValueAgent, self).__init__(
        task,
        n_eval_episodes=n_eval_episodes,
        output_dir=output_dir,
        **kwargs
    )
    self._value_batch_size = value_batch_size
    self._value_train_steps_per_epoch = value_train_steps_per_epoch
    self._value_evals_per_epoch = value_evals_per_epoch
    self._value_eval_steps = value_eval_steps
    self._only_eval = only_eval
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)
    self._n_replay_epochs = n_replay_epochs

    self._exploration_rate = exploration_rate()
    self._sync_at = (lambda step: step % sync_freq == 0)

    if scale_value_targets:
      self._value_network_scale = 1 / (1 - self._task.gamma)
    else:
      self._value_network_scale = 1

    value_model = functools.partial(
        models.Quality,
        body=value_body,
        n_actions=self.task.action_space.n)

    self._value_eval_model = value_model(mode='eval')
    self._value_eval_model.init(self._value_model_signature)
    self._value_eval_jit = tl.jit_forward(
        self._value_eval_model.pure_fn, fastmath.device_count(), do_mean=False)

    # Inputs to the value model are produced by self._values_batches_stream.
    self._inputs = data.inputs.Inputs(
        train_stream=lambda _: self.value_batches_stream())

    # This is the value Trainer that will be used to train the value model.
    # * inputs to the trainer come from self.value_batches_stream
    # * outputs, targets and weights are passed to self.value_loss
    self._value_trainer = supervised.Trainer(
        model=value_model,
        optimizer=value_optimizer,
        lr_schedule=value_lr_schedule(),
        loss_fn=self.value_loss,
        inputs=self._inputs,
        output_dir=output_dir,
        metrics={'value_loss': self.value_loss,
                 'value_mean': self.value_mean,
                 'returns_mean': self.returns_mean}
    )
    value_batch = next(self.value_batches_stream())
    self._eval_model = tl.Accelerate(
        value_model(mode='collect'), n_devices=1)
    self._eval_model.init(shapes.signature(value_batch))
    if self._task._initial_trajectories == 0:
      self._task.remove_epoch(0)
      self._collect_trajectories()
Exemple #8
0
  def __init__(
      self, task, model_fn,
      optimizer=adam.Adam,
      lr_schedule=lr.multifactor,
      batch_size=64,
      network_eval_at=None,
      n_eval_batches=1,
      max_slice_length=1,
      **kwargs
  ):
    """Initializes PolicyGradient.

    Args:
      task: Instance of trax.rl.task.RLTask.
      model_fn: Function (policy_distribution, mode) -> policy_model.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      batch_size: Batch size for network training.
      network_eval_at: Function step -> bool indicating the training steps, when
        network evaluation should be performed.
      n_eval_batches: Number of batches to run during network evaluation.
      max_slice_length: The length of trajectory slices to run the network on.
      **kwargs: Keyword arguments passed to the superclass.
    """
    super().__init__(task, **kwargs)

    self._max_slice_length = max_slice_length
    trajectory_batch_stream = task.trajectory_batch_stream(
        batch_size,
        epochs=[-1],
        max_slice_length=self._max_slice_length,
        sample_trajectories_uniformly=True,
    )
    self._policy_dist = distributions.create_distribution(task.action_space)
    train_task = policy_tasks.PolicyTrainTask(
        trajectory_batch_stream,
        optimizer(),
        lr_schedule(),
        self._policy_dist,
        # Policy gradient uses the MC estimator. No need for margin - the MC
        # estimator only uses empirical returns.
        advantage_estimator=advantages.monte_carlo(task.gamma, margin=0),
        value_fn=self._value_fn,
    )
    eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches)
    model_fn = functools.partial(
        model_fn,
        policy_distribution=self._policy_dist,
    )

    if self._output_dir is not None:
      policy_output_dir = os.path.join(self._output_dir, 'policy')
    else:
      policy_output_dir = None
    # Checkpoint every epoch. We do one step per epoch, so that's every step.
    checkpoint_at = lambda _: True
    self._loop = supervised.training.Loop(
        model=model_fn(mode='train'),
        tasks=[train_task],
        eval_model=model_fn(mode='eval'),
        eval_tasks=[eval_task],
        output_dir=policy_output_dir,
        eval_at=network_eval_at,
        checkpoint_at=checkpoint_at,
    )
    self._collect_model = model_fn(mode='collect')
    self._collect_model.init(shapes.signature(train_task.sample_batch))

    # Validate the restored checkpoints. The number of network training steps
    # (self.loop.step) should be equal to the number of epochs (self._epoch),
    # because we do exactly one gradient step per epoch.
    # TODO(pkozakowski): Move this to the base class once all Agents use Loop.
    if self.loop.step != self._epoch:
      raise ValueError(
          'The number of Loop steps must equal the number of Agent epochs, '
          'got {} and {}.'.format(self.loop.step, self._epoch)
      )
Exemple #9
0
  def __init__(self, task, model=None, optimizer=None,
               lr_schedule=lr.MultifactorSchedule, batch_size=64,
               train_steps_per_epoch=500, collect_per_epoch=50,
               max_slice_length=1, output_dir=None):
    """Configures the Reinforce loop.

    Args:
      task: RLTask instance, which defines the environment to train on.
      model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      optimizer: the optimizer to use to train the model.
      lr_schedule: learning rate schedule to use to train the model.
      batch_size: batch size used to train the model.
      train_steps_per_epoch: how long to train in each RL epoch.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
    """
    super(PolicyGradientTrainer, self).__init__(
        task, collect_per_epoch=collect_per_epoch, output_dir=output_dir)
    self._batch_size = batch_size
    self._train_steps_per_epoch = train_steps_per_epoch
    self._collect_per_epoch = collect_per_epoch
    self._max_slice_length = max_slice_length
    self._epoch = 0
    self._eval_model = model(mode='eval')
    example_batch = next(self._batches_stream())
    self._eval_model.init(example_batch)
    self._policy_dist = distributions.create_distribution(task.env.action_space)

    # Inputs to the policy model are produced by self._batches_stream.
    # As you can see below, the stream returns (observation, action, return)
    # from the RLTask, which the model uses as (inputs, targets, loss weights).
    self._inputs = supervised.Inputs(
        train_stream=lambda _: self._batches_stream())

    # This is the main Trainer that will be used to train the policy using
    # a policy gradient loss. Note a few of the choices here:
    #
    # * this is a policy trainer, so:
    #     inputs are states and targets are actions + loss weights (see below)
    # * we are using LogLoss
    #     This is because we are training a policy model, so targets are
    #     actions and they are points sampled from the policy distribution --
    #     LogLoss will calculate
    #     the log probability of each action in the state, log pi(s, a).
    # * we are using has_weights=True
    #     We set has_weights = True because pi(s, a) will be multiplied by
    #     a number -- a factor that can change depending on which policy
    #     gradient algorithms you use; here, we just use the return from
    #     from this state and action, but many other variants can be tried.
    loss = functools.partial(
        distributions.LogLoss, distribution=self._policy_dist
    )
    self._trainer = supervised.Trainer(
        model=model, optimizer=optimizer, lr_schedule=lr_schedule,
        loss_fn=loss, inputs=self._inputs, output_dir=output_dir,
        metrics={'loss': loss}, has_weights=True)
Exemple #10
0
    def __init__(self,
                 task,
                 model_fn,
                 value_fn,
                 weight_fn,
                 n_replay_epochs,
                 n_train_steps_per_epoch,
                 optimizer=adam.Adam,
                 lr_schedule=lr.multifactor,
                 batch_size=64,
                 network_eval_at=None,
                 n_eval_batches=1,
                 max_slice_length=1,
                 **kwargs):
        """Initializes LoopPolicyAgent.

    Args:
      task: Instance of trax.rl.task.RLTask.
      model_fn: Function (policy_distribution, mode) -> policy_model.
      value_fn: Function TimeStepBatch -> array (batch_size, seq_len)
        calculating the baseline for advantage calculation.
      weight_fn: Function float -> float to apply to advantages when calculating
        policy loss.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
        only makes sense for off-policy algorithms.
      n_train_steps_per_epoch: Number of steps to train the policy network for
        in each epoch.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      batch_size: Batch size for network training.
      network_eval_at: Function step -> bool indicating the training steps, when
        network evaluation should be performed.
      n_eval_batches: Number of batches to run during network evaluation.
      max_slice_length: The length of trajectory slices to run the network on.
      **kwargs: Keyword arguments passed to the superclass.
    """
        self._n_train_steps_per_epoch = n_train_steps_per_epoch
        super().__init__(task, **kwargs)

        task.set_n_replay_epochs(n_replay_epochs)
        self._max_slice_length = max_slice_length
        trajectory_batch_stream = task.trajectory_batch_stream(
            batch_size,
            epochs=[-(ep + 1) for ep in range(n_replay_epochs)],
            max_slice_length=self._max_slice_length,
            sample_trajectories_uniformly=True,
        )
        self._policy_dist = distributions.create_distribution(
            task.action_space)
        train_task = policy_tasks.PolicyTrainTask(
            trajectory_batch_stream,
            optimizer(),
            lr_schedule(),
            self._policy_dist,
            # Without a value network it doesn't make a lot of sense to use
            # a better advantage estimator than MC.
            advantage_estimator=advantages.monte_carlo(task.gamma, margin=0),
            value_fn=value_fn,
            weight_fn=weight_fn,
        )
        eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches)
        model_fn = functools.partial(
            model_fn,
            policy_distribution=self._policy_dist,
        )

        if self._output_dir is not None:
            policy_output_dir = os.path.join(self._output_dir, 'policy')
        else:
            policy_output_dir = None
        # Checkpoint every epoch.
        checkpoint_at = lambda step: step % n_train_steps_per_epoch == 0
        self._loop = supervised.training.Loop(
            model=model_fn(mode='train'),
            tasks=[train_task],
            eval_model=model_fn(mode='eval'),
            eval_tasks=[eval_task],
            output_dir=policy_output_dir,
            eval_at=network_eval_at,
            checkpoint_at=checkpoint_at,
        )
        self._collect_model = model_fn(mode='collect')
        self._collect_model.init(shapes.signature(train_task.sample_batch))

        # Validate the restored checkpoints.
        # TODO(pkozakowski): Move this to the base class once all Agents use Loop.
        if self._loop.step != self._epoch * self._n_train_steps_per_epoch:
            raise ValueError(
                'The number of Loop steps must equal the number of Agent epochs '
                'times the number of steps per epoch, got {}, {} and {}.'.
                format(self._loop.step, self._epoch,
                       self._n_train_steps_per_epoch))