Ejemplo n.º 1
0
  def forward(self, inputs):
    """Executes this layer as part of a forward pass through the model."""
    weights = self.weights[0]
    if isinstance(inputs, list):
      inputs = tuple(inputs)  # so that inputs structure matches outputs
    n_carry = self._n_carry
    def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
      carry, state, i = carry_and_state
      x_and_carry = x + carry if n_carry > 0 else x
      rng = fastmath.random.fold_in(self.rng, i)
      res, new_state = self.sublayer.pure_fn(
          x_and_carry, weights, state, rng, use_cache=True)
      if n_carry > 0:
        return (res[:-n_carry], (res[-n_carry:], new_state, i+1))
      else:
        return (res, ([], new_state, i+1))

    if n_carry > 0:
      xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
      xs_carry = inputs[-n_carry:]
      if self._mode == 'predict' and self._state[1] is not ():  # pylint: disable=literal-comparison
        xs_carry = self._state[1]
      init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32))
    else:
      xs_carry = ()
      xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32))
    ys, (carry, new_state, _) = _scan(scannable_fn, xs, init,
                                      axis=self._axis, remat=self._remat)
    res = ys + carry if n_carry > 0 else ys
    state_carry = carry if self._mode == 'predict' and n_carry > 0 else ()
    self.state = (new_state, state_carry)
    return res  # Put outputs and carry back on stack.
Ejemplo n.º 2
0
    def forward(self, inputs):
        weights = self.weights[0]
        if isinstance(inputs, list):
            inputs = tuple(inputs)  # so that inputs structure matches outputs
        n_carry = self._n_carry

        def scannable_fn(x, carry_and_state):  # pylint: disable=invalid-name
            carry, state, i = carry_and_state
            x_and_carry = x + carry if n_carry > 0 else x
            rng = fastmath.random.fold_in(self.rng, i)
            res, new_state = self.sublayer.pure_fn(x_and_carry,
                                                   weights,
                                                   state,
                                                   rng,
                                                   use_cache=True)
            if n_carry > 0:
                return (res[:-n_carry], (res[-n_carry:], new_state, i + 1))
            else:
                return (res, ([], new_state, i + 1))

        if n_carry > 0:
            xs = inputs[:-n_carry]  # Split input stack into inputs and carry.
            init = (inputs[-n_carry:], self.state[0],
                    jnp.array(0, dtype=jnp.int32))
        else:
            xs, init = inputs, ([], self.state[0], jnp.array(0,
                                                             dtype=jnp.int32))
        ys, (carry, new_state, _) = _scan(scannable_fn,
                                          xs,
                                          init,
                                          axis=self._axis,
                                          remat=self._remat)
        res = ys + carry if n_carry > 0 else ys
        self.state = (new_state, )
        return res  # Put outputs and carry back on stack.
Ejemplo n.º 3
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        opt_param_updates = self._for_n_devices(
            {'learning_rate': np.array(self.learning_rate)})
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        weights, slots, opt_params = opt_state
        (weights,
         slots), stat, self._model_state, self._rngs = self._jit_update_fn(
             (weights, slots), self._step, opt_params, batch,
             self._model_state, self._rngs)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        if self._should_log_now():
            for name, value in stat.items():
                # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here
                # with a device put array error complaining that it should be an array.
                # On multiple devices, take the mean.
                scalar_value = np.mean(np.array(value))
                self._train_sw.scalar('training/' + name,
                                      scalar_value,
                                      step=self._step)
        self._step += 1
Ejemplo n.º 4
0
 def _UpdateRow(x):
   # (L, H), (L1, H) & (L2, H)
   row_ed, row_e, _ = x
   mask_e = row_e != 0
   len_e = jnp.sum(mask_e, dtype=jnp.int32)
   # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
   # and pick up (L2, H) tensor slice from there.
   zero = jnp.array(0, dtype=len_e.dtype)  # avoid int32/int64 mismatch
   l2_np = jnp.array(L2, dtype=len_e.dtype)
   h_np = jnp.array(H, dtype=len_e.dtype)
   return fastmath.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np))
