Beispiel #1
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n = fastmath.psum(jnp.array(1.0),
                      'batch')  # number of devices on all hosts
    if not adasum:
        return fastmath.nested_map(lambda g: g / n, gradients_psum)
    # This implements an approximation of the Adasum algorithm from the following
    # paper: https://arxiv.org/pdf/2006.02924.pdf
    # Since implementing halving and averaging half-by-half is tricky, we first
    # average all hosts, so we use the sum as a point of comparison for gradients.
    # So for 2 devices, this algorithm is the same as in the paper, but with more
    # devices it does a different kind of averaging. It still has the property
    # that orthogonal gradients will result in a sum while identical ones will
    # be averaged, as postulated in the paper.
    adasum_nominator = fastmath.nested_map_multiarg(
        lambda g, q: jnp.vdot(g, q),  # pylint: disable=unnecessary-lambda
        gradients,
        gradients_psum)
    grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients)
    # If all devices have identical gradients, then the nominator is equal
    # to n * grad_norm; if they're orthogonal, then nominator = grad_norm.
    scaled_grads = fastmath.nested_map_multiarg(
        lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) /
                                          (n * g_norm)), gradients,
        adasum_nominator, grad_norm)
    return fastmath.psum(scaled_grads, 'batch')
Beispiel #2
0
 def _l2_norm(self, flat_list):
   """Returns the aggregate L2 norm of a list of tensors."""
   if fastmath.is_backend(fastmath.Backend.JAX):
     norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
   else:  # TODO(lukaszkaiser): add vdot to TF-numpy
     norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list))
   return norm
Beispiel #3
0
def l2_norm(tree):
    """Returns an L2 norm computed over all elements of all tensors in `tree`.

  Args:
    tree: Tree-structured collection of tensors, e.g., model weights matching
        the model's layer structure.

  Returns:
    A scalar value computed as if all the tensors in `tree` were combined
    and flattened into a single vector, and then the L2 norm of that vector
    was calculated.
  """
    leaves = fastmath.tree_flatten(tree)
    return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
Beispiel #4
0
    def _l2_norm(self, flat_list):
        """Returns an L2-like norm of all elements of all tensors in `flat_list`.

    Args:
      flat_list: Collection of tensors as a flat list (rather than, e.g., a
          tree).

    Returns:
      A scalar value computed as if all the tensors in `flat_list` were joined
      and flattened into a single vector, and then the L2 norm of that vector
      was calculated.
    """
        if fastmath.is_backend(fastmath.Backend.JAX):
            norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
        else:  # TODO(lukaszkaiser): add vdot to TF-numpy
            norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list))
        return norm
Beispiel #5
0
def l2_norm(tree):
  """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
  leaves = fastmath.tree_flatten(tree)
  return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
Beispiel #6
0
def _adasum_merge(g1, g2):
    """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf."""
    frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30)
    frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30)
    return (1 - frac1) * g1 + (1 - frac2) * g2