예제 #1
0
def init_reversible_blocks(blocks, loss_layer, input_signature, rng):
    """Initialize reversible blocks and the loss layer and place weights on CPU.

  Args:
    blocks: List of reversible blocks (pairs of layer lists).
    loss_layer: The final loss layer to initialize.
    input_signature: The signature of the input to the blocks.
    rng: Random key used to initialize the layers.
  """
    sig_stack = input_signature
    process = psutil.Process(os.getpid())
    mem_use = process.memory_info().rss
    for (std_layers, rev_layers) in blocks:
        rngs = fastmath.random.split(rng,
                                     len(std_layers) + len(rev_layers) + 1)
        rng = rngs[0]
        for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]):
            sig = cb.inputs_from_stack(sig_stack, layer.n_in)
            layer.init(sig, rng=layer_rng)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            layer.state = tl.on_cpu(layer.state)  # store weights in cpu memory
            logging.info('init: layer %s\nadded cpu memory (MB): %.2f',
                         str(layer), (process.memory_info().rss - mem_use) /
                         float(1024 * 1024))
            mem_use = process.memory_info().rss
            logging.info('init: cpu memory use (MB): %.2f',
                         mem_use / float(1024 * 1024))
            out_sig = layer.output_signature(sig)
            sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in)
    loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng)
    loss_layer.weights = tl.on_cpu(loss_layer.weights)
    loss_layer.state = tl.on_cpu(loss_layer.state)
예제 #2
0
파일: training.py 프로젝트: RJK99/trax
 def new_rng(self):
   """Returns a new single-use random number generator (JAX PRNG key)."""
   self._rng, rng = fastmath.random.split(self._rng)
   if self._use_memory_efficient_trainer:
     self._rng = tl.on_cpu(self._rng)
     rng = tl.on_cpu(rng)
   return rng
예제 #3
0
  def load_checkpoint(self, directory=None, filename=None):
    """Loads model weights and step from a checkpoint on disk.

    Args:
      directory: Directory with the checkpoint (self._output_dir by default).
      filename: Checkpoint file name (model.pkl.gz by default).
    """
    directory = directory or self._output_dir
    if directory is None:
      _log('Not loading as both directory and output_dir are None.',
           stdout=False)
      return
    filename = filename or 'model.pkl.gz'
    path = os.path.join(directory, filename)
    if not tf.io.gfile.exists(path):
      _log(f'Not loading as checkpoint file does not exist: {path}.',
           stdout=False)
      return
    d = unpickle_from_file(path, gzip=True)
    # For large models, load weights from sharded files.
    if self._use_memory_efficient_trainer:
      weights = []
      n_shards = d['flat_weights']  # We store the number of shards in d here.
      for i in range(n_shards):
        w = unpickle_from_file(path + '.shard%d' % i, gzip=True)
        w = self._from_bits(w)  # bit-casting may put w on accelerator, go back
        weights.extend([tl.on_cpu(x) for x in w])
      d['flat_weights'] = weights
    else:
      d['flat_weights'] = self._from_bits(d['flat_weights'])
    self._step = d['step']
    if 'slots' in d:
      if len(self._tasks) != 1:
        raise ValueError(
            'Can\'t load a single-task checkpoint into a multitask Loop.'
        )
      d['slots_per_task'] = [d['slots']]
    if self._use_memory_efficient_trainer:
      for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
        trainer.slots = slots
    else:
      for (task, slots) in zip(self._tasks, d['slots_per_task']):
        task.optimizer.slots = slots
    # This is self._model.init_from_file but optimized to not re-read.
    input_signature = d['input_signature']
    weights_and_state_sig = self._model.weights_and_state_signature(
        input_signature)
    weights, state = tl.unflatten_weights_and_state(
        d['flat_weights'], d['flat_state'], weights_and_state_sig)
    self._model.state = state
    self._model.weights = weights
    self._eval_model.weights = self._model.weights
    # Restore eval model state; note: it's not always the same as train state.
    if 'flat_eval_state' in d:
      flat_eval_state = d['flat_eval_state']
    else:  # It wasn't saved in old checkpoints; remove this branch once ported.
      flat_eval_state = d['flat_state']
    _, eval_state = tl.unflatten_weights_and_state(
        d['flat_weights'], flat_eval_state, weights_and_state_sig)
    self._eval_model.state = eval_state
