Esempio n. 1
0
  def fbo(inputs, weights, state, slots, opt_params, rng, step, grads):
    """FBO of the layer."""
    # We need a layer pure_fn but only for inputs and weights.
    def pure_fn_without_state_and_rng(x, w):
      return layer.pure_fn(x, w, state, rng)

    # Calculate the vector-Jacobian product of the reduced pure fn.
    activations, vjp_fn, new_state = fastmath.vjp(
        pure_fn_without_state_and_rng, inputs, weights, has_aux=True)

    # In the loss layer, set gradients to 1 with the dtype of activations=loss.
    if grads is None and stats_name is not None:
      grads = jnp.ones((), dtype=activations.dtype)

    # The vjp function returns gradients with respect to inputs and weights.
    grads_inputs, grads_weights = vjp_fn(grads)

    # For non-trainable layers, return the calculated arguments.
    if _is_empty_tuple(weights):
      stats = {}
      if stats_name is not None:
        stats[stats_name] = activations
      return weights, new_state, slots, grads_inputs, stats

    # In multi-device setting, average gradients from multiple devices.
    if n_devices > 1:
      grads_weights = _average_multidevice_gradients(grads_weights)

    # Run the optimizer.
    new_weights, new_slots, stats = optimizer.tree_update(
        step, grads_weights, weights, slots, opt_params)
    if stats_name is not None:
      stats[stats_name] = activations
    return new_weights, new_state, new_slots, grads_inputs, stats
Esempio n. 2
0
  def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(),
                       rng=None):
    """Backward pass: computes the inverse of a layer and propagates gradients.

    While you may choose to only implement reverse, some layers implement this
    function directly as computation may be shared between reversing and
    computing gradients.

    Args:
      output: Output activations; can be a (possibly nested) tuple.
      grad: gradient signal (cotangent) computed based on subsequent layers.
        The structure and shape must match the output.
      weights: layer weights
      state: start state
      new_state: updated state computed by the forward pass
      rng: Single-use random number generator (JAX PRNG key).

    Returns:
      A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input,
      x_grad is the gradient signal for the input, and weights_grad is the
      gradient signal for the weights.
    """
    def _do_forward(x, weights):
      old_weights, old_state, old_rng = self.weights, self.state, self._rng
      self.state, self._rng = state, rng
      self.weights = weights
      res = self.forward(x)
      self.weights, self.state, self._rng = old_weights, old_state, old_rng
      return res

    reconstructed_x = self.reverse(output, weights, state, new_state, rng)
    _, vjpfun = fastmath.vjp(_do_forward, reconstructed_x, weights)
    x_weights_grad = vjpfun(grad)
    return reconstructed_x, x_weights_grad
Esempio n. 3
0
        def loss_fbo(inputs, weights, state, slots, opt_params, rng, step):
            """FBO of the final loss layer."""

            # We need a loss layer pure_fn but only for inputs and weights.
            def loss_pure_fn_without_state_and_rng(x, w):
                return loss_layer.pure_fn(x, w, state, rng)

            # Calculate the vector-Jacobian product of the reduced loss pure fn.
            loss, vjp_fn, new_state = fastmath.vjp(
                loss_pure_fn_without_state_and_rng,
                inputs,
                weights,
                has_aux=True)

            # The vjp function returns gradients with respect to inputs and weights.
            # Since loss is scalar and there are no other layers, run it at 1.0.
            grads_inputs, grads_weights = vjp_fn(jnp.ones((),
                                                          dtype=loss.dtype))

            # In multi-device setting, average gradients from multiple devices.
            if self._n_devices > 1:
                grads_weights = _average_multidevice_gradients(grads_weights)

            # Run the loss optimizer, which is the last one since it's the last layer.
            new_weights, new_slots, stats = self._optimizers[-1].tree_update(
                step, grads_weights, weights, slots, opt_params)
            stats['loss'] = loss
            return new_weights, new_state, new_slots, grads_inputs, stats
Esempio n. 4
0
        def first_fbo(inputs, weights, state, slots, opt_params, rng, step,
                      grads):
            """FBO of the first layer."""

            # We need the first layer's pure_fn but only for inputs and weights.
            def first_layer_pure_fn_without_state_and_rng(x, w):
                return first_layer.pure_fn(x, w, state, rng)

            # Calculate vector-Jacobian product of the reduced first layer pure fn.
            activations_after_first_layer, vjp_fn, new_state = fastmath.vjp(
                first_layer_pure_fn_without_state_and_rng,
                inputs,
                weights,
                has_aux=True)
            del activations_after_first_layer  # unused

            # The vjp function returns gradients with respect to inputs and weights.
            _, grads_weights = vjp_fn(grads)

            # In multi-device setting, average gradients from multiple devices.
            if self._n_devices > 1:
                grads_weights = _average_multidevice_gradients(grads_weights)

            # Run the first layer optimizer, which is the first one.
            new_weights, new_slots, stats = self._optimizers[0].tree_update(
                step, grads_weights, weights, slots, opt_params)
            return new_weights, new_state, new_slots, stats
