Пример #1
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n = fastmath.psum(jnp.array(1.0),
                      'batch')  # number of devices on all hosts
    if not adasum:
        return fastmath.nested_map(lambda g: g / n, gradients_psum)
    # This implements an approximation of the Adasum algorithm from the following
    # paper: https://arxiv.org/pdf/2006.02924.pdf
    # Since implementing halving and averaging half-by-half is tricky, we first
    # average all hosts, so we use the sum as a point of comparison for gradients.
    # So for 2 devices, this algorithm is the same as in the paper, but with more
    # devices it does a different kind of averaging. It still has the property
    # that orthogonal gradients will result in a sum while identical ones will
    # be averaged, as postulated in the paper.
    adasum_nominator = fastmath.nested_map_multiarg(
        lambda g, q: jnp.vdot(g, q),  # pylint: disable=unnecessary-lambda
        gradients,
        gradients_psum)
    grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients)
    # If all devices have identical gradients, then the nominator is equal
    # to n * grad_norm; if they're orthogonal, then nominator = grad_norm.
    scaled_grads = fastmath.nested_map_multiarg(
        lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) /
                                          (n * g_norm)), gradients,
        adasum_nominator, grad_norm)
    return fastmath.psum(scaled_grads, 'batch')
Пример #2
0
    def _assert_all_equal(self, t1, t2, tol=1e-5):
        def eq(x1, x2):
            diff = np.maximum(np.abs(x1 - x2) - tol, 0.0)
            self.assertLessEqual(np.sum(diff),
                                 0.0,
                                 msg=f'\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}')

        fastmath.nested_map_multiarg(eq, t1, t2)
Пример #3
0
    def _test_equivalence_to_reference_code(self, model_cls, inp,
                                            input_signature, common_kwargs,
                                            *test_kwargs):
        ref_model = model_cls(use_reference_code=True, **common_kwargs)
        rng = fastmath.random.get_prng(123)
        weights, state = ref_model.init(input_signature, rng)

        ref_all = self._run_forward_and_backward(ref_model, inp, weights,
                                                 state)
        ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all

        for kwargs in test_kwargs:
            test_model = model_cls(**common_kwargs, **kwargs)
            state = test_model.init(input_signature, rng)[1]
            test_all = self._run_forward_and_backward(test_model, inp, weights,
                                                      state)
            test_out, test_state, test_inp_grad, test_weights_grad = test_all

            self.assertEqual(jax.tree_structure(ref_out),
                             jax.tree_structure(test_out))
            self.assertEqual(jax.tree_structure(ref_state),
                             jax.tree_structure(test_state))
            self.assertEqual(jax.tree_structure(ref_inp_grad),
                             jax.tree_structure(test_inp_grad))
            self.assertEqual(jax.tree_structure(ref_weights_grad),
                             jax.tree_structure(test_weights_grad))

            check_close = lambda x, y: self.assertAllClose(
                x, y, rtol=1e-3, atol=1e-3)
            fastmath.nested_map_multiarg(check_close, ref_out, test_out)
            fastmath.nested_map_multiarg(check_close, ref_state, test_state)
            fastmath.nested_map_multiarg(check_close, ref_inp_grad,
                                         test_inp_grad)
            fastmath.nested_map_multiarg(check_close, ref_weights_grad,
                                         test_weights_grad)
Пример #4
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS
    if adasum:
        # This implements a version of the Adasum algorithm from the following
        # paper: https://arxiv.org/pdf/2006.02924.pdf
        lg = max([i for i in range(20) if 2**i <= n])
        for lg_i in range(lg):
            shift = 2**lg_i
            perm = []
            for i in range(n):
                block_i = i % (2 * shift)  # we do blocks of 2*shift size
                if block_i < shift:
                    perm.append((i, i + shift))
                else:
                    perm.append((i, i - shift))
            perm_grad = jax.lax.ppermute(gradients,
                                         perm=perm,
                                         axis_name='batch')
            gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients,
                                                     perm_grad)
    if base.N_WEIGHTS_SHARDS > 1:  # only sum gradients from matching shards
        groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))]
                  for d in range(base.N_WEIGHTS_SHARDS)]
        gradients_psum = fastmath.psum(gradients,
                                       'batch',
                                       axis_index_groups=groups)
    else:
        gradients_psum = fastmath.psum(gradients, 'batch')  # sum all gradients
    n = jnp.array(n, dtype=jnp.float32)
    return fastmath.nested_map(lambda g: g / n, gradients_psum)