예제 #4
0
파일: training.py 프로젝트: RJK99/trax
 def _init_evaluator(self, eval_task):
   """Initializes the per-task evaluator."""
   model_with_metrics = _model_with_metrics(
       self._eval_model, eval_task)
   if self._use_memory_efficient_trainer:
     return _Evaluator(
         weights=tl.on_cpu(model_with_metrics.weights[1]),
         state=tl.on_cpu(model_with_metrics.state[1]),
         metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0)
     )
   else:
     return _Evaluator(
         # Replicate the eval part of weights and state.
         weights=self._for_n_devices(model_with_metrics.weights[1]),
         state=self._for_n_devices(model_with_metrics.state[1]),
         metrics_fn=_accelerate_model_with_metrics(
             model_with_metrics, self.n_devices)
     )
예제 #5
0
    def _replicate_cpu(self, x):
        # TODO(lukaszkaiser): move it to layers/acceleration to be together with
        #   tl.for_n_devices and other functions like that, possibly refactor them.
        def f(x):
            if self._n_devices > 1:
                return np.broadcast_to(x, (self._n_devices, ) +
                                       np.asarray(x).shape)
            else:
                return x

        return tl.on_cpu(fastmath.nested_map(f, x))
예제 #6
0
파일: trainer.py 프로젝트: wangdongya/trax
def init_reversible_blocks(blocks, loss_layer, input_signature, rng):
  """Initialize reversible blocks and the loss layer and place weights on CPU.

  Args:
    blocks: List of reversible blocks (pairs of layer lists).
    loss_layer: The final loss layer to initialize.
    input_signature: The signature of the input to the blocks.
    rng: Random key used to initialize the layers.
  """
  sig_stack = input_signature
  for (std_layers, rev_layers) in blocks:
    rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1)
    rng = rngs[0]
    for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]):
      sig = cb.inputs_from_stack(sig_stack, layer.n_in)
      layer.init(sig, rng=layer_rng)
      layer.weights = tl.on_cpu(layer.weights)  # store weights in cpu memory
      out_sig = layer.output_signature(sig)
      sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in)
  loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng)
  loss_layer.weights = tl.on_cpu(loss_layer.weights)
예제 #7
0
    def test_run_reversible_large_weights(self):
        """Runs the reversible trainer with a lot of weights to test memory use."""
        # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.global_device_count() == 1 and ram_limited:
            return

        # Create inputs and rngs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 18  # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram
        rev_layers = []
        int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
        shape = shapes.ShapeDtype((2, 4, 16 * 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.Adafactor

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # Set to true to run again, e.g., for profiling.
        run_twice = False
        if run_twice:
            t = time.time()
            loss, _ = trainer.one_step(labeled_batch, rng_step)
            self.assertLess(float(loss.sum()),
                            10000.0)  # Just to get the loss.
            print('Took %.3f seconds to run, loss %s' %
                  (time.time() - t, loss))
예제 #8
0
    def test_run_reversible_weights_trainsfer_xprof(self):
        """Runs the reversible trainer and profiles weight transfer stats."""
        run_this_test = False  # We only run this test manually.
        if not run_this_test or fastmath.global_device_count(
        ) == 1:  # TPU only
            return

        # Create inputs and rngs.
        inputs_batch = np.ones((1024, 128), dtype=np.int32)
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 6
        rev_layers = []
        int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32)
        shape = shapes.ShapeDtype((1024, 128, 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.SGD

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # We profile here.
        t = time.time()
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
예제 #9
0
  def test_run_reversible_large_weights(self):
    """Runs the reversible trainer with a lot of weights to test memory use."""
    # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU
    # and CPU when you run it locally, but it's too big for unit-testing.
    ram_limited = True  # Set to False to run this test locally.
    if fastmath.device_count() == 1 and ram_limited:
      return

    # Create inputs and rngs.
    inputs_batch = np.arange(8).reshape((2, 4))
    targets_batch = inputs_batch
    labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
    first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup())
    rng_init = fastmath.random.get_prng(12)
    rng_step = fastmath.random.get_prng(13)

    # Initialize layers.
    first_layer.init(labeled_batch, rng=rng_init)
    n_layers = 20  # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram
    rev_layers = []
    int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
    shape = shapes.ShapeDtype((2, 4, 16*1024))
    sig = (shape, shape)
    for _ in range(n_layers):
      layer = tl.ReversibleHalfResidual(tl.Dense(16*1024))
      layer.init(sig, rng=rng_init)
      layer.weights = tl.on_cpu(layer.weights)  # store weights in cpu memory
      rev_layers.append(layer)
      rev_layers.append(tl.ReversibleSwap())
    loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9),
                           tl.LogSoftmax(), tl.CrossEntropyLoss())
    loss_layer.init((shape, shape, int_shape, int_shape))
    optimizer_fn = optimizers.Adafactor

    # Make a step with reversible trainer.
    trainer = optimizers.ReversibleSerialTrainer(
        first_layer, rev_layers, loss_layer, optimizer_fn)
    trainer.one_step(labeled_batch, rng_step)
