Beispiel #1
0
    def reverse_and_grad(self,
                         output,
                         grad,
                         weights=(),
                         state=(),
                         new_state=(),
                         rng=None):
        rngs = (None, ) * self._n_layers
        if rng is not None:
            rngs = fastmath.random.split(rng, self._n_layers)

        stack = output
        stack_grad = grad
        weights_grad = []
        for layer, p, s, ns, rng in reversed(
                list(zip(self.sublayers, weights, state, new_state, rngs))):
            layer_val = cb.inputs_from_stack(stack, layer.n_out)
            layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out)
            layer_val, layer_ct = layer.reverse_and_grad(layer_val,
                                                         layer_ct,
                                                         p,
                                                         s,
                                                         ns,
                                                         rng=rng)
            layer_ct, p_ct = layer_ct
            weights_grad.insert(0, p_ct)
            stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out)
            stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad,
                                               layer.n_out)

        return stack, (stack_grad, tuple(weights_grad))
Beispiel #2
0
 def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng,
                                  optimizer, opt_params, reverse_and_fbo,
                                  old_state, new_state):
     """Run one reversible layer backwards."""
     # We are running backwards and reversing, so we get *outputs* from stack.
     outputs = cb.inputs_from_stack(stack, layer.n_out)
     grads = cb.inputs_from_stack(grad_stack, layer.n_out)
     slots = self._replicate(optimizer.slots)
     weights = self._replicate(layer.weights)  # cpu -> accelerator
     # Ensure all arguments are on accelerator.
     outputs = tl.on_accelerator(outputs)
     grads = tl.on_accelerator(grads)
     old_state = tl.on_accelerator(old_state)
     new_state = tl.on_accelerator(new_state)
     opt_params = tl.on_accelerator(opt_params)
     rng = tl.on_accelerator(rng)
     new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo(
         outputs, weights, grads, old_state, new_state, slots, opt_params,
         rng, step)
     layer.weights = self._lazy_unreplicate(
         new_weights)  # accelerator -> cpu
     layer.state = self._unreplicate(new_state)
     optimizer.slots = self._unreplicate(new_slots)
     stack = cb.outputs_onto_stack(inputs, stack, layer.n_out)
     grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out)
     return stack, grad_stack, layer_stats
Beispiel #3
0
 def _run_backward_reversible(self, stack, grad_stack, step,
                              rev_layers, rev_and_fbos,
                              old_states, new_states, rngs,
                              optimizers, replicated_opt_params):
   """Run reversible layers backwards."""
   counter = 0
   stats = []
   for layer, reverse_and_fbo, old_state, new_state, rng in reversed(list(zip(
       rev_layers, rev_and_fbos,
       old_states, new_states, rngs))):
     counter -= 1
     # We are running backwards and reversing, so we get *outputs* from stack.
     outputs = cb.inputs_from_stack(stack, layer.n_out)
     grads = cb.inputs_from_stack(grad_stack, layer.n_out)
     slots = self._replicate(optimizers[counter].slots)
     opt_params = replicated_opt_params[counter]
     weights = self._replicate(layer.weights)  # cpu -> accelerator
     new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo(
         outputs, weights, old_state, new_state,
         slots, opt_params, rng, step, grads)
     layer.weights = self._unreplicate(new_weights)  # accelerator -> cpu
     layer.state = self._unreplicate(new_state)
     optimizers[counter].slots = self._unreplicate(new_slots)
     stats.append(layer_stats)
     stack = cb.outputs_onto_stack(inputs, stack, layer.n_out)
     grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out)
   return stack, grad_stack, stats
