Example #1
0
 def per_device_rngs(rng):  # A function to JIT to not fragment memory.
   per_device_rng = fastmath.random.split(rng, self._n_devices)
   per_device_rngs = [
       fastmath.random.split(r, self._n_layers) for r in per_device_rng]
   rngs = [jnp.stack([r[i] for r in per_device_rngs])
           for i in range(self._n_layers)]
   return rngs
Example #2
0
 def predict(x, weights, state, rng):
     """Predict function JIT-compiled and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(reshape_by_device(x, n_devices), weights, state,
                       jnp.stack(fastmath.random.split(rng, n_devices))))
     if do_mean:
         return fastmath.nested_map(lambda y: jnp.mean(y, axis=0),
                                    res), state
     else:
         return res, state
Example #3
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = self.weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(self.weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        fastmath.dynamic_slice_in_dim(self.weights[0],
                                                      state[i],
                                                      inputs.shape[1],
                                                      axis=0))
                self.state = state + inputs.shape[1]
                res = inputs + jnp.stack(emb, 0)
                return res
Example #4
0
 def _per_device_rngs(self, rng):
     """Create per-device RNGs from a given rng."""
     # Splitting by device first to be identical with default trainer.
     per_device_rng = fastmath.random.split(rng, self._n_devices)
     per_device_rngs = [
         fastmath.random.split(r, self._n_layers) for r in per_device_rng
     ]
     rngs = [
         jnp.stack([r[i] for r in per_device_rngs])
         for i in range(self._n_layers)
     ]
     return rngs
Example #5
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates loss layer weights/state and optimizer slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are current optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)

        # separate rng needs to be created for each device
        if self._n_devices > 1:
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_loss_layer.weights
        state = self._accelerated_loss_layer.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_loss_layer.weights = new_weights
        self._accelerated_loss_layer.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Example #6
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Runs one training step, to update model and optimizer parameters.

    Args:
      batch: Batch of labeled training data.
      rng: Single-use random number generator (JAX PRNG key).
      step: Training step number.
      learning_rate: Learning rate for the optimizer; if None, use optimizer's
          default learning rate.

    Returns:
      Tuple of (loss, optimizer_stats), with the newly computed loss and
      updated stats as reported by the optimizer.
    """
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # Split the batch across devices (batch_dim --> batch_dim // n_devices)
        # and create new rng's 1-1 with devices.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_model_with_loss.weights
        state = self._accelerated_model_with_loss.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_model_with_loss.weights = new_weights
        self._accelerated_model_with_loss.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Example #7
0
  def _run_one_step(self, weights, state, slots, opt_params):
    """Updates model weights/state and optimizer slots by running one step.

    Args:
      weights: Weights from model being trained.
      state: State (non-weight parameters) from model being trained.
      slots: Updatable weights for the optimizer in this training loop.
      opt_params: Dictionary of optimizer (hyper)parameters,
        e.g. learning rate, momentum.

    Returns:
      Tuple (loss, weights, state, slots, stats) with new values from one step
      of training, where stats are current optimizer statistics.
    """
    step = self.step
    # Update the learning rate.
    opt_params['learning_rate'] = self._for_n_devices(
        self._task.learning_rate(step))

    batch = self._task.next_batch()
    # batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    batch = self._reshape_by_device(batch)

    rng = self.new_rng()
    if self.n_devices > 1:
      rng = jnp.stack(jax_random.split(rng, self.n_devices))

    if logging.vlog_is_on(1) and ((step & step - 1) == 0):
      # Prints every power of two, if debugging is enabled.
      logging.info('step[%d]', step)
      logging.info('opt_params[%s]', opt_params)
      logging.info('weights[%s]', weights)

    # NOTE: stats is a replicated dictionary of key to jnp arrays.
    (weights, slots), state, stats = (
        self._accelerated_update_fn(
            (weights, slots), step, opt_params, batch, state, rng)
        )

    if logging.vlog_is_on(1) and ((step & step - 1) == 0):
      logging.info('updated weights[%s]', weights)
      logging.info('stats[%s]', stats)

    return stats['loss'], weights, state, slots, stats