Пример #5
0
 def mock_training_step(x, weights, state, rng):
   def compute_mock_loss(weights):
     logits, new_state = model.pure_fn(x, weights, state, rng)
     loss = fastmath.numpy.mean(logits[..., 0])
     return loss, (new_state, logits)
   gradients, (new_state, logits) = fastmath.grad(
       compute_mock_loss, has_aux=True)(weights)
   new_weights = fastmath.nested_map_multiarg(
       lambda w, g: w - 1e-4 * g, weights, gradients)
   return new_weights, new_state, logits
Пример #6
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    n = jnp.array(fastmath.global_device_count(), dtype=jnp.float32)
    if adasum:
        # This implements a version of the Adasum algorithm from the following
        # paper: https://arxiv.org/pdf/2006.02924.pdf
        lg = max([i for i in range(20) if 2**i <= n])
        for lg_i in range(lg):
            shift = 2**lg_i
            perm = []
            for i in range(n):
                block_i = i % (2 * shift)  # we do blocks of 2*shift size
                if block_i < shift:
                    perm.append((i, i + shift))
                else:
                    perm.append((i, i - shift))
            perm_grad = jax.lax.ppermute(gradients,
                                         perm=perm,
                                         axis_name='batch')
            gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients,
                                                     perm_grad)
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    return fastmath.nested_map(lambda g: g / n, gradients_psum)
Пример #7
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       rng=None):
    rngs = _split_rngs(rng, len(self.sublayers))

    accumulator_output, *context = output
    context = tuple(context)
    accumulator_output_ct, *context_ct = ct
    context_ct = tuple(context_ct)

    # Forward pass through self.compute_residual. Outputs that will not receive
    # a gradient signal from subsequent layers are moved to aux.
    def call_compute_residual(x, weights):
      res, _ = self.compute_residual.pure_fn(
          x, weights=weights, state=state[0], rng=rngs[0])
      if not isinstance(res, (tuple, list)):
        return res, None
      else:
        n_differentiable = 1
        if self.attention_layer is not None:
          n_differentiable = min(len(res), self.attention_layer.n_in)
        return res[:n_differentiable], res[n_differentiable:]

    stack = context
    inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in)
    outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
        call_compute_residual, inputs, weights[0], has_aux=True)
    if outputs_aux is not None:
      n_differentiable_outputs = len(outputs)
      outputs = outputs + outputs_aux
    stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in)

    stack_ct = accumulator_output_ct
    if self.attention_layer is None:
      residual = stack[0] if isinstance(stack, (tuple, list)) else stack
    else:
      inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in)
      (residual, _, attn_inputs_ct, attn_weights_ct
      ) = self._forward_and_or_backward(
          inputs, weights[1], new_state[1], rngs[1],
          output_grad=accumulator_output_ct,
          compute_output=True, update_state=False)
      stack_ct = cb.outputs_onto_stack(
          attn_inputs_ct, stack_ct, self.attention_layer.n_out)

    compute_residual_ct = cb.inputs_from_stack(
        stack_ct, self.compute_residual.n_out)
    if outputs_aux is not None:
      if not isinstance(compute_residual_ct, (tuple, list)):
        compute_residual_ct = (compute_residual_ct,)
      compute_residual_ct = compute_residual_ct[:n_differentiable_outputs]
      assert len(compute_residual_ct) == n_differentiable_outputs
    (compute_residual_inputs_ct, compute_residual_weights_ct
    ) = compute_residual_vjpfun(compute_residual_ct)
    stack_ct = cb.outputs_onto_stack(
        compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out)
    if not isinstance(stack_ct, (tuple, list)):
      stack_ct = (stack_ct,)
    stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg(
        lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct
        ) + context_ct[len(stack_ct):]

    reconstructed_x = accumulator_output - residual
    stack = (reconstructed_x,) + context
    if self.attention_layer is None:
      weights_ct = (compute_residual_weights_ct,)
    else:
      weights_ct = (compute_residual_weights_ct, attn_weights_ct)
    return stack, (stack_ct, weights_ct)
