Exemple #1
0
  def __init__(self, task, policy_model=None, policy_optimizer=None,
               policy_lr_schedule=lr.MultifactorSchedule, policy_batch_size=64,
               policy_train_steps_per_epoch=500, policy_evals_per_epoch=1,
               policy_eval_steps=1, collect_per_epoch=50,
               max_slice_length=1, output_dir=None):
    """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
    """
    super(PolicyTrainer, self).__init__(
        task, collect_per_epoch=collect_per_epoch, output_dir=output_dir)
    self._policy_batch_size = policy_batch_size
    self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
    self._policy_evals_per_epoch = policy_evals_per_epoch
    self._policy_eval_steps = policy_eval_steps
    self._collect_per_epoch = collect_per_epoch
    self._max_slice_length = max_slice_length
    self._policy_dist = distributions.create_distribution(task.action_space)

    # Inputs to the policy model are produced by self._policy_batches_stream.
    self._policy_inputs = supervised.Inputs(
        train_stream=lambda _: self.policy_batches_stream())

    policy_model = functools.partial(
        policy_model,
        policy_distribution=self._policy_dist,
    )

    # This is the policy Trainer that will be used to train the policy model.
    # * inputs to the trainer come from self.policy_batches_stream
    # * we are using has_weights=True to allow inputs to set weights
    # * outputs, targets and weights are passed to self.policy_loss
    self._policy_trainer = supervised.Trainer(
        model=policy_model,
        optimizer=policy_optimizer,
        lr_schedule=policy_lr_schedule,
        loss_fn=self.policy_loss,
        inputs=self._policy_inputs,
        output_dir=output_dir,
        metrics={'policy_loss': self.policy_loss},
        has_weights=True)
    self._policy_eval_model = policy_model(mode='eval')
    policy_batch = next(self.policy_batches_stream())
    self._policy_eval_model.init(policy_batch)
Exemple #2
0
    def __init__(self,
                 task,
                 joint_model=None,
                 optimizer=None,
                 lr_schedule=lr.MultifactorSchedule,
                 batch_size=64,
                 train_steps_per_epoch=500,
                 collect_per_epoch=50,
                 max_slice_length=1,
                 output_dir=None):
        """Configures the joint trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      joint_model: Trax layer, representing the joint policy and value model.
      optimizer: the optimizer to use to train the joint model.
      lr_schedule: learning rate schedule to use to train the joint model/.
      batch_size: batch size used to train the joint model.
      train_steps_per_epoch: how long to train the joint model in each RL epoch.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
    """
        super(ActorCriticJointTrainer,
              self).__init__(task,
                             collect_per_epoch=collect_per_epoch,
                             output_dir=output_dir)
        self._batch_size = batch_size
        self._train_steps_per_epoch = train_steps_per_epoch
        self._collect_per_epoch = collect_per_epoch
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.env.action_space)

        # Inputs to the joint model are produced by self.batches_stream.
        self._inputs = supervised.Inputs(
            train_stream=lambda _: self.batches_stream())

        joint_model = functools.partial(
            joint_model,
            policy_distribution=self._policy_dist,
        )

        # This is the joint Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.batches_stream
        # * outputs are passed to self._joint_loss
        self._trainer = supervised.Trainer(
            model=joint_model,
            optimizer=optimizer,
            lr_schedule=lr_schedule,
            loss_fn=self.joint_loss,
            inputs=self._inputs,
            output_dir=output_dir,
            # TODO(lukaszkaiser): log policy and value losses too.
            metrics={'joint_loss': self.joint_loss})
        self._eval_model = joint_model(mode='eval')
        example_batch = next(self.batches_stream())
        self._eval_model.init(example_batch)
