Esempio n. 1
0
    def test_train_save_restore_sharded(self):
        """Saves and restores a sharded checkpoint to check for equivalence."""
        if fastmath.local_device_count() < 2:
            return  # multi-accelerator only
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        train_data = data.Serial(lambda _: _very_simple_data(2, 2),
                                 data.CountAndSkip('simple_data'))
        task = training.TrainTask(train_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(2, 2),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(2))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        _, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint('model')
        _, training_session2 = _make_model_and_session()
        training_session2.run(n_steps=1)
        base.N_WEIGHTS_SHARDS = 1
Esempio n. 2
0
 def test_autoregressive_sample_transformerlm_tfnp(self):
     with fastmath.use_backend(fastmath.Backend.TFNP):
         model = models.TransformerLM(10,
                                      d_model=32,
                                      d_ff=64,
                                      n_layers=1,
                                      n_heads=2,
                                      mode='predict')
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         s1 = decoding.autoregressive_sample(model,
                                             batch_size=1,
                                             eos_id=-1,
                                             max_length=10)
         self.assertEqual(s1.shape[0], 1)
         self.assertEqual(s1.shape[1], 10)
         batch_per_device = 2 // fastmath.local_device_count()
         model.init(shapes.ShapeDtype((batch_per_device, 1),
                                      dtype=np.int32))
         s2 = decoding.autoregressive_sample(model,
                                             batch_size=2,
                                             max_length=10)
         self.assertEqual(s2.shape[0], 2)
         self.assertLess(s2.shape[1], 11)
         model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
         prefix = np.array([[1, 2, 3]])
         s3 = decoding.autoregressive_sample(model,
                                             prefix,
                                             eos_id=-1,
                                             max_length=10,
                                             batch_size=1)
         self.assertEqual(s3.shape[0], 1)
         self.assertEqual(s3.shape[1], 10)
Esempio n. 3
0
 def __init__(self, layer, n_devices=None):
     super().__init__(n_in=layer.n_in, n_out=layer.n_out)
     self._sublayers = [layer]
     self._n_devices = n_devices or fastmath.local_device_count()
     self._jit_pure_fn = jit_forward(layer.pure_fn,
                                     self._n_devices,
                                     do_mean=False)
Esempio n. 4
0
 def _run_value_model(self, obs, use_eval_model=True):
     """Runs value model."""
     n_devices = fastmath.local_device_count()
     if use_eval_model:
         weights = tl.for_n_devices(self._value_eval_model.weights,
                                    n_devices)
         state = tl.for_n_devices(self._value_eval_model.state, n_devices)
         rng = self._value_eval_model.rng
     else:
         # TODO(henrykm): this strangely looking solution address the problem that
         # value_batches_stream calls _run_value_model _once_ before
         # the trainer is initialized.
         try:
             weights = tl.for_n_devices(self._value_trainer.model_weights,
                                        n_devices)
             state = tl.for_n_devices(self._value_trainer.model_state,
                                      n_devices)
             rng = self._value_trainer._rng  # pylint: disable=protected-access
         except AttributeError:
             weights = tl.for_n_devices(self._value_eval_model.weights,
                                        n_devices)
             state = tl.for_n_devices(self._value_eval_model.state,
                                      n_devices)
             rng = self._value_eval_model.rng
     # TODO(henrykm): the line below fails on TPU with the error
     # ValueError: Number of devices (8) does not evenly divide batch size (1).
     obs_batch = obs.shape[0]
     if n_devices > obs_batch:
         obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0)
     values, _ = self._value_eval_jit(obs, weights, state, rng)
     values = values[:obs_batch]
     values *= self._value_network_scale
     return values