Example #8
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.is_backend(fastmath.Backend.JAX):
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Example #9
0
  def _get_embeddings(self, t):
    """Get embeddings float[..., num_features].

    Args:
      t: int[...] position (i.e. jnp.arange(..., jnp.int32))

    Returns:
      embeddings: float[..., num_features]
    """
    inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length)
    bin_parity = inter_bin_idx % 2
    bin_fraction = intra_bin_idx / self._time_bin_length
    embeddings = jnp.stack([
        1 / (1 + inter_bin_idx),
        bin_fraction,
        bin_parity.astype(jnp.float32),
    ], -1)

    assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape
    return embeddings
Example #10
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == jnp.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = jnp.arange(lo, hi, dtype=jnp.uint32)
    i_hi = jnp.zeros_like(i_lo)
    i = jnp.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Example #11
0
  def one_step(self, batch, rng, step=0, learning_rate=None):
    """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
    # Update the learning rate if needed.
    if learning_rate is not None:
      self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices(
          learning_rate, self._n_devices)
      for (std_op, rev_ops) in self._replicated_opt_params:
        std_op['learning_rate'] = tl.for_n_devices(
            learning_rate, self._n_devices)
        for op in rev_ops:
          op['learning_rate'] = tl.for_n_devices(
              learning_rate, self._n_devices)

    # Batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    if self._n_devices > 1:
      batch = tl.reshape_by_device(batch, self._n_devices)
      step = jnp.repeat(step, self._n_devices)

    # Create separate rng for each device and layer.
    if self._n_devices == 1:
      rngs = fastmath.random.split(rng, self._n_layers)
    else:
      # Splitting by device first to be identical with default trainer.
      per_device_rng = fastmath.random.split(rng, self._n_devices)
      per_device_rngs = [
          fastmath.random.split(r, self._n_layers) for r in per_device_rng]
      rngs = [jnp.stack([r[i] for r in per_device_rngs])
              for i in range(self._n_layers)]
    # Group rngs by layer blocks.
    rng_blocks, rng_i = [], 0
    for _, rev_layers in self._blocks:
      l = len(rev_layers)
      rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1]))
      rng_i += l + 1

    # Run the layers forward upto the loss layer.
    stack = batch
    block_inputs_states = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i]
      std_rng, rev_rngs = rng_blocks[i]
      # Run the standard layer.
      stack, std_inputs, std_state = self._run_forward_standard(
          stack, std_layer, acc_std_layer_fn, std_rng)

      # Run the reversible layers and collect old and new states.
      stack, rev_old_states, rev_new_states = self._run_forward_reversible(
          stack, rev_layers, acc_rev_layer_fns, rev_rngs)
      block_inputs_states.append(
          ((std_inputs, std_state), (rev_old_states, rev_new_states)))

    # Run the loss layer forward and backward with optimizer update.
    loss_state = self._replicate(self._loss_layer.state)
    loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
    loss_stats, grad_stack = self._run_backward_standard(
        None, step, self._loss_layer, loss_inputs,
        loss_state, self._loss_fbo, rngs[-1], self._loss_opt,
        self._replicated_loss_opt_params)
    stats = [loss_stats]

    # Run the layers backward and run optimizer updates.
    for i in range(len(self._blocks) - 1, -1, -1):
      std_layer, rev_layers = self._blocks[i]
      (std_inputs, std_state), (rev_old_states,
                                rev_new_states) = block_inputs_states[i]
      std_fbo, rev_fbos = self._fbos[i]
      std_opt, rev_opts = self._optimizers[i]
      std_rng, rev_rngs = rng_blocks[i]
      repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i]

      # Run reversible layers backward with optimizer update.
      stack, grad_stack, new_stats = self._run_backward_reversible(
          stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
          rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
      stats.extend(new_stats)

      # Run the standard layer forward-and-backward pass and optimizer update.
      std_layer_stats, grad_stack = self._run_backward_standard(
          grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng,
          std_opt, repl_std_opt_params)
      stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
          std_inputs, stack, std_layer.n_out)
      stats.append(std_layer_stats)

    # Join stats from different optimizers into one.
    joint_stats = {}
    for i, stat in enumerate(reversed(stats)):
      for k, v in stat.items():
        joint_stats[f'layer{i}/' + k] = v
    return stats[0]['loss'], joint_stats
Example #12
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at if checkpoints_at is not None else []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Example #13
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            for op in self._replicated_opt_params:
                op['learning_rate'] = tl.for_n_devices(learning_rate,
                                                       self._n_devices)

        # Batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)
            step = jnp.repeat(step, self._n_devices)

        # Separate rng needs to be created for each device.
        if self._n_devices == 1:
            rngs = fastmath.random.split(rng, len(self._reversible_layers) + 2)
        else:
            # Splitting by device first to be identical with default trainer.
            per_device_rng = fastmath.random.split(rng, self._n_devices)
            per_device_rngs = [
                fastmath.random.split(r,
                                      len(self._reversible_layers) + 2)
                for r in per_device_rng
            ]
            rngs = [
                jnp.stack([r[i] for r in per_device_rngs])
                for i in range(len(self._reversible_layers) + 2)
            ]

        # Run the layers forward upto the loss layer.
        stack = batch

        # Run the first layer.
        first_layer_inputs = _inputs_from_stack(self._first_layer, stack)
        first_layer_weights = self._replicate(self._first_layer.weights)
        first_layer_state = self._replicate(self._first_layer.state)
        outputs, first_layer_new_state = self._accelerated_first_layer_fn(
            first_layer_inputs, first_layer_weights, first_layer_state,
            rngs[0])
        stack = _outputs_onto_stack(self._first_layer, outputs, stack)

        # Run the reversible layers and collect old and new states.
        old_states, new_states = [], []
        for i, layer in enumerate(self._reversible_layers):
            weights = self._replicate(
                layer.weights)  # also copies cpu -> accelerator
            state = self._replicate(layer.state)
            old_states.append(state)
            inputs = _inputs_from_stack(layer, stack)
            outputs, new_state = self._accelerated_reversible_layers_fns[i](
                inputs, weights, state, rngs[i + 1])
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_states.append(new_state)

        # Run the loss layer forward and backward with optimizer update.
        loss_weights = self._replicate(self._loss_layer.weights)
        loss_state = self._replicate(self._loss_layer.state)
        loss_inputs = _inputs_from_stack(self._loss_layer, stack)
        loss_slots = self._replicate(self._optimizers[-1].slots)
        new_weights, new_state, new_slots, grad_stack, loss_stats = self._loss_fbo(
            loss_inputs, loss_weights, loss_state, loss_slots,
            self._replicated_opt_params[-1], rngs[-1], step)
        stats = [loss_stats]
        self._loss_layer.weights = self._unreplicate(
            new_weights)  # acceler. -> cpu
        self._loss_layer.state = self._unreplicate(new_state)
        self._optimizers[-1].slots = self._unreplicate(new_slots)

        # Run reversible layers backward with optimizer update.
        counter = -1
        for layer, reverse_and_fbo, old_state, new_state, rng in reversed(
                list(
                    zip(self._reversible_layers, self._reverse_and_fbos,
                        old_states, new_states, rngs[1:-1]))):
            counter -= 1
            # We are running backwards and reversing, so we get *outputs* from stack.
            outputs = _inputs_from_stack(layer, stack, layer.n_out)
            grads = _inputs_from_stack(layer, grad_stack, layer.n_out)
            slots = self._replicate(self._optimizers[counter].slots)
            opt_params = self._replicated_opt_params[counter]
            weights = self._replicate(layer.weights)  # cpu -> accelerator
            new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo(
                outputs, weights, old_state, new_state, slots, opt_params, rng,
                step, grads)
            layer.weights = self._unreplicate(
                new_weights)  # accelerator -> cpu
            layer.state = self._unreplicate(new_state)
            self._optimizers[counter].slots = self._unreplicate(new_slots)
            stats.append(layer_stats)
            stack = _outputs_onto_stack(layer, inputs, stack, layer.n_out,
                                        layer.n_in)
            grad_stack = _outputs_onto_stack(layer, grads, grad_stack,
                                             layer.n_out, layer.n_in)

        # Run the first layer forward-and-backward pass and optimizer update.
        grads = _inputs_from_stack(self._first_layer, grad_stack,
                                   self._first_layer.n_out)
        slots = self._replicate(self._optimizers[0].slots)
        new_weights, new_state, new_slots, first_layer_stats = self._first_fbo(
            first_layer_inputs, first_layer_weights, first_layer_new_state,
            slots, self._replicated_opt_params[0], rngs[0], step, grads)
        stats.append(first_layer_stats)
        self._first_layer.weights = self._unreplicate(new_weights)
        self._first_layer.state = self._unreplicate(new_state)
        self._optimizers[0].slots = self._unreplicate(new_slots)

        return stats[0]['loss'], stats