Exemple #3
0
  def __init__(self, task, model=None, optimizer=None,
               lr_schedule=lr.MultifactorSchedule, batch_size=64,
               train_steps_per_epoch=500, collect_per_epoch=50,
               max_slice_length=1, output_dir=None):
    """Configures the Reinforce loop.

    Args:
      task: RLTask instance, which defines the environment to train on.
      model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      optimizer: the optimizer to use to train the model.
      lr_schedule: learning rate schedule to use to train the model.
      batch_size: batch size used to train the model.
      train_steps_per_epoch: how long to train in each RL epoch.
      collect_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
    """
    super(ExamplePolicyTrainer, self).__init__(task, output_dir=output_dir)
    self._batch_size = batch_size
    self._train_steps_per_epoch = train_steps_per_epoch
    self._collect_per_epoch = collect_per_epoch
    self._max_slice_length = max_slice_length
    self._epoch = 0
    self._eval_model = model(mode='eval')
    example_batch = next(self._batches_stream())
    self._eval_model.init(example_batch)

    # Inputs to the policy model are produced by self._batches_stream.
    # As you can see below, the stream returns (observation, action, return)
    # from the RLTask, which the model uses as (inputs, targets, loss weights).
    self._inputs = supervised.Inputs(
        train_stream=lambda _: self._batches_stream())

    # This is the main Trainer that will be used to train the policy using
    # a policy gradient loss. Note a few of the choices here:
    #
    # * this is a policy trainer, so:
    #     inputs are states and targets are actions + loss weights (see below)
    # * we are using CrossEntropyLoss
    #     This is because we are training a policy model, so targets are
    #     actions and they are integers -- CrossEntropyLoss will calculate
    #     the probability of each action in the state, pi(s, a).
    # * we are using has_weights=True
    #     We set has_weights = True because pi(s, a) will be multiplied by
    #     a number -- a factor that can change depending on which policy
    #     gradient algorithms you use; here, we just use the return from
    #     from this state and action, but many other variants can be tried.
    #  * we use id_to_mask=0
    #     This is because we reserved 0 for padding actions - so true actions
    #     start from 1 and we want to remove any loss on the 0 padding.
    self._trainer = supervised.Trainer(
        model=model, optimizer=optimizer, lr_schedule=lr_schedule,
        loss_fn=tl.CrossEntropyLoss, inputs=self._inputs, output_dir=output_dir,
        has_weights=True, id_to_mask=0)
Exemple #4
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.MultifactorSchedule,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 n_shared_layers=0,
                 on_policy=True,
                 **kwargs):  # Arguments of PolicyTrainer come here.
        """Configures the actor-critic Trainer."""
        self._on_policy = on_policy
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyTrainer init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', None)

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, '/value')
        self._value_inputs = supervised.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule,
            loss_fn=tl.L2Loss,
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={'value_loss': tl.L2Loss},
            has_weights=True)
        self._value_eval_model = value_model(mode='eval')
        value_batch = next(self.value_batches_stream())
        self._value_eval_model.init(value_batch)

        # Initialize policy training.
        super(ActorCriticTrainer, self).__init__(task, **kwargs)
Exemple #5
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.MultifactorSchedule,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 n_shared_layers=0,
                 added_policy_slice_length=0,
                 **kwargs):  # Arguments of PolicyTrainer come here.
        """Configures the actor-critic Trainer.

    Args:
      task: RLTask instance to use
      value_model: the model to use for the value function
      value_optimizer: the optimizer to train the value model
      value_lr_schedule: lr schedule for value model training
      value_batch_size: batch size for value model training
      value_train_steps_per_epoch: how many steps are we using to
        train the value model in each epoch
      n_shared_layers: how many layers to share between value and
        policy models
      added_policy_slice_length: how much longer should slices of
        trajectories be for policy than for value training; this
        is useful for TD calculations and only affect the length
        of elements produced for policy batches; value batches
        have maximum length set by max_slice_length in **kwargs
     **kwargs: arguments for PolicyTrainer super-class
    """
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyTrainer init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', None)
        self._added_policy_slice_length = added_policy_slice_length

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, '/value')
        self._value_inputs = supervised.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule,
            loss_fn=tl.L2Loss,
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={'value_loss': tl.L2Loss},
            has_weights=True)
        self._value_eval_model = value_model(mode='eval')
        value_batch = next(self.value_batches_stream())
        self._value_eval_model.init(value_batch)

        # Initialize policy training.
        super(ActorCriticTrainer, self).__init__(task, **kwargs)