Ejemplo n.º 5
0
    def __init__(self,
                 learning_rate=0.01,
                 clip_grad_norm=None,
                 **init_opt_params):
        """Sets initial hyperparameter values for this optimizer.

    Takes optimizer hyperparameters as keyword arguments. These values can
    change over time (training steps), e.g., for learning rate schedules.

    To expose subclass hyperparameters for gin configuration, override this
    constructor and use explicitly named keyword arguments. See
    `momentum.Momentum.__init__` for one such example.

    Args:
      learning_rate: Learning rate for the optimizer. This can change during
          training by means of a training rate schedule.
      clip_grad_norm: If specified, this scalar value is used to limit gradient
          size -- all gradient elements in a training step are treated as if
          they belonged to a single vector and then scaled back if needed so
          that such a vector's L2 norm does not exceed `clip_grad_norm`. If
          None, no clipping happens.
      **init_opt_params: Initial values of any additional optimizer parameters.
    """
        init_opt_params['learning_rate'] = learning_rate
        self._init_opt_params = {
            name: jnp.array(value)
            for (name, value) in init_opt_params.items()
        }
        self._slots = None
        # Gradient clipping happens with respect to the norm of the whole gradient
        # tree, so it is not passed to single-slot updates, but done in this class
        # for the whole gradient tree.
        self._clip_grad_norm = clip_grad_norm
Ejemplo n.º 6
0
  def init_weights_and_state(self, input_signature):
    """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: :py:class:`ShapeDtype` instance characterizing the input
          this layer should compute on.
    """
    d_feature = input_signature.shape[-1]
    if self._d_feature is not None:
      d_feature = self._d_feature
    pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
    position = np.arange(0, self._max_len)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)  # [self._max_len, d_feature]
    if self._use_bfloat16:
      pe = pe.astype(jnp.bfloat16)
    w = jnp.array(pe)  # Trainable parameters, initialized above.
    if self._d_feature is not None:
      ff = init.GlorotUniformInitializer()(
          (d_feature, input_signature.shape[-1]), self.rng)
      self.weights = w, ff
    else:
      self.weights = w
    if self._mode == 'predict':
      self.state = jnp.zeros((), dtype=jnp.int32)
Ejemplo n.º 7
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n = fastmath.psum(jnp.array(1.0),
                      'batch')  # number of devices on all hosts
    if not adasum:
        return fastmath.nested_map(lambda g: g / n, gradients_psum)
    # This implements an approximation of the Adasum algorithm from the following
    # paper: https://arxiv.org/pdf/2006.02924.pdf
    # Since implementing halving and averaging half-by-half is tricky, we first
    # average all hosts, so we use the sum as a point of comparison for gradients.
    # So for 2 devices, this algorithm is the same as in the paper, but with more
    # devices it does a different kind of averaging. It still has the property
    # that orthogonal gradients will result in a sum while identical ones will
    # be averaged, as postulated in the paper.
    adasum_nominator = fastmath.nested_map_multiarg(
        lambda g, q: jnp.vdot(g, q),  # pylint: disable=unnecessary-lambda
        gradients,
        gradients_psum)
    grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients)
    # If all devices have identical gradients, then the nominator is equal
    # to n * grad_norm; if they're orthogonal, then nominator = grad_norm.
    scaled_grads = fastmath.nested_map_multiarg(
        lambda g, nominator, g_norm: g * (1 - (nominator - g_norm) /
                                          (n * g_norm)), gradients,
        adasum_nominator, grad_norm)
    return fastmath.psum(scaled_grads, 'batch')
Ejemplo n.º 8
0
 def init_weights_and_state(self, input_signature):
   if self._mode == 'predict':
     shape, dtype = input_signature.as_tuple()
     batch_size, _, d_feature = shape
     cache = jnp.zeros((batch_size, 2 * self._total_kv_pooling, d_feature),
                       dtype=dtype)
     self.state = cache, jnp.array(0)
Ejemplo n.º 9
0
  def __init__(self, learning_rate=0.01, clip_grad_norm=None,
               **init_opt_params):
    """Sets initial hyperparameter values for this optimizer.

    Takes initial optimizer parameters as keyword arguments. These values can
    be changed between training steps, e.g., for learning rate schedules.

    If you want your subclass to expose hyperparameters for gin configuration,
    override this constructor and use explicitly named keyword arguments. See
    `momentum.Momentum.__init__` for one such example.

    Args:
      learning_rate: The initial learning rate.
      clip_grad_norm: float; the value to which gradients will be clipped.
      **init_opt_params: Initial values of any additional optimizer parameters.
    """
    init_opt_params['learning_rate'] = learning_rate
    self._init_opt_params = {
        name: jnp.array(value) for (name, value) in init_opt_params.items()
    }
    self._slots = None
    # Gradient clipping happens with respect to the norm of the whole gradient
    # tree, so it is not passed to single-slot updates, but done in this class
    # for the whole gradient tree.
    self._clip_grad_norm = clip_grad_norm
