Beispiel #1
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS
    if adasum:
        # This implements a version of the Adasum algorithm from the following
        # paper: https://arxiv.org/pdf/2006.02924.pdf
        lg = max([i for i in range(20) if 2**i <= n])
        for lg_i in range(lg):
            shift = 2**lg_i
            perm = []
            for i in range(n):
                block_i = i % (2 * shift)  # we do blocks of 2*shift size
                if block_i < shift:
                    perm.append((i, i + shift))
                else:
                    perm.append((i, i - shift))
            perm_grad = jax.lax.ppermute(gradients,
                                         perm=perm,
                                         axis_name='batch')
            gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients,
                                                     perm_grad)
    if base.N_WEIGHTS_SHARDS > 1:  # only sum gradients from matching shards
        groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))]
                  for d in range(base.N_WEIGHTS_SHARDS)]
        gradients_psum = fastmath.psum(gradients,
                                       'batch',
                                       axis_index_groups=groups)
    else:
        gradients_psum = fastmath.psum(gradients, 'batch')  # sum all gradients
    n = jnp.array(n, dtype=jnp.float32)
    return fastmath.nested_map(lambda g: g / n, gradients_psum)
Beispiel #2
0
def _n_weights_per_core(weights):  # pylint: disable=invalid-name
    """Calculates the number of weights per core.

  In multi-device settings, gradients and losses are averaged over all devices.
  When loss is weighted and the number of weights can differ by device, e.g.,
  when the weights represent the number of tokens in a batch of sentences (which
  can differ from device to device), we want to make sure each token on each
  device is weighted in the same way. This function ensures that by reporting
  the number of weights per core in multi-core settings (and simply
  np.sum(weights) in a single-core setting).

  Args:
    weights: tensor with arbitrary shape

  Returns:
    a scalar equal to np.sum(weights) in 1-machine settings and to the sum
    of weights over all cores divided by the number of cores otherwise
  """
    weights_sum = jnp.sum(weights)
    if fastmath.device_count() < 2:
        return weights_sum
    else:
        try:
            n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
            return fastmath.psum(weights_sum, 'batch') / n_devices_total
        except (NameError,
                ValueError):  # running outside of pmap, e.g., on init
            return weights_sum  # fall back to the sum
Beispiel #3
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n = fastmath.psum(jnp.array(1.0),
                      'batch')  # number of devices on all hosts
    if not adasum:
        return fastmath.nested_map(lambda g: g / n, gradients_psum)
    # This implements an approximation of the Adasum algorithm from the following
    # paper: https://arxiv.org/pdf/2006.02924.pdf
    # Since implementing halving and averaging half-by-half is tricky, we first
    # average all hosts, so we use the sum as a point of comparison for gradients.
    # So for 2 devices, this algorithm is the same as in the paper, but with more
    # devices it does a different kind of averaging. It still has the property
    # that orthogonal gradients will result in a sum while identical ones will
    # be averaged, as postulated in the paper.
    adasum_nominator = fastmath.nested_map_multiarg(
        lambda g, q: jnp.vdot(g, q),  # pylint: disable=unnecessary-lambda
        gradients,
        gradients_psum)
    grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients)
    # If all devices have identical gradients, then the nominator is equal
    # to n * grad_norm; if they're orthogonal, then nominator = grad_norm.
    scaled_grads = fastmath.nested_map_multiarg(
        lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) /
                                          (n * g_norm)), gradients,
        adasum_nominator, grad_norm)
    return fastmath.psum(scaled_grads, 'batch')
Beispiel #4
0
def _average_multidevice_gradients(gradients):
  """Averages gradients over all the devices across different hosts."""
  # Sum gradients over all devices across all hosts.
  gradients = fastmath.psum(gradients, 'batch')
  # Calculate the total number of devices.
  # Note: the usual n_devices is only the number of devices at this host,
  # here we are calculating the number of all devices across all hosts.
  n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
  # Average across hosts.
  return fastmath.nested_map(lambda g: g / n_devices_total, gradients)
Beispiel #5
0
  def _multi_device_update_fn(
      weights_and_slots, step, opt_params, batch, state, rng):
    # We assume all tensors have the first dimension = n_devices.
    weights, slots = weights_and_slots
    (loss, state), gradients = forward_and_backward_fn(
        batch, weights, state, rng)

    # gradients now need to be summed over all the devices across different host
    # machines, n_devices is only the number of devices on *this* host machine.
    gradients = fastmath.psum(gradients, 'batch')
    n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
    # Average across hosts.
    gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total, gradients)

    weights, slots, stats = optimizer.tree_update(
        step, gradients, weights, slots, opt_params)
    stats['loss'] = loss
    return (weights, slots), state, stats