Esempio n. 5
0
    def forward_and_or_backward(inputs,
                                weights,
                                state,
                                rng,
                                output_grad=None,
                                compute_output=True,
                                update_state=True):
        """Performs batched forward and/or backward passes.

    Args:
      inputs: inputs to the attention layer
      weights: weights for the attention layer
      state: state of the attention layer
      rng: PRNG key for the layer (shared across all examples and heads)
      output_grad: gradient of the loss wrt the output of the layer, or None.
          This function performs the backward pass iff `output_grad` is not
          None.
      compute_output: bool: whether to return the output of the forward pass
          (for example, a pure backwards pass does not need to return the
          output).
      update_state: bool: whether to return an updated layer state.

    Returns:
      A tuple (output, new_state, inputs_grad, weights_grad).
      - output is not None iff compute_output is True
      - new_state is not None iff update_state is True
      - inputs_grad & weights_grad are not None iff output_grad is not None
    """

        # We need a layer pure_fn but only for inputs and weights.
        def pure_fn_without_state_and_rng(x, w):
            return layer.pure_fn(x, w, state, rng)

        # Calculate the vector-Jacobian product of the layer pure_fn.
        output, vjp_fn, new_state = fastmath.vjp(pure_fn_without_state_and_rng,
                                                 inputs,
                                                 weights,
                                                 has_aux=True)
        output = output if compute_output else None
        new_state = new_state if update_state else None

        # The vjp function returns gradients with respect to inputs and weights.
        if output_grad is not None:
            grads_inputs, grads_weights = vjp_fn(output_grad)
        else:
            grads_inputs, grads_weights = None, None

        return (output, new_state, grads_inputs, grads_weights)
Esempio n. 6
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       rng=None):
    rngs = _split_rngs(rng, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in)
    outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self._forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = cb.outputs_onto_stack(
          attn_inputs_ct, stack_ct, self.attention_layer.n_out)

    compute_residual_ct = cb.inputs_from_stack(
        stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = cb.outputs_onto_stack(
        compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)
Esempio n. 7
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         rng=None):
        rngs = _split_rngs(rng, len(self.sublayers))

        accumulator_output, *context = output
        context = tuple(context)
        accumulator_output_ct, *context_ct = ct
        context_ct = tuple(context_ct)

        # Forward pass through self._compute_residual. Outputs that will not receive
        # a gradient signal from subsequent layers are moved to aux.
        def call_compute_residual(x, weights):
            state_to_pass = state[0]  # old_state

            # _replace_second_time is currently used exclusively in _RememberInReverse
            # layer to combat numerical instability in Reformer2 when quantizing
            # the mask in SparseFF.
            def _replace_second_time(stt, nstt):
                if (isinstance(stt, tuple) and len(stt) == 2
                        and isinstance(stt[1], dict)
                        and 'running_second_time' in stt[1]):
                    return (nstt[0], {'running_second_time_yes': ()})
                elif isinstance(stt, (tuple, list)):
                    assert isinstance(nstt,
                                      (tuple, list)) and len(nstt) == len(stt)
                    return type(stt)([
                        _replace_second_time(s, ns)
                        for s, ns in zip(stt, nstt)
                    ])
                else:
                    return stt

            state_to_pass = _replace_second_time(state_to_pass, new_state[0])
            res, _ = self._compute_residual.pure_fn(x,
                                                    weights=weights,
                                                    state=state_to_pass,
                                                    rng=rngs[0])
            if not isinstance(res, (tuple, list)):
                return res, None
            else:
                n_differentiable = 1
                if self._attention_layer is not None:
                    n_differentiable = min(len(res),
                                           self._attention_layer.n_in)
                return res[:n_differentiable], res[n_differentiable:]

        stack = context
        inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in)
        outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
            call_compute_residual, inputs, weights[0], has_aux=True)
        if outputs_aux is not None:
            n_differentiable_outputs = len(outputs)
            outputs = outputs + outputs_aux
        stack = cb.outputs_onto_stack(outputs, stack,
                                      self._compute_residual.n_in)

        stack_ct = accumulator_output_ct
        if self._attention_layer is None:
            residual = stack[0] if isinstance(stack, (tuple, list)) else stack
        else:
            inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in)
            (residual, _, attn_inputs_ct,
             attn_weights_ct) = self._forward_and_or_backward(
                 inputs,
                 weights[1],
                 new_state[1],
                 rngs[1],
                 output_grad=accumulator_output_ct,
                 compute_output=True,
                 update_state=False)
            stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct,
                                             self._attention_layer.n_out)

        compute_residual_ct = cb.inputs_from_stack(
            stack_ct, self._compute_residual.n_out)
        if outputs_aux is not None:
            if not isinstance(compute_residual_ct, (tuple, list)):
                compute_residual_ct = (compute_residual_ct, )
            compute_residual_ct = compute_residual_ct[:
                                                      n_differentiable_outputs]
            assert len(compute_residual_ct) == n_differentiable_outputs
        (compute_residual_inputs_ct, compute_residual_weights_ct
         ) = compute_residual_vjpfun(compute_residual_ct)
        stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct,
                                         self._compute_residual.n_out)
        if not isinstance(stack_ct, (tuple, list)):
            stack_ct = (stack_ct, )

        def _add(x, y):
            # `None` is for TFNP backend, which uses `None` as the gradient of
            # int/bool instead of an array of dtype `float0`.
            if x is None or x.dtype == jax.float0:
                return y
            if y is None or y.dtype == jax.float0:
                return x
            return x + y

        stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg(
            _add, context_ct[:len(stack_ct)],
            stack_ct) + context_ct[len(stack_ct):]

        reconstructed_x = accumulator_output - residual
        stack = (reconstructed_x, ) + context
        if self._attention_layer is None:
            weights_ct = (compute_residual_weights_ct, )
        else:
            weights_ct = (compute_residual_weights_ct, attn_weights_ct)
        return stack, (stack_ct, weights_ct)