Ejemplo n.º 10
0
def _average_multidevice_gradients(gradients, adasum=False):
    """Averages gradients over all the devices across different hosts."""
    n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS
    if adasum:
        # This implements a version of the Adasum algorithm from the following
        # paper: https://arxiv.org/pdf/2006.02924.pdf
        lg = max([i for i in range(20) if 2**i <= n])
        for lg_i in range(lg):
            shift = 2**lg_i
            perm = []
            for i in range(n):
                block_i = i % (2 * shift)  # we do blocks of 2*shift size
                if block_i < shift:
                    perm.append((i, i + shift))
                else:
                    perm.append((i, i - shift))
            perm_grad = jax.lax.ppermute(gradients,
                                         perm=perm,
                                         axis_name='batch')
            gradients = fastmath.nested_map_multiarg(_adasum_merge, gradients,
                                                     perm_grad)
    if base.N_WEIGHTS_SHARDS > 1:  # only sum gradients from matching shards
        groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))]
                  for d in range(base.N_WEIGHTS_SHARDS)]
        gradients_psum = fastmath.psum(gradients,
                                       'batch',
                                       axis_index_groups=groups)
    else:
        gradients_psum = fastmath.psum(gradients, 'batch')  # sum all gradients
    n = jnp.array(n, dtype=jnp.float32)
    return fastmath.nested_map(lambda g: g / n, gradients_psum)
Ejemplo n.º 11
0
def one_hot(x, n_categories, dtype=jnp.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = jnp.arange(n_categories)
    if fastmath.is_backend(fastmath.Backend.JAX):
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
Ejemplo n.º 12
0
def _n_weights_per_core(weights):  # pylint: disable=invalid-name
    """Calculates the number of weights per core.

  In multi-device settings, gradients and losses are averaged over all devices.
  When loss is weighted and the number of weights can differ by device, e.g.,
  when the weights represent the number of tokens in a batch of sentences (which
  can differ from device to device), we want to make sure each token on each
  device is weighted in the same way. This function ensures that by reporting
  the number of weights per core in multi-core settings (and simply
  np.sum(weights) in a single-core setting).

  Args:
    weights: tensor with arbitrary shape

  Returns:
    a scalar equal to np.sum(weights) in 1-machine settings and to the sum
    of weights over all cores divided by the number of cores otherwise
  """
    weights_sum = jnp.sum(weights)
    if fastmath.device_count() < 2:
        return weights_sum
    else:
        try:
            n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
            return fastmath.psum(weights_sum, 'batch') / n_devices_total
        except (NameError,
                ValueError):  # running outside of pmap, e.g., on init
            return weights_sum  # fall back to the sum
Ejemplo n.º 13
0
def _multi_device_put(x, devices=None):
  """Memory efficient multi-device replication / broadcast in JAX.

  JAX uses a ShardedDeviceArray class that holds a list of device buffers
  on separate devices for use with pmap'd computations.  Sharded arrays
  are explicitly used to eliminate unnecessary inter-device transfer of
  memory buffers between use in pmap'd computations.  The JAX API currently
  does not have a multi-device 'put' function that copies a buffer onto
  N devices in a memory-efficient fashion, so we implement our own here.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.

  Returns:
    A ShardedDeviceArray with
    dtype = x.dtype and shape = (n_devices,) + x.shape
    that's backed by replicated device_buffers on each local device.
  """
  # Convert _FilledConstants that don't have device_buffer, etc.
  if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
    x = jnp.array(x)
  # Calculate the abstract shape of the replicated array.
  if not devices:
    devices = jax.local_devices()
  return jax.api.device_put_sharded(len(devices) * [x], devices)
Ejemplo n.º 14
0
 def test_autoregressive_sample_transformerlm(self):
     model = models.TransformerLM(10,
                                  d_model=32,
                                  d_ff=64,
                                  n_layers=1,
                                  n_heads=2,
                                  mode='predict')
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     s1 = trainer_lib.autoregressive_sample(model,
                                            batch_size=1,
                                            eos_id=-1,
                                            max_length=10)
     self.assertEqual(s1.shape[0], 1)
     self.assertEqual(s1.shape[1], 10)
     batch_per_device = 2 // fastmath.device_count()
     model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=jnp.int32))
     s2 = trainer_lib.autoregressive_sample(model,
                                            batch_size=2,
                                            max_length=10)
     self.assertEqual(s2.shape[0], 2)
     self.assertLess(s2.shape[1], 11)
     model.init(shapes.ShapeDtype((1, 1), dtype=jnp.int32))
     prefix = jnp.array([[1, 2, 3]])
     s3 = trainer_lib.autoregressive_sample(model,
                                            eos_id=-1,
                                            max_length=10,
                                            batch_size=1,
                                            prefix=prefix)
     self.assertEqual(s3.shape[0], 1)
     self.assertEqual(int(s3[0][0]), 1)
     self.assertEqual(int(s3[0][1]), 2)
     self.assertEqual(int(s3[0][2]), 3)