Beispiel #4
0
def init_reversible_blocks(blocks, loss_layer, input_signature, rng):
    """Initialize reversible blocks and the loss layer and place weights on CPU.

  Args:
    blocks: List of reversible blocks (pairs of layer lists).
    loss_layer: The final loss layer to initialize.
    input_signature: The signature of the input to the blocks.
    rng: Random key used to initialize the layers.
  """
    sig_stack = input_signature
    process = psutil.Process(os.getpid())
    mem_use = process.memory_info().rss
    for (std_layers, rev_layers) in blocks:
        rngs = fastmath.random.split(rng,
                                     len(std_layers) + len(rev_layers) + 1)
        rng = rngs[0]
        for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]):
            sig = cb.inputs_from_stack(sig_stack, layer.n_in)
            layer.init(sig, rng=layer_rng)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            layer.state = tl.on_cpu(layer.state)  # store weights in cpu memory
            logging.info('init: layer %s\nadded cpu memory (MB): %.2f',
                         str(layer), (process.memory_info().rss - mem_use) /
                         float(1024 * 1024))
            mem_use = process.memory_info().rss
            logging.info('init: cpu memory use (MB): %.2f',
                         mem_use / float(1024 * 1024))
            out_sig = layer.output_signature(sig)
            sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in)
    loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng)
    loss_layer.weights = tl.on_cpu(loss_layer.weights)
    loss_layer.state = tl.on_cpu(loss_layer.state)
Beispiel #5
0
 def _run_backward_standard(self, grad_stack, step, layer, inp, state,
                            fbo_fn, rng, optimizer, replicated_opt_params):
     """Run reversible layers backwards."""
     step_int = int(step) if self._n_devices < 2 else int(step[0])
     if step_int % self._n_steps_per_log == 1:
         logging.info('running backward standard layer %s', str(layer))
     if grad_stack is not None:
         grads = cb.inputs_from_stack(grad_stack, layer.n_out)
     else:
         grads = None
     slots = self._replicate(optimizer.slots)
     weights = self._replicate(layer.weights)
     # Ensure all arguments are on accelerator.
     state = tl.on_accelerator(state)
     replicated_opt_params = tl.on_accelerator(replicated_opt_params)
     rng = tl.on_accelerator(rng)
     grads = tl.on_accelerator(grads)
     inp = tl.on_accelerator(inp)
     new_weights, new_state, new_slots, new_grads, stats = fbo_fn(
         inp, weights, grads, state, slots, replicated_opt_params, rng,
         step)
     layer.weights = self._lazy_unreplicate(new_weights)
     layer.state = self._unreplicate(new_state)
     optimizer.slots = self._unreplicate(new_slots)
     if grad_stack is not None:
         grad_stack = cb.outputs_onto_stack(new_grads, grad_stack,
                                            layer.n_out)
     else:
         grad_stack = new_grads
     return stats, grad_stack
Beispiel #6
0
 def _run_forward_standard(self, stack, layer, accelerated_fn, rng):
   """Run standard layer forward."""
   layer_inputs = cb.inputs_from_stack(stack, layer.n_in)
   layer_weights = self._replicate(layer.weights)
   layer_state = self._replicate(layer.state)
   outputs, layer_new_state = accelerated_fn(
       layer_inputs, layer_weights, layer_state, rng)
   stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
   return stack, layer_inputs, layer_new_state
Beispiel #7
0
 def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step):
     """Run standard layer forward."""
     if step % self._n_steps_per_log == 1:
         logging.info('running forward standard layer %s', str(layer))
     layer_inputs = cb.inputs_from_stack(stack, layer.n_in)
     layer_weights = self._replicate(layer.weights)
     layer_state = self._replicate(layer.state)
     outputs, layer_new_state = accelerated_fn(layer_inputs, layer_weights,
                                               layer_state, rng)
     stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
     return stack, layer_inputs, layer_new_state
Beispiel #8
0
  def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = fastmath.random.split(rng, self._n_layers)

    stack = output
    for layer, p, s, ns, rng in reversed(list(zip(
        self.sublayers, weights, state, new_state, rngs))):
      layer_val = cb.inputs_from_stack(stack, layer.n_out)
      layer_val = layer.reverse(layer_val, p, s, ns, rng=rng)
      stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out)

    return stack