Exemple #6
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.MultifactorSchedule,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 value_evals_per_epoch=1,
                 value_eval_steps=1,
                 n_shared_layers=0,
                 added_policy_slice_length=0,
                 n_replay_epochs=1,
                 scale_value_targets=False,
                 q_value=False,
                 q_value_aggregate_max=True,
                 q_value_n_samples=1,
                 vocab_size=2,
                 **kwargs):  # Arguments of PolicyTrainer come here.
        """Configures the actor-critic trainer.

    Args:
      task: `RLTask` instance to use.
      value_model: Model to use for the value function.
      value_optimizer: Optimizer to train the value model.
      value_lr_schedule: lr schedule for value model training.
      value_batch_size: Batch size for value model training.
      value_train_steps_per_epoch: Number of steps are we using to train the
          value model in each epoch.
      value_evals_per_epoch: Number of value trainer evaluations per RL epoch;
          only affects metric reporting.
      value_eval_steps: Number of value trainer steps per evaluation; only
          affects metric reporting.
      n_shared_layers: Number of layers to share between value and policy
          models.
      added_policy_slice_length: How much longer should slices of
          trajectories be for policy than for value training; this
          is useful for TD calculations and only affect the length
          of elements produced for policy batches; value batches
          have maximum length set by `max_slice_length` in `**kwargs`.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`.
      q_value: If `True`, use Q-values as baselines.
      q_value_aggregate_max: If `True`, aggregate Q-values with max (or mean).
      q_value_n_samples: Number of samples to average over when calculating
          baselines based on Q-values.
      vocab_size: Embedding vocabulary size (passed to `tl.Embedding`); used
          only with discrete actions and when `q_value` is `True`.
      **kwargs: Arguments for `PolicyTrainer` superclass.
    """
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyTrainer init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', 1)
        self._added_policy_slice_length = added_policy_slice_length
        self._n_replay_epochs = n_replay_epochs
        task.set_n_replay_epochs(n_replay_epochs)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        self._q_value = q_value
        self._q_value_aggregate_max = q_value_aggregate_max
        self._q_value_n_samples = q_value_n_samples
        self._vocab_size = vocab_size

        is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete)
        # TODO(henrykm) handle the case other than Discrete/Gaussian

        if q_value:
            value_model = functools.partial(value_model,
                                            inject_actions=True,
                                            is_discrete=is_discrete,
                                            vocab_size=self._vocab_size)
        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              math.device_count(),
                                              do_mean=False)

        # Initialize policy training.
        super(ActorCriticTrainer, self).__init__(task, **kwargs)

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, 'value')
            # If needed, create value_output_dir and missing parent directories.
            if not tf.io.gfile.isdir(value_output_dir):
                tf.io.gfile.makedirs(value_output_dir)
        self._value_inputs = supervised.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule,
            loss_fn=tl.L2Loss(),
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={'value_loss': tl.L2Loss()})
Exemple #7
0
    def __init__(self,
                 task,
                 policy_model=None,
                 policy_optimizer=None,
                 policy_lr_schedule=lr.MultifactorSchedule,
                 policy_batch_size=64,
                 policy_train_steps_per_epoch=500,
                 policy_evals_per_epoch=1,
                 policy_eval_steps=1,
                 n_eval_episodes=0,
                 only_eval=False,
                 max_slice_length=1,
                 output_dir=None,
                 **kwargs):
        """Configures the policy trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      policy_model: Trax layer, representing the policy model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      policy_optimizer: the optimizer to use to train the policy model.
      policy_lr_schedule: learning rate schedule to use to train the policy.
      policy_batch_size: batch size used to train the policy model.
      policy_train_steps_per_epoch: how long to train policy in each RL epoch.
      policy_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      policy_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      max_slice_length: the maximum length of trajectory slices to use.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super(PolicyTrainer, self).__init__(task,
                                            n_eval_episodes=n_eval_episodes,
                                            output_dir=output_dir,
                                            **kwargs)
        self._policy_batch_size = policy_batch_size
        self._policy_train_steps_per_epoch = policy_train_steps_per_epoch
        self._policy_evals_per_epoch = policy_evals_per_epoch
        self._policy_eval_steps = policy_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)

        # Inputs to the policy model are produced by self._policy_batches_stream.
        self._policy_inputs = supervised.Inputs(
            train_stream=lambda _: self.policy_batches_stream())

        policy_model = functools.partial(
            policy_model,
            policy_distribution=self._policy_dist,
        )

        # This is the policy Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.policy_batches_stream
        # * outputs, targets and weights are passed to self.policy_loss
        self._policy_trainer = supervised.Trainer(
            model=policy_model,
            optimizer=policy_optimizer,
            lr_schedule=policy_lr_schedule,
            loss_fn=self.policy_loss,
            inputs=self._policy_inputs,
            output_dir=output_dir,
            metrics=self.policy_metrics,
        )
        self._policy_collect_model = policy_model(mode='collect')
        policy_batch = next(self.policy_batches_stream())
        self._policy_collect_model.init(shapes.signature(policy_batch))
        self._policy_eval_model = policy_model(
            mode='eval')  # Not collecting stats
        self._policy_eval_model.init(shapes.signature(policy_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()
Exemple #8
0
  def __init__(self, task,
               value_model=None,
               value_optimizer=None,
               value_lr_schedule=lr.MultifactorSchedule,
               value_batch_size=64,
               value_train_steps_per_epoch=500,
               value_evals_per_epoch=1,
               value_eval_steps=1,
               n_shared_layers=0,
               added_policy_slice_length=0,
               n_replay_epochs=1,
               scale_value_targets=False,
               **kwargs):  # Arguments of PolicyTrainer come here.
    """Configures the actor-critic Trainer.

    Args:
      task: RLTask instance to use
      value_model: the model to use for the value function
      value_optimizer: the optimizer to train the value model
      value_lr_schedule: lr schedule for value model training
      value_batch_size: batch size for value model training
      value_train_steps_per_epoch: how many steps are we using to
        train the value model in each epoch
      value_evals_per_epoch: number of value trainer evaluations per RL
          epoch - only affects metric reporting.
      value_eval_steps: number of value trainer steps per evaluation -
          only affects metric reporting.
      n_shared_layers: how many layers to share between value and
        policy models
      added_policy_slice_length: how much longer should slices of
        trajectories be for policy than for value training; this
        is useful for TD calculations and only affect the length
        of elements produced for policy batches; value batches
        have maximum length set by max_slice_length in **kwargs
     n_replay_epochs: how many last epochs to take into the replay buffer;
        only makes sense for off-policy algorithms
     scale_value_targets: whether to scale targets for the value function by
        1 / (1 - gamma)
     **kwargs: arguments for PolicyTrainer super-class
    """
    self._n_shared_layers = n_shared_layers
    self._value_batch_size = value_batch_size
    self._value_train_steps_per_epoch = value_train_steps_per_epoch
    self._value_evals_per_epoch = value_evals_per_epoch
    self._value_eval_steps = value_eval_steps

    # The 2 below will be initalized in super.__init__ anyway, but are needed
    # to construct value batches which are needed before PolicyTrainer init
    # since policy input creation calls the value model -- hence this code.
    self._task = task
    self._max_slice_length = kwargs.get('max_slice_length', 1)
    self._added_policy_slice_length = added_policy_slice_length
    self._n_replay_epochs = n_replay_epochs

    if scale_value_targets:
      self._value_network_scale = 1 / (1 - self._task.gamma)
    else:
      self._value_network_scale = 1

    self._value_eval_model = value_model(mode='eval')
    self._value_eval_model.init(self._value_model_signature)

    # Initialize training of the value function.
    value_output_dir = kwargs.get('output_dir', None)
    if value_output_dir is not None:
      value_output_dir = os.path.join(value_output_dir, 'value')
      # If needed, create value_output_dir and missing parent directories.
      if not tf.io.gfile.isdir(value_output_dir):
        tf.io.gfile.makedirs(value_output_dir)
    self._value_inputs = supervised.Inputs(
        train_stream=lambda _: self.value_batches_stream())
    self._value_trainer = supervised.Trainer(
        model=value_model,
        optimizer=value_optimizer,
        lr_schedule=value_lr_schedule,
        loss_fn=tl.L2Loss(),
        inputs=self._value_inputs,
        output_dir=value_output_dir,
        metrics={'value_loss': tl.L2Loss()})

    # Initialize policy training.
    super(ActorCriticTrainer, self).__init__(task, **kwargs)
Exemple #9
0
    def __init__(self,
                 task,
                 joint_model=None,
                 optimizer=None,
                 lr_schedule=lr.multifactor,
                 batch_size=64,
                 train_steps_per_epoch=500,
                 supervised_evals_per_epoch=1,
                 supervised_eval_steps=1,
                 n_trajectories_per_epoch=50,
                 max_slice_length=1,
                 normalize_advantages=True,
                 output_dir=None,
                 n_replay_epochs=1):
        """Configures the joint trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      joint_model: Trax layer, representing the joint policy and value model.
      optimizer: the optimizer to use to train the joint model.
      lr_schedule: learning rate schedule to use to train the joint model/.
      batch_size: batch size used to train the joint model.
      train_steps_per_epoch: how long to train the joint model in each RL epoch.
      supervised_evals_per_epoch: number of value trainer evaluations per RL
          epoch - only affects metric reporting.
      supervised_eval_steps: number of value trainer steps per evaluation -
          only affects metric reporting.
      n_trajectories_per_epoch: how many trajectories to collect per epoch.
      max_slice_length: the maximum length of trajectory slices to use.
      normalize_advantages: if True, then normalize advantages - currently
          implemented only in PPO.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      n_replay_epochs: how many last epochs to take into the replay buffer;
           > 1 only makes sense for off-policy algorithms.
    """
        super(ActorCriticJointTrainer, self).__init__(
            task,
            n_trajectories_per_epoch=n_trajectories_per_epoch,
            output_dir=output_dir,
        )
        self._batch_size = batch_size
        self._train_steps_per_epoch = train_steps_per_epoch
        self._supervised_evals_per_epoch = supervised_evals_per_epoch
        self._supervised_eval_steps = supervised_eval_steps
        self._n_trajectories_per_epoch = n_trajectories_per_epoch
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)
        self._lr_schedule = lr_schedule()
        self._optimizer = optimizer
        self._normalize_advantages = normalize_advantages
        self._n_replay_epochs = n_replay_epochs
        self._task.set_n_replay_epochs(n_replay_epochs)

        # Inputs to the joint model are produced by self.batches_stream.
        self._inputs = supervised.Inputs(
            train_stream=lambda _: self.batches_stream())

        self._joint_model = functools.partial(
            joint_model,
            policy_distribution=self._policy_dist,
        )

        # This is the joint Trainer that will be used to train the policy model.
        # * inputs to the trainer come from self.batches_stream
        # * outputs are passed to self._joint_loss
        self._trainer = supervised.Trainer(model=self._joint_model,
                                           optimizer=self._optimizer,
                                           lr_schedule=self._lr_schedule,
                                           loss_fn=self.joint_loss,
                                           inputs=self._inputs,
                                           output_dir=output_dir,
                                           metrics={
                                               'joint_loss':
                                               self.joint_loss,
                                               'advantage_mean':
                                               self.advantage_mean,
                                               'advantage_norm':
                                               self.advantage_norm,
                                               'value_loss':
                                               self.value_loss,
                                               'explained_variance':
                                               self.explained_variance,
                                               'log_probs_mean':
                                               self.log_probs_mean,
                                               'preferred_move':
                                               self.preferred_move
                                           })
        self._eval_model = tl.Accelerate(self._joint_model(mode='eval'),
                                         n_devices=1)
        example_batch = next(self.batches_stream())
        self._eval_model.init(example_batch)