Esempio n. 5
0
    def test_run_sharded_terraformer(self):
        """Runs Terraformer with sharded weights (only on 2+-device systems)."""
        if fastmath.local_device_count() == 1:
            return
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = terraformer.ConfigurableTerraformer(
            20,
            d_model=8,
            d_ff=32,
            n_heads=1,
            dropout=0.0,
            n_encoder_layers=2,
            n_decoder_layers=2,
            ff_sparsity=(4, 8, 0.0, 1.0),
            encoder_attention_type=tl.Attention,
            encoder_decoder_attention_type=tl.CausalAttention,
            pos_type=None,
            reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make a step with the trainer.
        optimizer = optimizers.Adafactor(0.01)
        split_w = fastmath.nested_map(
            lambda x: x[0],
            tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS))
        optimizer.tree_init(split_w)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        trainer.one_step(labeled_batch, rng_step1)
        # Reset shards back to default.
        base.N_WEIGHTS_SHARDS = 1
Esempio n. 6
0
def shard(tensors, n_shards=None):
    """Shard tensors across n_shards."""
    n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards
    indices = _axis_index(np.zeros(fastmath.local_device_count()))

    def _shard_fn(x):
        axis = _axis_to_shard_heuristic(x.shape)
        if int(x.shape[axis]) % n_shards != 0:
            raise ValueError(
                f'Cannot split x with shape {x.shape} into {n_shards}.')
        split_x = jnp.split(x, n_shards, axis=axis)
        split_x = [split_x[i % n_shards] for i in indices]
        return np.stack(split_x, axis=0)

    return fastmath.nested_map(_shard_fn, tensors)
Esempio n. 7
0
    def __init__(self,
                 model_with_loss,
                 optimizer,
                 n_devices=None,
                 adasum=False):
        self._model_with_loss = model_with_loss
        self._optimizer = optimizer
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum

        # optimizer slots and opt_params may need to be replicated
        self._slots, self._opt_params = tl.on_cpu(
            tl.for_n_devices(
                (self._optimizer.slots, self._optimizer.opt_params),
                self._n_devices))

        # accelerated version of model+loss to replicate weights and state
        self._accelerated_model_with_loss = tl.Accelerate(model_with_loss,
                                                          n_devices=n_devices)

        # Signature:
        # (batch, weights, state, rng) -> ((loss, state), gradients)
        self._forward_and_backward_fn = (
            fastmath.value_and_grad(
                model_with_loss.pure_fn,
                argnums=1,  # arg1 of pure_fn: weights
                has_aux=True))  # return (loss, state), gradients

        # Signature:
        # (weights, slots), step, opt_params, batch, state, rng ->
        # (weights, slots), state, stats
        self._accelerated_update_fn = (_accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
            adasum=self._adasum))