예제 #10
0
    def __init__(self,
                 model_with_loss,
                 optimizer,
                 n_devices=None,
                 adasum=False):
        self._model_with_loss = model_with_loss
        self._optimizer = optimizer
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum

        # optimizer slots and opt_params may need to be replicated
        self._slots, self._opt_params = tl.on_cpu(
            tl.for_n_devices(
                (self._optimizer.slots, self._optimizer.opt_params),
                self._n_devices))

        # accelerated version of model+loss to replicate weights and state
        self._accelerated_model_with_loss = tl.Accelerate(model_with_loss,
                                                          n_devices=n_devices)

        # Signature:
        # (batch, weights, state, rng) -> ((loss, state), gradients)
        self._forward_and_backward_fn = (
            fastmath.value_and_grad(
                model_with_loss.pure_fn,
                argnums=1,  # arg1 of pure_fn: weights
                has_aux=True))  # return (loss, state), gradients

        # Signature:
        # (weights, slots), step, opt_params, batch, state, rng ->
        # (weights, slots), state, stats
        self._accelerated_update_fn = (_accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
            adasum=self._adasum))
예제 #11
0
파일: trainer.py 프로젝트: wangdongya/trax
 def _unreplicate(self, x):
   if self._n_devices == 1:
     return tl.on_cpu(x)
   return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x))
예제 #12
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            self._replicated_loss_opt_params[
                'learning_rate'] = self._replicate_cpu(learning_rate)
            for (std_op, rev_ops) in self._replicated_opt_params:
                std_op['learning_rate'] = self._replicate_cpu(learning_rate)
                for op in rev_ops:
                    op['learning_rate'] = self._replicate_cpu(learning_rate)

        # Batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        step_int = step
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True)
            step = np.repeat(step, self._n_devices)

        # Create separate rng for each device and layer.
        if self._n_devices == 1:
            rngs = fastmath.random.split(rng, self._n_layers)
        else:
            # JIT the function and run it on CPU to avoid memory fragmentation.
            rngs = self._jit_per_device_rngs(tl.on_cpu(rng))
        # Group rngs by layer blocks.
        rng_blocks, rng_i = [], 0
        for _, rev_layers in self._blocks:
            l = len(rev_layers)
            rng_blocks.append((rngs[rng_i], rngs[rng_i + 1:rng_i + l + 1]))
            rng_i += l + 1

        # Run the layers forward upto the loss layer.
        if self._do_free:
            self._free_accelerators()
        process = psutil.Process(os.getpid())
        if isinstance(batch, (list, tuple)):
            batch_shapes = [x.shape for x in batch]
        else:
            batch_shapes = batch.shape
        logging.info('running step %d on shapes %s', step_int,
                     str(batch_shapes))
        if step_int % self._n_steps_per_log == 1:
            logging.info('run fwd: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))

        stack = batch
        block_inputs_states = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[
                i]
            std_rng, rev_rngs = rng_blocks[i]
            # Run the standard layer.
            stack, std_inputs, std_state = self._run_forward_standard(
                stack, std_layer, acc_std_layer_fn, std_rng, step_int)

            # Run the reversible layers and collect old and new states.
            stack, rev_old_states, rev_new_states = self._run_forward_reversible(
                stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int)
            block_inputs_states.append(
                tl.on_cpu(((std_inputs, std_state), (rev_old_states,
                                                     rev_new_states))))

        # Run the loss layer forward and backward with optimizer update.
        if step_int % self._n_steps_per_log == 1:
            logging.info('run loss: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))
        loss_state = self._replicate(self._loss_layer.state)
        loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
        loss_stats, grad_stack = self._run_backward_standard(
            None, step, self._loss_layer, loss_inputs, loss_state,
            self._loss_fbo, rngs[-1], self._loss_opt,
            self._replicated_loss_opt_params)
        self._collect_weights(self._loss_layer)
        stats = [tl.on_cpu(loss_stats)]

        # De-fragment memory.
        if self._do_free:
            stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack)
            self._free_accelerators()

        # Run the layers backward and run optimizer updates.
        if step_int % self._n_steps_per_log == 1:
            logging.info('run bwd: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))
        for i in range(len(self._blocks) - 1, -1, -1):
            std_layer, rev_layers = self._blocks[i]
            (std_inputs, std_state), (rev_old_states,
                                      rev_new_states) = block_inputs_states[i]
            std_fbo, rev_fbos = self._fbos[i]
            std_opt, rev_opts = self._optimizers[i]
            std_rng, rev_rngs = rng_blocks[i]
            repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[
                i]

            # Run reversible layers backward with optimizer update.
            stack, grad_stack, new_stats = self._run_backward_reversible(
                stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
                rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
            stats.extend(tl.on_cpu(new_stats))

            # Run the standard layer forward-and-backward pass and optimizer update.
            std_layer_stats, grad_stack = self._run_backward_standard(
                grad_stack, step, std_layer, std_inputs, std_state, std_fbo,
                std_rng, std_opt, repl_std_opt_params)
            stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
                std_inputs, stack, std_layer.n_out)
            stats.append(tl.on_cpu(std_layer_stats))

            # Collect lazily unreplicated layer weights.
            for rev_layer_id in range(self._n_async_layers):
                self._collect_weights(rev_layers[rev_layer_id])
            self._collect_weights(std_layer)

        # Join stats from different optimizers into one.
        joint_stats = {}
        for i, stat in enumerate(reversed(stats)):
            for k, v in stat.items():
                joint_stats[f'layer{i}/' + k] = v
        return stats[0]['loss'], joint_stats
예제 #13
0
 def _make_optimizer(layer):
     opt = optimizer_fn()
     opt.tree_init(layer.weights)
     opt.slots = tl.on_cpu(opt.slots)
     return opt
예제 #14
0
 def slots(self, slots):
     """Sets the slots of the optimizers and this class (replicated)."""
     self._optimizer.slots = slots
     self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices))