Ejemplo n.º 15
0
def _fast_inference_init_state(input_signature,
                               buffer_length,
                               predict_mask=None):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        d_feature = shape[-1]
        return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    if predict_mask is not None:
        mask_for_predict = jnp.zeros((buffer_length, )) != 0
        return (mask_for_predict, k, v, jnp.array(0))
    else:
        return (k, v, jnp.array(0))
Ejemplo n.º 16
0
 def _UpdateRow(x):
     # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
     row_e, row_d, row_mask_e = x
     # final_row - (L1+L2, H)
     final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
     # Find the last real token/vector of the encoder.
     e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
     # Starting after that index, update with the decoder row.
     zero = jnp.array(0, dtype=e_idx.dtype)  # avoid int32/int64 mismatch
     return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))
Ejemplo n.º 17
0
def _average_multidevice_gradients(gradients):
  """Averages gradients over all the devices across different hosts."""
  # Sum gradients over all devices across all hosts.
  gradients = fastmath.psum(gradients, 'batch')
  # Calculate the total number of devices.
  # Note: the usual n_devices is only the number of devices at this host,
  # here we are calculating the number of all devices across all hosts.
  n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
  # Average across hosts.
  return fastmath.nested_map(lambda g: g / n_devices_total, gradients)
Ejemplo n.º 18
0
 def single_device_update_fn(
     weights_and_slots, step, opt_params, batch, state, rng):
   step = jnp.array(step, dtype=jnp.int32)  # Needed in TFNP backend.
   weights, slots = weights_and_slots
   (loss, state), gradients = forward_and_backward_fn(
       batch, weights, state, rng)
   weights, slots, stats = optimizer.tree_update(
       step, gradients, weights, slots, opt_params)
   stats['loss'] = loss
   return (weights, slots), state, stats
Ejemplo n.º 19
0
    def _fast_inference_init_state(self, input_signature):
        """Returns an initial state for causal attention layer fast inference."""
        def zeros_for(bs, shape_dtype):
            shape, dtype = shape_dtype.as_tuple()
            d_feature = shape[-1]
            return jnp.zeros((bs, self._max_len, d_feature), dtype=dtype)

        batch_size = input_signature[0].shape[0]
        k = zeros_for(batch_size, input_signature[0])
        v = zeros_for(batch_size, input_signature[1])
        return k, v, jnp.array(0)
Ejemplo n.º 20
0
    def init_weights_and_state(self, input_signature):
        """Returns newly initialized weights for this layer.

    Weights is a single  `w` tensor with previously specified shape.

    Args:
      input_signature: `ShapeDtype` instance characterizing the input this layer
          should compute on. Unused.
    """
        del input_signature  # Unused.
        self.weights = ()
        self.state = {self._name: jnp.array(0.)}
Ejemplo n.º 21
0
def _fast_inference_init_state(input_signature, buffer_length):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        d_feature = shape[-1]
        return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    mask = jnp.zeros((batch_size, 1, buffer_length))
    return (k, v, mask, jnp.array(0))
Ejemplo n.º 22
0
 def init_weights_and_state(self, input_signature):
   d_feature = input_signature.shape[-1]
   pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
   position = np.arange(0, self._max_len)[:, np.newaxis]
   div_term = np.exp(
       np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
   pe[:, 0::2] = np.sin(position * div_term)
   pe[:, 1::2] = np.cos(position * div_term)
   pe = pe[np.newaxis, :, :]  # [1, self._max_len, d_feature]
   self.weights = jnp.array(pe)  # Trainable parameters, initialized above.
   if self._mode == 'predict':
     batch_size = input_signature.shape[0]
     self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
Ejemplo n.º 23
0
def next_symbol(cur_output_tokens, model, d_model):
    """Returns the next symbol for a given sentence.

    Args:
        cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
        model (trax.layers.combinators.Serial): The transformer model.

    Returns:
        int: tokenized symbol.
    """

    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))
    assert(padded_length<=d_model*8, 'the d_features is 512 char, or 4096 bit maximum') #assuming 8bit char
    #print('token length: {}, padded length: {}'.format(token_length, padded_length))
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] # Don't replace this 'None'! This is a way of setting the batch dim

    output, _ = model((jnp.array(padded_with_batch), jnp.array(padded_with_batch)))
    log_probs = output[0, len(cur_output_tokens), :]

    return int(np.argmax(log_probs))