Beispiel #6
0
 def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots = weights_and_slots
   rng, subrng = jax_random.split(rng)
   grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: (  # pylint: disable=g-long-lambda
           fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')),
       grads)
   new_weights, new_slots, stats = optimizer.tree_update(
       i, grads, weights, slots, opt_params)
   return (new_weights, new_slots), stats, state, subrng
Beispiel #7
0
def _make_weights_and_state_same_across_hosts(weights_and_state):
  """Makes train and eval model's weights and state the same across hosts."""

  # We assume that they have been already replicated, i.e the leading axis is
  # self._n_devices

  # This is the total number of devices across all hosts.
  n_devices_total = fastmath.psum(jnp.array(1.0), 'devices')

  # This sums up the weights and state across all devices.
  # NOTE: There will not be any leading axis remaining because we psum
  # over it.
  weights_and_state = fastmath.psum(weights_and_state, 'devices')

  # We finally take the average over all devices.
  weights_and_state = jax.tree_util.tree_map(
      lambda ws: ws / n_devices_total, weights_and_state)

  return weights_and_state
Beispiel #8
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    n = jnp.array(fastmath.global_device_count(), dtype=jnp.float32)
    if adasum:
        # This implements a version of the Adasum algorithm from the following
        # paper: https://arxiv.org/pdf/2006.02924.pdf
        lg = max([i for i in range(20) if 2**i <= n])
        for lg_i in range(lg):
            shift = 2**lg_i
            perm = []
            for i in range(n):
                block_i = i % (2 * shift)  # we do blocks of 2*shift size
                if block_i < shift:
                    perm.append((i, i + shift))
                else:
                    perm.append((i, i - shift))
            perm_grad = jax.lax.ppermute(gradients,
                                         perm=perm,
                                         axis_name='batch')
            gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients,
                                                     perm_grad)
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    return fastmath.nested_map(lambda g: g / n, gradients_psum)
Beispiel #9
0
 def _get_conditionally_synced_rng(self):
     if self._sync and fastmath.global_device_count() > 1:
         return fastmath.psum(self.rng, 'batch')
     else:
         return self.rng
Beispiel #10
0
  def forward(self, xs):
    self._validate_forward_inputs(xs)
    (step, layers_state) = self.state
    # Get N+1 rngs, N for running layers and one extra.
    rngs = _split_rngs(self.rng, self._n_layers + 1)
    rng0, rngs = rngs[0], rngs[1:]
    if not self.sublayers:  # No-op: leave args unchanged.
      self.state = (step + 1, layers_state)
      return xs

    # Prepare the stack and do some safety checks as in the parent class.
    stack = xs
    new_state = []
    n_layers = self._n_layers
    weights = self.weights
    if n_layers != 1 and len(weights) != n_layers:
      raise ValueError('number of weights ({}) not equal to number of layers '
                       '({})'.format(len(weights), n_layers))
    if n_layers != 1 and len(layers_state) != n_layers:
      raise ValueError('length of state ({}) not equal to number of layers '
                       '({})'.format(len(layers_state), n_layers))

    # TODO(chowdhery): try different strategies, also try running not all
    # layers backwards by using fastmath.stop_gradient where needed.

    # Calculate how many layers to run forward.
    if self._mode == 'train':
      # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
      w_steps = float(self._skipping_warmup_steps)
      f_step = step.astype(jnp.float32)
      warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps)
      # low is the minimum number of layers to *not* skip, from n_layers to 0
      low = warmup * float(n_layers)
      # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
      # because (high - n_layers) / high is the probability we're not skipping
      # (after warmup); so high - n_layers = high - high * skip_fraction
      high = float(n_layers) / self._skip_fraction
      # We want the same rng0 on all cores.
      if fastmath.device_count() > 1:
        rng0 = fastmath.psum(rng0, 'batch')
      n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high)
    else:
      n_forward_layers = float(n_layers)
    # Run layers skipping after a certain number.
    cur_layer_idx = 0.0
    for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      def CondF(t):
        return layer.pure_fn(t[0], t[1], t[2], t[3])  # pylint: disable=cell-var-from-loop
      def PassF(t):
        return t[0], t[2]
      outputs, s = fastmath.cond(
          fastmath.lt(cur_layer_idx, n_forward_layers),
          CondF,
          PassF,
          (inputs, p, s, rng)
      )
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
      cur_layer_idx += 1.0
    self.state = (step + 1, new_state)
    return stack
Beispiel #11
0
def _average_multidevice_gradients(gradients):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
    return fastmath.nested_map(lambda g: g / n_devices_total, gradients_psum)