예제 #15
0
파일: training.py 프로젝트: RJK99/trax
  def __init__(
      self,
      model,
      tasks,
      eval_model=None,
      eval_tasks=None,
      output_dir=None,
      checkpoint_at=None,
      permanent_checkpoint_at=None,
      eval_at=None,
      which_task=None,
      n_devices=None,
      random_seed=None,
      loss_chunk_size=0,
      use_memory_efficient_trainer=False,
      callbacks=None,
  ):
    """Configures a training ``Loop``, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      tasks: List of :py:class:`TrainTask` instances, which define the training
          data, loss function, and optimizer to be used in respective tasks in
          this training loop. It can also be a single :py:class:`TrainTask`
          instance which is treated in the same way as a singleton list.
      eval_model: Optional Trax layer, representing model used for evaluation,
          e.g., with dropout turned off. If ``None``, the training model (model)
          will be used.
      eval_tasks: List of :py:class:`EvalTask` instances which define how to
          evaluate the model: which validation data to use and which metrics to
          report. Evaluation on each of the tasks and will run and be reported
          separately which allows to score a model on different subtasks. This
          argument can also be ``None``, in which case no evals will be run, or
          a single :py:class:`EvalTask`, which wil be treated in the same way
          as a singleton list.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are
          ``None``.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If ``None``, the default
          is periodic checkpointing at ``task.n_steps_per_checkpoint``.
      permanent_checkpoint_at: Function (integer --> boolean) telling,
          for step n, whether that step should have its checkpoint saved
          permanently. If ``None``, the default is periodic checkpointing at
          ``task.n_steps_per_permanent_checkpoint``.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If ``None``, run when
          checkpointing.
      which_task: Function (integer --> integer) indicating which task should be
          used at which training step. Can be set to ``None`` in single-task
          training.
      n_devices: integer or ``None``, the number of devices for this
          computation.
      random_seed: the random seed to use; time/os dependent if ``None``
          (default).
      loss_chunk_size: int, if > 0 use chunks of this size to make loss
        computation more more memory-efficient.
      use_memory_efficient_trainer: whether to use a special memory-efficient
        trainer; if set to 2, the memory efficiency if very aggressive
      callbacks: List of subclasses of StepCallback to call on training
        steps.
    """
    self._is_chief, self._n_hosts, self._n_devices, self._rng = (
        init_host_and_devices(n_devices, random_seed))
    if use_memory_efficient_trainer:
      self._rng = tl.on_cpu(self._rng)

    # Handle single task case without lists too.
    if not isinstance(tasks, (list, tuple)):
      tasks = [tasks]

    if not tasks:
      raise ValueError('Must provide at least one training task.')
    if eval_tasks is None:
      eval_tasks = []
      eval_at = _never
    else:
      if not isinstance(eval_tasks, (list, tuple)):
        eval_tasks = [eval_tasks]

    self._tasks = tasks
    self._model = model
    self._eval_model = eval_model or model

    self._use_memory_efficient_trainer = use_memory_efficient_trainer
    self._loss_chunk_size = loss_chunk_size
    # TODO(lukaszkaiser): can we have different eval models and save memory?
    if use_memory_efficient_trainer:
      assert len(tasks) == 1, 'only single task supported for now'
      self._eval_model = model

    default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint)
    permanent_default_at = _at_step_1_and_every_nth_step(
        tasks[0].n_steps_per_permanent_checkpoint)
    if output_dir is not None:
      self._output_dir = os.path.expanduser(output_dir)
      tf.io.gfile.makedirs(self._output_dir)
      inputs.load_data_counters(self._output_dir)
    else:
      self._output_dir = None

    # Prepare training components.
    self._step = 0
    self._history = trax_history.History()
    self._checkpoint_at = checkpoint_at or default_at
    self._permanent_checkpoint_at = (
        permanent_checkpoint_at or permanent_default_at)
    if which_task is None:
      # If which task is not passed, then we permute tasks one by one.
      # If len(tasks) = 1, then which_task is a constant function equal to 0.
      which_task = lambda n: n % len(tasks)
    self._which_task = which_task

    # Initialize using the given random seed.
    # NOTE: If random_seed is None then self._rng will be different on
    # different hosts, leading to different weights on the different hosts.
    self._batch_signature = shapes.signature(tasks[0].sample_batch)
    self._model.rng = self.new_rng()
    # In the memory-efficient case, we initialize in init_trainer.
    if not use_memory_efficient_trainer:
      if _is_uninitialized(self._model):
        self._model.init(self._batch_signature)
      self._eval_model.rng = self.new_rng()
      if _is_uninitialized(self._eval_model):
        self._eval_model.init(self._batch_signature)

    # To handle the above case (i.e. random_seed = None), we psum the weights
    # and state and average them.
    # NOTE: This adds time (how much?) so we prefer not to do it if it is
    # unnecessary, i.e. random_seed was set.
    # NOTE: Averaging the weights across devices can screw up the initial weight
    # statistics.
    # TODO(pkozakowski): Broadcast from one of the devices instead?
    # TODO(lukaszkaiser): make it work for the memory-efficient trainer too.
    if (random_seed is None and self._n_hosts > 1 and
        not use_memory_efficient_trainer):
      logging.info('Syncing weights/state across %d hosts.', self._n_hosts)
      self._sync_weights_and_state_across_hosts()

    # Create the optimizer for the training loss function.
    self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    self.load_checkpoint()

    # Prepare eval components.
    self._eval_at = eval_at or default_at
    self._eval_tasks = eval_tasks
    loss_names = [task.loss_name for task in self._tasks]
    metric_names = [
        name  # pylint: disable=g-complex-comprehension
        for eval_task in self._eval_tasks
        for name in eval_task.metric_names
    ]
    self._rjust_len = max(map(len, loss_names + metric_names))
    self._evaluator_per_task = tuple(
        self._init_evaluator(eval_task) for eval_task in self._eval_tasks)

    if self._output_dir is None:
      _log('Will not write evaluation metrics, because output_dir is None.')

    def task_output_dir(task_index, task_list):
      if self._output_dir is not None:
        if len(task_list) < 2:
          output_dir = self._output_dir
        else:
          output_dir = os.path.join(self._output_dir, str(task_index))
        tf.io.gfile.makedirs(output_dir)
        return output_dir
      else:
        return None
    self._output_dir_per_eval_task = [
        task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))]
    self._output_dir_per_train_task = [
        task_output_dir(i, tasks) for i in range(len(tasks))]

    callbacks = callbacks or []
    self._callbacks = [
        callback_class(self) for callback_class in callbacks
    ]