Ejemplo n.º 24
0
  def _fast_inference_init_state(self, input_signature):
    """Returns an initial state for causal attention layer fast inference."""

    def zeros_for_shape(bs, tokens_len, shape_dtype):
      shape, dtype = shape_dtype.as_tuple()
      d_feature = shape[-1]

      return jnp.zeros((bs, tokens_len, d_feature), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    n_tokens = self._chunk_len if self._chunk_len is not None else self._max_len
    k = zeros_for_shape(batch_size, n_tokens, input_signature[0])
    v = zeros_for_shape(batch_size, n_tokens, input_signature[1])
    return k, v, jnp.array(0)
Ejemplo n.º 25
0
    def __init__(self,
                 mode=None,
                 learn_epsilon=False,
                 init_epsilon=1e-6,
                 init_learnt_epsilon=1e-4):
        super().__init__()

        del mode

        # If we learn epsilon then epsilon = init_epsilon + |learnt_value|
        # where learnt_value is initialized to init_learnt_epsilon.
        # If learn_epsilon is false then epsilon is just init_epsilon.
        #
        # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`.
        self._learn_epsilon = learn_epsilon

        # TODO(jonni): Replace asserts with ValueError.
        assert init_epsilon > 0
        assert init_learnt_epsilon > 0

        self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32)
        self._init_learnt_epsilon = jnp.array(init_learnt_epsilon,
                                              dtype=jnp.float32)
Ejemplo n.º 26
0
def predict(num_chars, prefix):
    inp = [ord(c) for c in prefix]
    result = [c for c in prefix]
    max_len = len(prefix) + num_chars
    for _ in range(num_chars):
        cur_inp = np.array(inp + [0] * (max_len - len(inp)))
        outp = model(cur_inp[None, :])  # Add batch dim.
        next_char = gumbel_sample(outp[0, len(inp)])
        inp += [int(next_char)]

        if inp[-1] == 1:
            break  # EOS
        result.append(chr(int(next_char)))

    return "".join(result)
Ejemplo n.º 27
0
  def init_weights_and_state(self, input_signature):
    weights = []
    states = []
    # In the code below, stack, inputs, and outputs are abstract (shapes and
    # dtypes), but weights and states are non-abstract actual values.
    stack = input_signature
    for sublayer in self.sublayers:
      inputs = _inputs_from_stack(sublayer, stack)
      weights_or_cache_marker, state_or_cache_marker = (
          sublayer.init(inputs, use_cache=True))
      outputs, _ = sublayer._forward_abstract(inputs)
      stack = _outputs_onto_stack(sublayer, outputs, stack)

      weights.append(weights_or_cache_marker)
      states.append(state_or_cache_marker)
    self.state = (jnp.array(0, dtype=jnp.int32), states)
    self.weights = weights
Ejemplo n.º 28
0
 def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots = weights_and_slots
   rng, subrng = jax_random.split(rng)
   grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: (  # pylint: disable=g-long-lambda
           fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')),
       grads)
   new_weights, new_slots, stats = optimizer.tree_update(
       i, grads, weights, slots, opt_params)
   return (new_weights, new_slots), stats, state, subrng
Ejemplo n.º 29
0
  def _multi_device_update_fn(
      weights_and_slots, step, opt_params, batch, state, rng):
    # We assume all tensors have the first dimension = n_devices.
    weights, slots = weights_and_slots
    (loss, state), gradients = forward_and_backward_fn(
        batch, weights, state, rng)

    # gradients now need to be summed over all the devices across different host
    # machines, n_devices is only the number of devices on *this* host machine.
    gradients = fastmath.psum(gradients, 'batch')
    n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
    # Average across hosts.
    gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total, gradients)

    weights, slots, stats = optimizer.tree_update(
        step, gradients, weights, slots, opt_params)
    stats['loss'] = loss
    return (weights, slots), state, stats
Ejemplo n.º 30
0
def _make_weights_and_state_same_across_hosts(weights_and_state):
  """Makes train and eval model's weights and state the same across hosts."""

  # We assume that they have been already replicated, i.e the leading axis is
  # self._n_devices

  # This is the total number of devices across all hosts.
  n_devices_total = fastmath.psum(jnp.array(1.0), 'devices')

  # This sums up the weights and state across all devices.
  # NOTE: There will not be any leading axis remaining because we psum
  # over it.
  weights_and_state = fastmath.psum(weights_and_state, 'devices')

  # We finally take the average over all devices.
  weights_and_state = jax.tree_util.tree_map(
      lambda ws: ws / n_devices_total, weights_and_state)

  return weights_and_state