Esempio n. 8
0
    def __init__(self,
                 blocks,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None,
                 memoize_jit=True,
                 free_accelerators_on_step=False,
                 adasum=False):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
      free_accelerators_on_step: If true, frees memory on accelerators when
        starting a step. All layers and arguments must be on host for that,
        otherwise it can lead to failures. Can prevent memory fragmentation.
      adasum: if True, use adaptive summation to gather multi-device gradients.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
        self._n_async_layers = 1  # How many layers to run asynchronously.
        self._jit_memory = {} if memoize_jit else None
        self._do_free = free_accelerators_on_step
        self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs,
                                                 backend='cpu')

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
            self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            opt.slots = tl.on_cpu(opt.slots)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate_cpu(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer,
                                              std_opt,
                                              self._n_devices,
                                              adasum=self._adasum)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
                    layer, opt, self._n_devices, self._adasum)
                # The donated args are (outputs, weights, grads) and we can donate
                # them because weights and grads are immediately replaced and in
                # case of reversible layers, the outputs are never used again.
                rev_and_fbos.append(
                    self._pjit(rev_and_fbo,
                               f'rev+bwd {repr(layer)}',
                               donate_argnums=(0, 1, 2)))
            # In standard layers, the inputs cannot be donated as they may be used
            # as outputs for the reversible block below, but weights and grads can.
            jit_std_fbo = self._pjit(std_fbo,
                                     f'bwd {repr(std_layer)}',
                                     donate_argnums=(1, 2))
            self._fbos.append((jit_std_fbo, rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss',
                                           self._adasum)
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
Esempio n. 9
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.WeightedCategoryCrossEntropy(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          permanent_checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          permanent_checkpoint_frequency=None,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False,
          adasum=False,
          init_checkpoint=None,
          callbacks=None,
          additional_train_tasks=None,
          additional_eval_tasks=None,
          additional_eval_streams=None):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    permanent_checkpoints_at: list of integers. Save a permanent checkpoint for
      each training step in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    permanent_checkpoint_frequency: int, how often to save permanent checkpoints
      (every permanent_checkpoint_frequency steps).
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.
    adasum: if True, use adaptive summation for multi-device gradients.
    init_checkpoint: a checkpoint for fine tuning.
    callbacks: a list of callbacks to call during training.
    additional_train_tasks: additional tasks which should be performed during
      training.
    additional_eval_tasks: additional tasks which should be performed during
      evaluation.
    additional_eval_streams: List[NamedStream], additional data streams that
      should be used during evaluation. Can be provided independently of
      additional_eval_tasks.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
  if (permanent_checkpoint_frequency is not None
      and permanent_checkpoints_at is not None):
    raise ValueError('Only one of ["permanent_checkpoint_frequency", '
                     '"permanent_checkpoints_at"] should be set.')
  if use_loop:
    n_devices = num_devices() or fastmath.local_device_count()

    # Prepare the training task.
    # Inputs is either an Inputs instance or a function that returns it.
    if callable(inputs):  # If we pass a function, e.g., through gin, call it.
      inputs = inputs()
    opt = optimizer if use_memory_efficient_trainer else optimizer()
    train_task = training.TrainTask(
        inputs.train_stream(n_devices),
        loss_layer=loss_fn,
        optimizer=opt,
        lr_schedule=lr_schedule_fn(),
        n_steps_per_checkpoint=eval_frequency,
        n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency)

    if additional_train_tasks is None:
      additional_train_tasks = []

    # Prepare the evaluation.
    metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
    names, metrics = zip(*metrics_dict.items())
    eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                  metrics,
                                  metric_names=names,
                                  n_eval_batches=eval_steps)

    if additional_eval_tasks is None:
      additional_eval_tasks = []

    additional_eval_tasks_from_streams = []
    if additional_eval_streams is not None:
      for stream in additional_eval_streams:
        additional_eval_tasks_from_streams.append(
            training.EvalTask(stream.stream,
                              metrics,
                              metric_names=names,
                              n_eval_batches=eval_steps,
                              export_prefix=stream.name))

    # Prepare the training loop.
    checkpoint_at = None
    if checkpoints_at is not None:
      checkpoint_at = lambda step: step in checkpoints_at
    permanent_checkpoint_at = None
    if permanent_checkpoints_at is not None:
      permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')
    if init_checkpoint:
      model_train.init_from_file(init_checkpoint, weights_only=True)
      model_predict_eval.init_from_file(init_checkpoint, weights_only=True)
    loop = training.Loop(
        model_train, [train_task] + additional_train_tasks,
        eval_model=model_predict_eval,
        eval_tasks=[eval_task] +
        additional_eval_tasks + additional_eval_tasks_from_streams,
        output_dir=output_dir,
        checkpoint_at=checkpoint_at,
        permanent_checkpoint_at=permanent_checkpoint_at,
        n_devices=n_devices,
        loss_chunk_size=loss_chunk_size,
        use_memory_efficient_trainer=use_memory_efficient_trainer,
        adasum=adasum,
        random_seed=random_seed,
        callbacks=callbacks,
    )

    steps_to_go = steps - loop.step
    if steps_to_go <= 0:
      log('Stop training, already reached the total training steps %d' % steps)
      return loop

    # Train and return the loop.
    loop.run(steps_to_go)
    return loop

  n_devices = num_devices()
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs,
                          output_dir,
                          random_seed=random_seed,
                          n_devices=n_devices,
                          checkpoints_at=checkpoints_at,
                          metrics=metrics,
                          checkpoint_lowest=checkpoint_lowest,
                          checkpoint_highest=checkpoint_highest,
                          init_checkpoint=init_checkpoint)

  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
  trainer.print_n_weights()

  try:
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
      trainer.train_epoch(epoch_steps, eval_steps)

      # Bookkeeping we do at the first step
      if trainer.step == 1:
        # Save computation graph (single-device only for now)
        if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
          trainer.save_computation_graphs()

        # Save Gin config
        trainer.save_gin()

    trainer.log_step('Training done')
  except Exception as e:
    raise e
  finally:
    trainer.close()
  return trainer.state