Пример #8
0
    def reverse_and_grad(self,
                         output,
                         ct,
                         weights=(),
                         state=(),
                         new_state=(),
                         rng=None):
        rngs = _split_rngs(rng, len(self.sublayers))

        accumulator_output, *context = output
        context = tuple(context)
        accumulator_output_ct, *context_ct = ct
        context_ct = tuple(context_ct)

        # Forward pass through self._compute_residual. Outputs that will not receive
        # a gradient signal from subsequent layers are moved to aux.
        def call_compute_residual(x, weights):
            state_to_pass = state[0]  # old_state

            # _replace_second_time is currently used exclusively in _RememberInReverse
            # layer to combat numerical instability in Reformer2 when quantizing
            # the mask in SparseFF.
            def _replace_second_time(stt, nstt):
                if (isinstance(stt, tuple) and len(stt) == 2
                        and isinstance(stt[1], dict)
                        and 'running_second_time' in stt[1]):
                    return (nstt[0], {'running_second_time_yes': ()})
                elif isinstance(stt, (tuple, list)):
                    assert isinstance(nstt,
                                      (tuple, list)) and len(nstt) == len(stt)
                    return type(stt)([
                        _replace_second_time(s, ns)
                        for s, ns in zip(stt, nstt)
                    ])
                else:
                    return stt

            state_to_pass = _replace_second_time(state_to_pass, new_state[0])
            res, _ = self._compute_residual.pure_fn(x,
                                                    weights=weights,
                                                    state=state_to_pass,
                                                    rng=rngs[0])
            if not isinstance(res, (tuple, list)):
                return res, None
            else:
                n_differentiable = 1
                if self._attention_layer is not None:
                    n_differentiable = min(len(res),
                                           self._attention_layer.n_in)
                return res[:n_differentiable], res[n_differentiable:]

        stack = context
        inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in)
        outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp(
            call_compute_residual, inputs, weights[0], has_aux=True)
        if outputs_aux is not None:
            n_differentiable_outputs = len(outputs)
            outputs = outputs + outputs_aux
        stack = cb.outputs_onto_stack(outputs, stack,
                                      self._compute_residual.n_in)

        stack_ct = accumulator_output_ct
        if self._attention_layer is None:
            residual = stack[0] if isinstance(stack, (tuple, list)) else stack
        else:
            inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in)
            (residual, _, attn_inputs_ct,
             attn_weights_ct) = self._forward_and_or_backward(
                 inputs,
                 weights[1],
                 new_state[1],
                 rngs[1],
                 output_grad=accumulator_output_ct,
                 compute_output=True,
                 update_state=False)
            stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct,
                                             self._attention_layer.n_out)

        compute_residual_ct = cb.inputs_from_stack(
            stack_ct, self._compute_residual.n_out)
        if outputs_aux is not None:
            if not isinstance(compute_residual_ct, (tuple, list)):
                compute_residual_ct = (compute_residual_ct, )
            compute_residual_ct = compute_residual_ct[:
                                                      n_differentiable_outputs]
            assert len(compute_residual_ct) == n_differentiable_outputs
        (compute_residual_inputs_ct, compute_residual_weights_ct
         ) = compute_residual_vjpfun(compute_residual_ct)
        stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct,
                                         self._compute_residual.n_out)
        if not isinstance(stack_ct, (tuple, list)):
            stack_ct = (stack_ct, )

        def _add(x, y):
            # `None` is for TFNP backend, which uses `None` as the gradient of
            # int/bool instead of an array of dtype `float0`.
            if x is None or x.dtype == jax.float0:
                return y
            if y is None or y.dtype == jax.float0:
                return x
            return x + y

        stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg(
            _add, context_ct[:len(stack_ct)],
            stack_ct) + context_ct[len(stack_ct):]

        reconstructed_x = accumulator_output - residual
        stack = (reconstructed_x, ) + context
        if self._attention_layer is None:
            weights_ct = (compute_residual_weights_ct, )
        else:
            weights_ct = (compute_residual_weights_ct, attn_weights_ct)
        return stack, (stack_ct, weights_ct)