Beispiel #9
0
 def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, rngs):
   """Run reversible layers forward, collect states for backwards pass."""
   old_states, new_states = [], []
   for i, layer in enumerate(rev_layers):
     weights = self._replicate(layer.weights)  # also copies cpu -> accelerator
     state = self._replicate(layer.state)
     old_states.append(state)
     inputs = cb.inputs_from_stack(stack, layer.n_in)
     outputs, new_state = accelerated_fns[i](
         inputs, weights, state, rngs[i])
     stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
     new_states.append(new_state)
   return stack, old_states, new_states
Beispiel #10
0
  def forward(self, xs):
    rngs = _split_rngs(self.rng, len(self.sublayers))
    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs):
      inputs = cb.inputs_from_stack(stack, layer.n_in)
      outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    self.state = tuple(new_state)
    return stack
Beispiel #11
0
  def init_weights_and_state(self, input_signature):
    stack = input_signature[1:]
    if len(stack) == 1:
      stack = stack[0]

    inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in)
    weights, state = self.compute_residual.init(inputs)
    outputs, _ = self.compute_residual._forward_abstract(inputs)
    stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in)

    if self.attention_layer is None:
      self.state = (state,)
      self.weights = (weights,)
    else:
      inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in)
      attn_weights, attn_state = self.attention_layer.init(inputs)
      self.state = (state, attn_state)
      self.weights = (weights, attn_weights)
Beispiel #12
0
 def _run_backward_standard(self, grad_stack, step, layer, inp, state,
                            fbo_fn, rng, optimizer, replicated_opt_params):
   """Run reversible layers backwards."""
   if grad_stack is not None:
     grads = cb.inputs_from_stack(grad_stack, layer.n_out)
   else:
     grads = None
   slots = self._replicate(optimizer.slots)
   weights = self._replicate(layer.weights)
   new_weights, new_state, new_slots, new_grads, stats = fbo_fn(
       inp, weights, state, slots, replicated_opt_params, rng, step, grads)
   layer.weights = self._unreplicate(new_weights)
   layer.state = self._unreplicate(new_state)
   optimizer.slots = self._unreplicate(new_slots)
   if grad_stack is not None:
     grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out)
   else:
     grad_stack = new_grads
   return stats, grad_stack
Beispiel #13
0
  def forward(self, xs):
    rngs = _split_rngs(self.rng, len(self.sublayers))
    accumulator, *context = xs
    stack = context = tuple(context)
    new_state = []
    for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs):
      inputs = cb.inputs_from_stack(stack, layer.n_in)
      if base.N_WEIGHTS_SHARDS > 1:
        # With sharded weights, make sure we don't keep them concatenated
        # in memory on each device by using remat.
        outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng)
      else:
        outputs, s = layer.pure_fn(inputs, w, s, rng)
      stack = cb.outputs_onto_stack(outputs, stack, layer.n_in)
      new_state.append(s)
    residual = stack[0] if isinstance(stack, (tuple, list)) else stack

    output = accumulator + residual
    stack = (output,) + context
    self.state = tuple(new_state)
    return stack
Beispiel #14
0
def init_reversible_blocks(blocks, loss_layer, input_signature, rng):
  """Initialize reversible blocks and the loss layer and place weights on CPU.

  Args:
    blocks: List of reversible blocks (pairs of layer lists).
    loss_layer: The final loss layer to initialize.
    input_signature: The signature of the input to the blocks.
    rng: Random key used to initialize the layers.
  """
  sig_stack = input_signature
  for (std_layers, rev_layers) in blocks:
    rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1)
    rng = rngs[0]
    for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]):
      sig = cb.inputs_from_stack(sig_stack, layer.n_in)
      layer.init(sig, rng=layer_rng)
      layer.weights = tl.on_cpu(layer.weights)  # store weights in cpu memory
      out_sig = layer.output_signature(sig)
      sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in)
  loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng)
  loss_layer.weights = tl.on_cpu(loss_layer.weights)
