예제 #1
0
def set_sharding_annotations_v1(task_p: InstantiableParams,
                                mesh_shape: Sequence[int]) -> None:
  """Sets the sharding annotations in the task config for the given mesh.

  Args:
    task_p: The task parameters to update with sharding annotations.
    mesh_shape: a 3D sequence representing the mesh shape.
  """
  model_p = task_p.model
  asserts.eq(len(mesh_shape), 3)
  device_count = np.prod(mesh_shape)
  device_ids_mesh = np.arange(device_count).reshape(mesh_shape)
  model_p.device_mesh = device_ids_mesh
  replica_axis = 'replica'
  data_axis = 'data'
  mdl_axis = 'mdl'
  mesh_axis_names = [replica_axis, data_axis, mdl_axis]
  task_p.train.inputs_split_mapping = NestedMap(
      map_1d=((replica_axis, data_axis),),
      map_2d=((replica_axis, data_axis), None))
  model_p.mesh_axis_names = mesh_axis_names
  if hasattr(model_p, 'lm'):
    model_p.lm = model_p.lm.cls.set_sharding_params_v1(
        model_p.lm,
        replica_axis=replica_axis,
        data_axis=data_axis,
        mdl_axis=mdl_axis,
        device_ids_mesh=device_ids_mesh,
        mesh_axis_names=mesh_axis_names)
예제 #2
0
    def fprop(self,
              inputs: JTensor,
              paddings: Optional[JTensor] = None) -> JTensor:
        """Apply batch normalization.

    Args:
      inputs: The inputs JTensor. Shaped [..., dim].
      paddings: The paddings JTensor. Shaped [..., 1].

    Returns:
      Output after applying batch normalization, with the same shape as
        'inputs'.
    """
        p = self.params
        inputs, paddings = self._cast_to_fprop_dtype((inputs, paddings))
        if paddings is None:
            paddings = self._get_default_paddings(inputs)

        asserts.eq(inputs.ndim, paddings.ndim)
        asserts.eq(paddings.shape[-1], 1)

        norm_mean, norm_variance, beta, gamma = self.compute_and_update_moments(
            inputs, paddings)

        inv = gamma / jnp.sqrt(norm_variance + self._epsilon)
        bn_output = (inputs - norm_mean) * inv + beta

        if p.set_padded_output_to_zero:
            bn_output *= 1.0 - paddings

        return bn_output
예제 #3
0
    def fprop(self,
              inputs: JTensor,
              paddings: Optional[JTensor] = None) -> JTensor:
        """Applies group normalization.

    Args:
      inputs: The inputs JTensor. Shaped [batch_size, height, width, channel] if
        p.rank == 4, else [batch, height, channel].
      paddings: The paddings JTensor. Shaped [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      Output after applying group normalization, with the same shape as
        'inputs'. Or an output, output_paddings pair if input paddings is not
        None.
    """
        p = self.params
        inputs, paddings = self._cast_to_fprop_dtype((inputs, paddings))
        asserts.eq(inputs.ndim, p.input_rank)

        x = jnp.reshape(
            inputs,
            list(inputs.shape[:-1]) + [self.num_groups, self.group_size])
        expanded_rank = p.input_rank + 1
        all_dims = list(range(expanded_rank))
        if paddings is None or not p.cumulative:
            # Skips batch and num_groups.
            reduce_over_dims = all_dims[1:-2] + all_dims[-1:]
        else:
            # Skips batch, seqlen and num_groups.
            reduce_over_dims = all_dims[2:-2] + all_dims[-1:]

        if paddings is None and not p.cumulative:
            group_mean = jnp.mean(x, axis=reduce_over_dims, keepdims=True)
            group_variance = jnp.mean(
                jnp.square(x - jax.lax.stop_gradient(group_mean)),
                axis=reduce_over_dims,
                keepdims=True)
        else:
            expanded_paddings = jnp.reshape(
                paddings,
                list(inputs.shape[:2]) + [1] * (expanded_rank - 2))
            group_mean, group_variance = compute_moments(
                x,
                expanded_paddings,
                reduce_over_dims,
                cumulative_axis=1,
                enable_cross_replica_sum_on_tpu=p.
                enable_cross_replica_sum_on_tpu,
                keepdims=True)

        outputs = self._normalize(x, group_mean, group_variance)

        if paddings is None:
            return outputs
        else:
            return outputs, paddings