Esempio n. 10
0
    def __init__(
            self,
            task,
            value_body=None,
            value_optimizer=None,
            value_lr_schedule=lr.multifactor,
            value_batch_size=64,
            value_train_steps_per_epoch=500,
            value_evals_per_epoch=1,
            value_eval_steps=1,
            exploration_rate=functools.partial(
                lr.multifactor,
                factors='constant * decay_every',
                constant=1.,  # pylint: disable=redefined-outer-name
                decay_factor=0.99,
                steps_per_decay=1,
                minimum=0.1),
            n_eval_episodes=0,
            only_eval=False,
            n_replay_epochs=1,
            max_slice_length=1,
            sync_freq=1000,
            scale_value_targets=True,
            output_dir=None,
            **kwargs):
        """Configures the value trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      value_body: Trax layer, representing the body of the value model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      value_optimizer: the optimizer to use to train the policy model.
      value_lr_schedule: learning rate schedule to use to train the policy.
      value_batch_size: batch size used to train the policy model.
      value_train_steps_per_epoch: how long to train policy in each RL epoch.
      value_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      value_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      exploration_rate: exploration rate schedule - used in the policy method.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      max_slice_length: the maximum length of trajectory slices to use; it is
          the second dimenions of the value network output:
          (batch, max_slice_length, number of actions)
          Higher max_slice_length implies that the network has to predict more
          values into the future.
      sync_freq: frequency when to synchronize the target
        network with the trained network. This is necessary for training the
        network on bootstrapped targets, e.g. using n-step returns.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`. We are trying to fix the problem with very large
          returns in some games in a way which does not introduce an additional
          hyperparameters.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super(ValueAgent, self).__init__(task,
                                         n_eval_episodes=n_eval_episodes,
                                         output_dir=output_dir,
                                         **kwargs)
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)
        self._n_replay_epochs = n_replay_epochs

        self._exploration_rate = exploration_rate()
        self._sync_at = (lambda step: step % sync_freq == 0)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        value_model = functools.partial(models.Quality,
                                        body=value_body,
                                        n_actions=self.task.action_space.n)

        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              fastmath.local_device_count(),
                                              do_mean=False)

        # Inputs to the value model are produced by self._values_batches_stream.
        self._inputs = data.inputs.Inputs(
            train_stream=lambda _: self.value_batches_stream())

        # This is the value Trainer that will be used to train the value model.
        # * inputs to the trainer come from self.value_batches_stream
        # * outputs, targets and weights are passed to self.value_loss
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule(),
            loss_fn=self.value_loss,
            inputs=self._inputs,
            output_dir=output_dir,
            metrics={
                'value_loss': self.value_loss,
                'value_mean': self.value_mean,
                'returns_mean': self.returns_mean
            })
        value_batch = next(self.value_batches_stream())
        self._eval_model = tl.Accelerate(value_model(mode='collect'),
                                         n_devices=1)
        self._eval_model.init(shapes.signature(value_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()