Beispiel #15
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       rng=None):
    rngs = _split_rngs(rng, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in)
    outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self._forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = cb.outputs_onto_stack(
          attn_inputs_ct, stack_ct, self.attention_layer.n_out)

    compute_residual_ct = cb.inputs_from_stack(
        stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = cb.outputs_onto_stack(
        compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)
Beispiel #16
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         rng=None):
        rngs = _split_rngs(rng, len(self.sublayers))

        accumulator_output, *context = output
        context = tuple(context)
        accumulator_output_ct, *context_ct = ct
        context_ct = tuple(context_ct)

        # Forward pass through self._compute_residual. Outputs that will not receive
        # a gradient signal from subsequent layers are moved to aux.
        def call_compute_residual(x, weights):
            state_to_pass = state[0]  # old_state

            # _replace_second_time is currently used exclusively in _RememberInReverse
            # layer to combat numerical instability in Reformer2 when quantizing
            # the mask in SparseFF.
            def _replace_second_time(stt, nstt):
                if (isinstance(stt, tuple) and len(stt) == 2
                        and isinstance(stt[1], dict)
                        and 'running_second_time' in stt[1]):
                    return (nstt[0], {'running_second_time_yes': ()})
                elif isinstance(stt, (tuple, list)):
                    assert isinstance(nstt,
                                      (tuple, list)) and len(nstt) == len(stt)
                    return type(stt)([
                        _replace_second_time(s, ns)
                        for s, ns in zip(stt, nstt)
                    ])
                else:
                    return stt

            state_to_pass = _replace_second_time(state_to_pass, new_state[0])
            res, _ = self._compute_residual.pure_fn(x,
                                                    weights=weights,
                                                    state=state_to_pass,
                                                    rng=rngs[0])
            if not isinstance(res, (tuple, list)):
                return res, None
            else:
                n_differentiable = 1
                if self._attention_layer is not None:
                    n_differentiable = min(len(res),
                                           self._attention_layer.n_in)
                return res[:n_differentiable], res[n_differentiable:]

        stack = context
        inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in)
        outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
            call_compute_residual, inputs, weights[0], has_aux=True)
        if outputs_aux is not None:
            n_differentiable_outputs = len(outputs)
            outputs = outputs + outputs_aux
        stack = cb.outputs_onto_stack(outputs, stack,
                                      self._compute_residual.n_in)

        stack_ct = accumulator_output_ct
        if self._attention_layer is None:
            residual = stack[0] if isinstance(stack, (tuple, list)) else stack
        else:
            inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in)
            (residual, _, attn_inputs_ct,
             attn_weights_ct) = self._forward_and_or_backward(
                 inputs,
                 weights[1],
                 new_state[1],
                 rngs[1],
                 output_grad=accumulator_output_ct,
                 compute_output=True,
                 update_state=False)
            stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct,
                                             self._attention_layer.n_out)

        compute_residual_ct = cb.inputs_from_stack(
            stack_ct, self._compute_residual.n_out)
        if outputs_aux is not None:
            if not isinstance(compute_residual_ct, (tuple, list)):
                compute_residual_ct = (compute_residual_ct, )
            compute_residual_ct = compute_residual_ct[:
                                                      n_differentiable_outputs]
            assert len(compute_residual_ct) == n_differentiable_outputs
        (compute_residual_inputs_ct, compute_residual_weights_ct
         ) = compute_residual_vjpfun(compute_residual_ct)
        stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct,
                                         self._compute_residual.n_out)
        if not isinstance(stack_ct, (tuple, list)):
            stack_ct = (stack_ct, )

        def _add(x, y):
            # `None` is for TFNP backend, which uses `None` as the gradient of
            # int/bool instead of an array of dtype `float0`.
            if x is None or x.dtype == jax.float0:
                return y
            if y is None or y.dtype == jax.float0:
                return x
            return x + y

        stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg(
            _add, context_ct[:len(stack_ct)],
            stack_ct) + context_ct[len(stack_ct):]

        reconstructed_x = accumulator_output - residual
        stack = (reconstructed_x, ) + context
        if self._attention_layer is None:
            weights_ct = (compute_residual_weights_ct, )
        else:
            weights_ct = (compute_residual_weights_ct, attn_weights_ct)
        return stack, (stack_ct, weights_ct)