예제 #4
0
def compute_moments(
    inputs: JTensor,
    padding: JTensor,
    reduce_over_dims: List[int],
    cumulative_axis: Optional[int] = None,
    enable_cross_replica_sum_on_tpu: bool = False,
    keepdims: bool = False,
) -> Tuple[JTensor, JTensor]:
    """Computes mean and variance over the valid data points in inputs.

  Args:
    inputs: The inputs JTensor.
    padding: The paddings JTensor.
    reduce_over_dims: A sequence of ints for dimensions to reduce `inputs` over.
    cumulative_axis: An optional int for axis to compute a cumulative sum. If
      none, there will be no cumulative sum applied.
    enable_cross_replica_sum_on_tpu: A boolean indicating whether to use an
      all-reduce sum over the 'batch' axis.
    keepdims: A boolean indicating whether summations reduction axes should be
      left in the result as dimensions with size one.

  Returns:
    Tuple of (mean, variance).
  """
    asserts.eq(inputs.ndim, padding.ndim)
    rank = inputs.ndim
    for dim in reduce_over_dims:
        asserts.between(dim, 0, rank, left_strict=False, right_strict=True)
    mask = 1.0 - padding
    sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=keepdims)
    count_v = jnp.sum(jnp.ones_like(inputs) * mask,
                      axis=reduce_over_dims,
                      keepdims=keepdims)
    if cumulative_axis is not None:
        sum_v = jnp.cumsum(sum_v, axis=cumulative_axis)
        count_v = jnp.cumsum(count_v, axis=cumulative_axis)

    if enable_cross_replica_sum_on_tpu:
        # TODO(shafey, yonghui): Fetch axis_name from globals.
        sum_v = jax.lax.psum(sum_v, axis_name='batch')
        count_v = jax.lax.psum(count_v, axis_name='batch')

    count_v = jnp.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = jnp.sum((inputs - mean) * (inputs - mean) * mask,
                     axis=reduce_over_dims,
                     keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = jnp.cumsum(sum_vv, axis=cumulative_axis)

    if enable_cross_replica_sum_on_tpu:
        # TODO(shafey, yonghui): Fetch axis_name from globals.
        sum_vv = jax.lax.psum(sum_vv, axis_name='batch')

    variance = sum_vv / count_v
    return mean, variance
예제 #5
0
def save_checkpoint(
    train_state: train_states.TrainState,
    checkpoint_dir: str,
    overwrite: bool = False,
    unreplicate: bool = True,
    checkpoint_type: CheckpointType = CheckpointType.CHECKPOINT_FLAX,
    state_specs: Optional[train_states.TrainState] = None,
) -> None:
  """Saves a checkpoint into the provided base directory.

  This is typically called on a replicated TrainState instance.

  Args:
    train_state: The TrainState instance to save.
    checkpoint_dir: The base directory from where to retrieve checkpoints.
    overwrite: Whether to overwrite existing checkpoints files if a checkpoint
      at the current or a later step already exists.
    unreplicate: Whether to unreplicate variables (Optional). If using SPMD
      sharding, then this should be set to False.
    checkpoint_type: The checkpoint type (implementation) to save. Either
      `CHECKPOINT_FLAX`, `CHECKPOINT_MULTI_HOST_FLAX`, `CHECKPOINT_GDA` or
      `CHECKPOINT_PERSISTENCE`.
    state_specs: Currently unused.

  Raises:
    ValueError: If the global step has an unexpected shape, if `state_specs`
    is not specified for persistence-based checkpointing or if
    `checkpoint_type` is invalid.
  """
  del state_specs

  if jax.config.jax_parallel_functions_output_gda:
    asserts.eq(checkpoint_type, CheckpointType.CHECKPOINT_GDA)
    step = int(jax.device_get(py_utils.maybe_unreplicate_gda(train_state.step)))
    _save_checkpoint_gda(train_state, checkpoint_dir, overwrite, step)
    return

  if train_state.step.ndim == 0:
    step = jax.device_get(train_state.step)
  elif train_state.step.ndim == 1:
    step = jax.device_get(train_state.step[0])
  else:
    raise ValueError(
        f'Expecting a replicated 1D global step (got `{train_state.step.ndim}`).'
    )

  if checkpoint_type in {
      CheckpointType.CHECKPOINT_FLAX, CheckpointType.CHECKPOINT_MULTI_HOST_FLAX
  }:
    use_multi_host = (
        checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX)
    _save_checkpoint_flax(train_state, checkpoint_dir, overwrite, unreplicate,
                          step, use_multi_host)
  else:
    raise ValueError(f'Unexpected checkpoint_type `{checkpoint_type}`.')
예제 #6
0
def retrieve_checkpoint_type(multi_host_checkpointing: bool,
                             maybe_use_persistence_checkpointing,
                             task_p: InstantiableParams) -> CheckpointType:
  """Retrieves the CheckpointType given the input arguments."""
  if jax.config.jax_parallel_functions_output_gda:
    asserts.eq(multi_host_checkpointing, True)
    checkpoint_type = CheckpointType.CHECKPOINT_GDA
  elif maybe_use_persistence_checkpointing and task_p.model.device_mesh is not None:
    asserts.eq(multi_host_checkpointing, False)
    checkpoint_type = CheckpointType.CHECKPOINT_PERSISTENCE
  else:  # Flax-based checkpointing
    if multi_host_checkpointing:
      checkpoint_type = CheckpointType.CHECKPOINT_MULTI_HOST_FLAX
    else:
      checkpoint_type = CheckpointType.CHECKPOINT_FLAX
  return checkpoint_type
예제 #7
0
    def __init__(self, params):
        """Initializes GroupNorm layer and checks parameters."""
        super().__init__(params)
        p = self.params
        asserts.not_none(p.name)
        asserts.gt(p.num_groups, 0)
        asserts.gt(p.min_group_size, 0)
        asserts.le(p.min_group_size, p.dim)
        asserts.eq(p.dim % p.min_group_size, 0)

        if p.dim >= p.num_groups:
            asserts.eq(
                p.dim % p.num_groups,
                0,
                msg='p.dim({0}) is not dividable by p.num_groups({1})'.format(
                    p.dim, p.num_groups))

        asserts.in_set(p.input_rank, (3, 4))
예제 #8
0
def restore_checkpoint(
    train_state: Optional[train_states.TrainState],
    checkpoint_dir: str,
    global_mesh: Optional[maps.Mesh] = None,
    checkpoint_type: CheckpointType = CheckpointType.CHECKPOINT_FLAX,
    state_specs: Optional[train_states.TrainState] = None,
    step: Optional[int] = None) -> train_states.TrainState:
  """Restores a checkpoint from the provided base directory.

  This is typically called on an unreplicated TrainState instance.

  Args:
    train_state: The TrainState instance to restore.
    checkpoint_dir: The base directory from where to retrieve checkpoints.
    global_mesh: The global mesh representing devices across multiple processes.
    checkpoint_type: The checkpoint type (implementation) to restore. Either
      `CHECKPOINT_FLAX`, `CHECKPOINT_MULTI_HOST_FLAX`, `CHECKPOINT_GDA` or
      `CHECKPOINT_PERSISTENCE`.
    state_specs: If using a GDA-based checkpoint, the partition specs
      corresponding to this TrainState instance to restore.
    step: Step number to load a checkpoint from or None to load the latest.

  Returns:
    A restored `TrainState` instance.

  Raises:
    ValueError: When a mismatch between the current checkpoint structure and
    the saved checkpoint one is detected.
  """
  if jax.config.jax_parallel_functions_output_gda:
    asserts.eq(checkpoint_type, CheckpointType.CHECKPOINT_GDA)
    return _restore_checkpoint_gda(train_state, checkpoint_dir, global_mesh,
                                   state_specs, step)

  if train_state is not None and train_state.step.ndim != 0:
    raise ValueError('Expecting an unreplicated scalar global step (got '
                     f'`{train_state.step.ndim}`).')

  if checkpoint_type in {
      CheckpointType.CHECKPOINT_FLAX, CheckpointType.CHECKPOINT_MULTI_HOST_FLAX
  }:
    return _restore_checkpoint_flax(train_state, checkpoint_dir, step)
  else:
    raise ValueError(f'Unexpected checkpoint_type `{checkpoint_type}`.')
예제 #9
0
    def fprop(self, state0: NestedMap,
              inputs: NestedMap) -> Tuple[NestedMap, NestedMap]:
        """Forward function.

    `_reset_state` is optionally applied if `reset_cell_state` is True. The RNN
    layer should provide `reset_mask` inputs in addition to other inputs.
    `reset_mask` inputs are expected to be 0 at timesteps where state0 should be
    reset to default (zeros) before running `fprop`, and 1
    otherwise. This is meant to support use cases like packed inputs, where
    multiple samples are fed in a single input example sequence, and need to be
    masked from each other. For example, if the two examples packed together
    are ['good', 'day'] -> ['guten-tag'] and ['thanks'] -> ['danke']
    to produce ['good', 'day', 'thanks'] -> ['guten-tag', 'danke'], the
    source reset_masks would be [1, 1, 0] and target reset masks would be
    [1, 0]. These ids are meant to enable masking computations for
    different examples from each other.

    Args:
      state0: The previous recurrent state.
      inputs: The inputs to the cell.

    Returns:
      state1: The next recurrent state.
      extras: Intermediate results to faciliate backprop.
    """
        asserts.instance(inputs.act, list)
        asserts.eq(self.params.inputs_arity, len(inputs.act))

        state0 = self._maybe_reset_state(state0, inputs)

        concat = jnp.concatenate(inputs.act + [state0.m], 1)
        wm = self.local_theta().wm
        xmw = jnp.einsum('bd,dc->bc', concat, wm)

        i_i, i_g, f_g, o_g = self._retrieve_and_split_gates(xmw)
        state1 = self._gates_internal(state0, inputs, i_i, i_g, f_g, o_g)
        return state1, NestedMap()
예제 #10
0
 def test_eq_raises(self, value1, value2):
     with self.assertRaisesRegex(
             ValueError,
             f'`value1={value1}` must be equal to `value2={value2}`.$'):
         asserts.eq(value1, value2)
     with self.assertRaisesRegex(
             ValueError,
             f'`custom_value={value1}` must be equal to `value2={value2}`.$'
     ):
         asserts.eq(value1, value2, value_str1=f'custom_value={value1}')
     with self.assertRaisesRegex(
             ValueError,
             f'`value1={value1}` must be equal to `custom_value={value2}`.$'
     ):
         asserts.eq(value1, value2, value_str2=f'custom_value={value2}')
     custom_error_msg = 'This is a custom error message.'
     with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'):
         asserts.eq(value1, value2, msg=custom_error_msg)
예제 #11
0
 def test_eq(self, value1):
     value2 = value1
     asserts.eq(value1, value2)
예제 #12
0
def ctc_loss(logits: JTensor,
             logitpaddings: JTensor,
             labels: JTensor,
             labelpaddings: JTensor,
             blank_id: int = 0) -> Tuple[JTensor, Mapping[str, JTensor]]:
  """Computes CTC loss.

  This function performs forward computation over an FSA with `N * 2` states
  where `N` is the max number of labels. The states are split into two groups:
  Phi states and emission states. a phi-state accepts repetition of
  phi (blank)-symbols and transits to emission state when the correct label is
  observed. An emission state accepts repetition of the label and transits to
  the next phi states at any time (so called epsilon-transition).
  Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
  and `N` denotes the time steps in `labels`.

  Args:
    logits: (B, T, K)-array containing log-probabilities of each class.
    logitpaddings: (B, T)-array. Padding indicators for `logits`.
    labels: (B, N)-array containing reference integer labels.
    labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
      `labels` must be right-padded, i.e. each row of `labelpaddings` must be
      repetition of zeroes, followed by repetition of ones.
    blank_id: Id for blank token.

  Returns:
    A pair of `(per_seq_loss, aux)`.
    per_seq_loss:
      (B,)-array containing loss values for each sequence in the batch.
    aux: Dictionary containing interim variables used for computing losses.
      aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
        phi-state corresponding to the n-th label.
      aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
        emission-state corresponding to the n-th label.
      aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
        corresponding to each time frame.
      aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
        corresponding to each time frame.
  """
  batchsize, unused_maxinputlen, num_classes = logits.shape
  batchsize_, maxlabellen = labels.shape
  asserts.eq(batchsize, batchsize_)

  logprobs = jax.nn.log_softmax(logits)
  labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)

  # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
  repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
  repeat = jnp.pad(repeat, ((0, 0), (0, 1)))

  logprobs_phi = logprobs[:, :, blank_id:blank_id + 1]  # [B, T, 1]
  logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2))  # [T, B, 1]

  one_hot = jax.nn.one_hot(labels, num_classes=num_classes)  # [B, N, K]
  logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot)
  logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2))  # [T, B, N]

  logalpha_phi_init = jnp.ones(
      (batchsize, maxlabellen + 1)) * _LOGEPSILON  # [B, N]
  logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
  logalpha_emit_init = jnp.ones(
      (batchsize, maxlabellen)) * _LOGEPSILON  # [B, N]

  def loop_body(prev, x):
    prev_phi, prev_emit = prev
    # emit-to-phi epsilon transition, except if the next label is repetition
    prev_phi_orig = prev_phi
    prev_phi = prev_phi.at[:, 1:].set(
        jnp.logaddexp(prev_phi[:, 1:], prev_emit + _LOGEPSILON * repeat))

    logprob_emit, logprob_phi, pad = x

    # phi-to-emit transition
    next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit,
                              prev_emit + logprob_emit)
    # self-loop transition
    next_phi = prev_phi + logprob_phi
    # emit-to-phi blank transition only when the next label is repetition
    next_phi = next_phi.at[:, 1:].set(
        jnp.logaddexp(next_phi[:, 1:],
                      prev_emit + logprob_phi + _LOGEPSILON * (1.0 - repeat)))

    pad = pad.reshape((batchsize, 1))
    next_emit = pad * prev_emit + (1.0 - pad) * next_emit
    next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi

    return (next_phi, next_emit), (next_phi, next_emit)

  xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
  _, (logalpha_phi,
      logalpha_emit) = jax.lax.scan(loop_body,
                                    (logalpha_phi_init, logalpha_emit_init), xs)

  # last row needs to be updated with the last epsilon transition
  logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(
      jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
  logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)

  # extract per_seq_loss
  one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1)  # [B, N+1]
  per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot)

  return per_seq_loss, {
      'logalpha_phi': logalpha_phi,
      'logalpha_emit': logalpha_emit,
      'logprobs_phi': logprobs_phi,
      'logprobs_emit': logprobs_emit,
  }