Beispiel #17
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            self._replicated_loss_opt_params[
                'learning_rate'] = self._replicate_cpu(learning_rate)
            for (std_op, rev_ops) in self._replicated_opt_params:
                std_op['learning_rate'] = self._replicate_cpu(learning_rate)
                for op in rev_ops:
                    op['learning_rate'] = self._replicate_cpu(learning_rate)

        # Batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        step_int = step
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True)
            step = np.repeat(step, self._n_devices)

        # Create separate rng for each device and layer.
        if self._n_devices == 1:
            rngs = fastmath.random.split(rng, self._n_layers)
        else:
            # JIT the function and run it on CPU to avoid memory fragmentation.
            rngs = self._jit_per_device_rngs(tl.on_cpu(rng))
        # Group rngs by layer blocks.
        rng_blocks, rng_i = [], 0
        for _, rev_layers in self._blocks:
            l = len(rev_layers)
            rng_blocks.append((rngs[rng_i], rngs[rng_i + 1:rng_i + l + 1]))
            rng_i += l + 1

        # Run the layers forward upto the loss layer.
        if self._do_free:
            self._free_accelerators()
        process = psutil.Process(os.getpid())
        if isinstance(batch, (list, tuple)):
            batch_shapes = [x.shape for x in batch]
        else:
            batch_shapes = batch.shape
        logging.info('running step %d on shapes %s', step_int,
                     str(batch_shapes))
        if step_int % self._n_steps_per_log == 1:
            logging.info('run fwd: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))

        stack = batch
        block_inputs_states = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[
                i]
            std_rng, rev_rngs = rng_blocks[i]
            # Run the standard layer.
            stack, std_inputs, std_state = self._run_forward_standard(
                stack, std_layer, acc_std_layer_fn, std_rng, step_int)

            # Run the reversible layers and collect old and new states.
            stack, rev_old_states, rev_new_states = self._run_forward_reversible(
                stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int)
            block_inputs_states.append(
                tl.on_cpu(((std_inputs, std_state), (rev_old_states,
                                                     rev_new_states))))

        # Run the loss layer forward and backward with optimizer update.
        if step_int % self._n_steps_per_log == 1:
            logging.info('run loss: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))
        loss_state = self._replicate(self._loss_layer.state)
        loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
        loss_stats, grad_stack = self._run_backward_standard(
            None, step, self._loss_layer, loss_inputs, loss_state,
            self._loss_fbo, rngs[-1], self._loss_opt,
            self._replicated_loss_opt_params)
        self._collect_weights(self._loss_layer)
        stats = [tl.on_cpu(loss_stats)]

        # De-fragment memory.
        if self._do_free:
            stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack)
            self._free_accelerators()

        # Run the layers backward and run optimizer updates.
        if step_int % self._n_steps_per_log == 1:
            logging.info('run bwd: cpu memory use (MB): %.2f',
                         process.memory_info().rss / float(1024 * 1024))
        for i in range(len(self._blocks) - 1, -1, -1):
            std_layer, rev_layers = self._blocks[i]
            (std_inputs, std_state), (rev_old_states,
                                      rev_new_states) = block_inputs_states[i]
            std_fbo, rev_fbos = self._fbos[i]
            std_opt, rev_opts = self._optimizers[i]
            std_rng, rev_rngs = rng_blocks[i]
            repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[
                i]

            # Run reversible layers backward with optimizer update.
            stack, grad_stack, new_stats = self._run_backward_reversible(
                stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
                rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
            stats.extend(tl.on_cpu(new_stats))

            # Run the standard layer forward-and-backward pass and optimizer update.
            std_layer_stats, grad_stack = self._run_backward_standard(
                grad_stack, step, std_layer, std_inputs, std_state, std_fbo,
                std_rng, std_opt, repl_std_opt_params)
            stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
                std_inputs, stack, std_layer.n_out)
            stats.append(tl.on_cpu(std_layer_stats))

            # Collect lazily unreplicated layer weights.
            for rev_layer_id in range(self._n_async_layers):
                self._collect_weights(rev_layers[rev_layer_id])
            self._collect_weights(std_layer)

        # Join stats from different optimizers into one.
        joint_stats = {}
        for i, stat in enumerate(reversed(stats)):
            for k, v in stat.items():
                joint_stats[f'layer{i}/' + k] = v
        return stats[0]['loss'], joint_stats
Beispiel #18
0
  def one_step(self, batch, rng, step=0, learning_rate=None):
    """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
    # Update the learning rate if needed.
    if learning_rate is not None:
      self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices(
          learning_rate, self._n_devices)
      for (std_op, rev_ops) in self._replicated_opt_params:
        std_op['learning_rate'] = tl.for_n_devices(
            learning_rate, self._n_devices)
        for op in rev_ops:
          op['learning_rate'] = tl.for_n_devices(
              learning_rate, self._n_devices)

    # Batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    if self._n_devices > 1:
      batch = tl.reshape_by_device(batch, self._n_devices)
      step = jnp.repeat(step, self._n_devices)

    # Create separate rng for each device and layer.
    if self._n_devices == 1:
      rngs = fastmath.random.split(rng, self._n_layers)
    else:
      # Splitting by device first to be identical with default trainer.
      per_device_rng = fastmath.random.split(rng, self._n_devices)
      per_device_rngs = [
          fastmath.random.split(r, self._n_layers) for r in per_device_rng]
      rngs = [jnp.stack([r[i] for r in per_device_rngs])
              for i in range(self._n_layers)]
    # Group rngs by layer blocks.
    rng_blocks, rng_i = [], 0
    for _, rev_layers in self._blocks:
      l = len(rev_layers)
      rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1]))
      rng_i += l + 1

    # Run the layers forward upto the loss layer.
    stack = batch
    block_inputs_states = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i]
      std_rng, rev_rngs = rng_blocks[i]
      # Run the standard layer.
      stack, std_inputs, std_state = self._run_forward_standard(
          stack, std_layer, acc_std_layer_fn, std_rng)

      # Run the reversible layers and collect old and new states.
      stack, rev_old_states, rev_new_states = self._run_forward_reversible(
          stack, rev_layers, acc_rev_layer_fns, rev_rngs)
      block_inputs_states.append(
          ((std_inputs, std_state), (rev_old_states, rev_new_states)))

    # Run the loss layer forward and backward with optimizer update.
    loss_state = self._replicate(self._loss_layer.state)
    loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
    loss_stats, grad_stack = self._run_backward_standard(
        None, step, self._loss_layer, loss_inputs,
        loss_state, self._loss_fbo, rngs[-1], self._loss_opt,
        self._replicated_loss_opt_params)
    stats = [loss_stats]

    # Run the layers backward and run optimizer updates.
    for i in range(len(self._blocks) - 1, -1, -1):
      std_layer, rev_layers = self._blocks[i]
      (std_inputs, std_state), (rev_old_states,
                                rev_new_states) = block_inputs_states[i]
      std_fbo, rev_fbos = self._fbos[i]
      std_opt, rev_opts = self._optimizers[i]
      std_rng, rev_rngs = rng_blocks[i]
      repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i]

      # Run reversible layers backward with optimizer update.
      stack, grad_stack, new_stats = self._run_backward_reversible(
          stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
          rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
      stats.extend(new_stats)

      # Run the standard layer forward-and-backward pass and optimizer update.
      std_layer_stats, grad_stack = self._run_backward_standard(
          grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng,
          std_opt, repl_std_opt_params)
      stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
          std_inputs, stack, std_layer.n_out)
      stats.append(std_layer_stats)

    # Join stats from different optimizers into one.
    joint_stats = {}
    for i, stat in enumerate(reversed(stats)):
      for k, v in stat.items():
        joint_stats[f'layer{i}/' + k] = v
    return stats[0]['loss'], joint_stats