예제 #1
0
 def extend_step_fn(states, ids):
   with base_layer.JaxContext.new_context(
       prng_key=base_layer.next_prng_key(),
       global_step=global_step) as jax_context:
     jax_context.bind(self.model, self.model.vars_to_flax_vars(model_theta),
                      [base_layer.SCOPE_AUX_LOSS])
     new_states, xent = self.model.extend_step(states, ids)
     return new_states, xent.logits
예제 #2
0
 def _dropout(self, inputs: JTensor, noise_shape: List[int]) -> JTensor:
     p = self.params
     if noise_shape is None:
         noise_shape = inputs.shape
     prng_key = base_layer.next_prng_key()
     keep_prob = p.keep_prob
     assert keep_prob > 0.0
     random_nums = keep_prob + jax.random.uniform(
         prng_key, noise_shape, inputs.dtype, minval=0.0, maxval=1.0)
     binary_mask = jnp.floor(random_nums)
     return inputs * binary_mask / keep_prob
예제 #3
0
  def decode(self, input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
    """Greedy decodes the input_batch.

    Args:
      input_batch: The input batch, with fields like `.ids`.

    Returns:
      - metrics, a NestedMap containing str keys and (metrics, weight) pairs.
      - A NestedMap like `input_batch`, with `.prefix_lengths` (vector of
        randomly generated ints indicating the lengths of prefixes for each
        row), and `.output_ids` (matrix of int ids with the decoded output).
    """
    p = self.params
    if p.decoder.seqlen <= 0:
      raise ValueError('Must set p.decoder.seqlen > 0, current value = '
                       f'{p.decoder.seqlen}')
    batch_size = input_batch.ids.shape[0]
    maxval = jnp.sum(1 - input_batch.paddings, axis=1).astype(jnp.int32)
    minval = jnp.minimum(maxval, p.decoder.min_prefix_len)
    prefix_lengths = jax.random.randint(base_layer.next_prng_key(),
                                        [batch_size], minval, maxval + 1,
                                        input_batch.ids.dtype)
    decoder_state = self.lm.init_states(
        target_batch_size=batch_size,
        target_max_length=p.decoder.seqlen)

    global_step = base_layer.cur_global_step()

    lm_theta = self.lm.local_theta()
    def extend_step_fn(states, ids):
      with base_layer.JaxContext.new_context(
          prng_key=base_layer.next_prng_key(),
          global_step=global_step) as jax_context:
        jax_context.bind(self.lm, self.lm.vars_to_flax_vars(lm_theta),
                         [base_layer.SCOPE_AUX_LOSS])
        new_states, xent = self.lm.extend_step(states, ids)
        return new_states, xent.logits

    result = greedy_decode(
        extend_step_fn,
        decoder_state,
        input_batch.ids,
        input_batch.paddings,
        p.decoder.seqlen,
        max_decode_steps=p.decoder.max_decode_steps,
        prefix_lengths=prefix_lengths,
        eos_id=p.decoder.eos_id)
    result.update(input_batch)

    metrics = NestedMap(
        num_decoded=(jnp.array(0.0, jnp.float32),
                     jnp.array(batch_size, jnp.float32)))
    return metrics, result
예제 #4
0
    def fprop(self, inputs: JTensor,
              paddings: JTensor) -> Tuple[JTensor, JTensor]:
        """Applies data augmentation by randomly masking values in the spectrum.

    Args:
      inputs: A tensor of shape [batch, length, channels].
      paddings: A 0/1 tensor of shape [batch, length].

    Returns:
      A pair <new_inputs, mask>:
      new_inputs: A tensor of shape [batch, length, channels].
      paddings: A 0/1 tensor of shape [batch, length].
    """
        lengths = jnp.einsum('bh->b', 1 - paddings).astype(jnp.int32)

        prng_key = base_layer.next_prng_key()
        inputs = self._time_mask(inputs, lengths, global_seed=prng_key)
        prng_key = base_layer.next_prng_key()
        inputs = self._frequency_mask(inputs, global_seed=prng_key)

        return inputs, paddings
예제 #5
0
    def _apply_zoneout(self, state0: NestedMap, inputs: NestedMap,
                       new_c: JTensor, new_m: JTensor) -> NestedMap:
        """Apply Zoneout and returns the updated states."""
        p = self.params

        if p.zo_prob > 0.0:
            c_random_uniform = jax.random.uniform(base_layer.next_prng_key(),
                                                  new_c.shape)
            m_random_uniform = jax.random.uniform(base_layer.next_prng_key(),
                                                  new_m.shape)
        else:
            c_random_uniform = None
            m_random_uniform = None

        new_c = self._zoneout_internal(state0.c, new_c, inputs.padding,
                                       p.zo_prob, self.do_eval,
                                       c_random_uniform)
        new_m = self._zoneout_internal(state0.m, new_m, inputs.padding,
                                       p.zo_prob, self.do_eval,
                                       m_random_uniform)

        return NestedMap(m=new_m, c=new_c)
예제 #6
0
    def body_fprop(self, per_stage_inputs: JTensor, *per_stage_args,
                   **per_stage_kwargs) -> NestedJTensor:
        """Runs the fprop function of the stages."""
        p = self.params
        if p.mesh_axis_names is not None:

            def annotate(x):
                unconstrained_dims = list(range(1, x.ndim))
                dims_mapping = (p.weight_split_dims_mapping.stages + [None] *
                                (x.ndim - 1))
                return base_layer.maybe_shard(x, dims_mapping,
                                              p.mesh_axis_names,
                                              unconstrained_dims)

            per_stage_inputs = jax.tree_map(annotate, per_stage_inputs)
            per_stage_args = jax.tree_map(annotate, per_stage_args)
            per_stage_kwargs = jax.tree_map(annotate, per_stage_kwargs)

        prng_key = base_layer.next_prng_key()
        global_step = base_layer.cur_global_step()

        # vmap self.body.fprop to get a leading stage dimension to handle per_stage
        # inputs and args.
        def _wrapped_fn(theta, per_stage_inputs, *per_stage_args,
                        **per_stage_kwargs):
            with base_layer.JaxContext.new_context(
                    prng_key=prng_key, global_step=global_step) as jax_ctx:
                jax_ctx.bind(self.body, self.body.vars_to_flax_vars(theta),
                             [base_layer.SCOPE_AUX_LOSS])
                res = self.body.fprop(per_stage_inputs, *per_stage_args,
                                      **per_stage_kwargs)
                summaries = base_layer.all_summaries()
                return res, summaries

        res, summaries = jax.vmap(_wrapped_fn)(self.body.local_theta(),
                                               per_stage_inputs,
                                               *per_stage_args,
                                               **per_stage_kwargs)

        self._forward_summary(summaries)
        return res
예제 #7
0
    def _drop_connect(self, inputs: JTensor) -> JTensor:
        """Drops the entire residual layer with given survival probability.

    Args:
      inputs: input `.JTensor` which is on the residual branch which is dropped.

    Returns:
      Dropped out inputs.
    """
        if self.do_eval:
            return inputs

        # Compute tensor.
        prng_key = base_layer.next_prng_key()
        batch_size = inputs.shape[0]
        shape = [batch_size] + [1] * (len(inputs.shape) - 1)
        random_tensor = self.params.survival_prob + jax.random.uniform(
            prng_key, shape, dtype=inputs.dtype)
        binary_tensor = jnp.floor(random_tensor)
        # Unlike conventional way that multiply survival_prob at test time, here we
        # divide survival_prob at training time, such that no additional compute is
        # needed at test time.
        output = inputs / self.params.survival_prob * binary_tensor
        return output
예제 #8
0
def recurrent_func(theta: NestedMap, states_0: NestedMap, inputs: NestedMap,
                   cell_fn: Callable[[NestedMap, NestedMap, NestedMap],
                                     NestedMap]):
    """Computes a recurrent neural net.

  Args:
    theta: weights. A `.NestedMap`.
    states_0: initial state. A `.NestedMap`.
    inputs: inputs. A `.NestedMap`.
    cell_fn: A python function which computes::
        states_1 = cell_fn(theta, states_0, inputs[t, :])

  Returns:
    `accumulate_state` and the final state.
  """
    input_seq_len = inputs.Flatten()[0].shape[0]

    def assert_not_none(x):
        assert x is not None

    tf.nest.map_structure(assert_not_none, states_0)
    tf.nest.map_structure(assert_not_none, inputs)
    tf.nest.map_structure(assert_not_none, theta)

    def new_cum_state(x):
        x1 = jnp.expand_dims(x, 0)
        # +1 so that we can store initial_states at position 0.
        return jnp.tile(x1, [input_seq_len + 1] + [1] * x.ndim)

    cumulative_states = states_0.Transform(new_cum_state)

    prng_key = base_layer.next_prng_key()
    global_step = base_layer.cur_global_step()

    start_time = jnp.array(0, dtype=jnp.uint32)
    fwd_initial_loop_vars = NestedMap(cur_time=start_time,
                                      theta=theta,
                                      states_0=states_0,
                                      cumulative_states=cumulative_states,
                                      inputs=inputs)

    def same_type_shape(x, y):
        assert x.dtype == y.dtype, (x.dtype, y.dtype)
        assert x.shape == y.shape, (x.shape, y.shape)

    def wrapped_cell_fn(fn_in):
        # fn_in is NestedMap containing the following elements:
        #    - t
        #    - theta
        #    - states_0
        #    - inputs_t
        # Start a chain of prng key that also takes into account of time steps.
        t = fn_in.t
        theta = fn_in.theta
        states_0 = fn_in.states_0
        inputs_t = fn_in.inputs_t
        with base_layer.JaxContext.new_context(prng_key=jax.random.fold_in(
                prng_key, t),
                                               global_step=global_step):
            # NO side-effect ops are allowed as the enclosing JaxContext is not bound
            # to any layer.
            states_1 = cell_fn(theta, states_0, inputs_t)

            tf.nest.assert_same_structure(states_0, states_1)
            tf.nest.map_structure(same_type_shape, states_0, states_1)
        return states_1

    def wrapped_cell_fn_grad(fn_in, d_fn_out):
        # This is roughly the following:
        #
        # fn_out = wrapped_cell_fn(fn_in)
        # d_fn_in = tf.gradient(fn_out, fn_in, d_fn_out)
        # return d_fn_in
        #
        assert isinstance(fn_in, NestedMap)
        fn_out, vjp_fn = jax.vjp(wrapped_cell_fn, fn_in)
        del fn_out
        d_fn_in = vjp_fn(d_fn_out)
        assert isinstance(d_fn_in, tuple)
        assert len(d_fn_in) == 1
        d_fn_in_0 = d_fn_in[0]
        # Over-write gradient for t, the time step.
        d_fn_in_0.t = jnp.zeros_like(fn_in.t)
        tf.nest.assert_same_structure(fn_in, d_fn_in_0)
        tf.nest.map_structure(same_type_shape, fn_in, d_fn_in_0)
        return d_fn_in_0

    def fwd_comp_fn(loop_vars):
        # loop_vars is a NestedMap containing the following elements:
        #   - cur_time
        #   - theta
        #   - inputs
        #   - cumulative_states
        #   - states_0
        t = loop_vars.cur_time
        theta = loop_vars.theta
        inputs = loop_vars.inputs
        cumulative_states = loop_vars.cumulative_states
        states_0 = loop_vars.states_0
        inputs_t = inputs.Transform(lambda x: x[t])

        states_1 = wrapped_cell_fn(
            NestedMap(t=t, theta=theta, states_0=states_0, inputs_t=inputs_t))

        def set_t(x, x_t):
            return x.at[t + 1].set(x_t)

        cumulative_states = tf.nest.map_structure(set_t, cumulative_states,
                                                  states_1)
        loop_out = NestedMap(cur_time=t + 1,
                             theta=theta,
                             inputs=inputs,
                             states_0=states_1,
                             cumulative_states=cumulative_states)
        return loop_out

    def fwd_continue_fn(loop_vars):
        return loop_vars.cur_time < input_seq_len

    # This custom_vjp implementation follows examples here:
    # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
    @jax.custom_vjp
    def fwd_loop(loop_vars):
        final_loop_vars = jax.lax.while_loop(fwd_continue_fn, fwd_comp_fn,
                                             loop_vars)
        return NestedMap(final_states=final_loop_vars.states_0,
                         cumulative_states=final_loop_vars.cumulative_states)

    def loop_fn_vjp_fwd(loop_vars):
        loop_fn_out = fwd_loop(loop_vars)
        return loop_fn_out, (loop_vars, loop_fn_out.cumulative_states)

    def loop_fn_vjp_bwd(res, d_out):
        fwd_loop_vars, cumulative_states = res
        d_final_states = d_out.final_states
        d_cumulative_states = d_out.cumulative_states

        start_time = input_seq_len - 1
        d_states_1 = tf.nest.map_structure(lambda x, y: x[start_time + 1] + y,
                                           d_cumulative_states, d_final_states)
        bwd_loop_vars = NestedMap(
            cur_time=start_time,
            theta=fwd_loop_vars.theta,
            inputs=fwd_loop_vars.inputs,
            cumulative_states=cumulative_states,
            d_cumulative_states=d_cumulative_states,
            d_theta=fwd_loop_vars.theta.Transform(jnp.zeros_like),
            d_inputs=fwd_loop_vars.inputs.Transform(jnp.zeros_like),
            d_states_1=d_states_1)

        def bwd_comp_fn(loop_vars):
            t = loop_vars.cur_time
            inputs = loop_vars.inputs
            inputs_t = inputs.Transform(lambda x: x[t])
            states_0 = loop_vars.cumulative_states.Transform(lambda x: x[t])
            d_cell_in = wrapped_cell_fn_grad(
                NestedMap(t=t,
                          theta=loop_vars.theta,
                          states_0=states_0,
                          inputs_t=inputs_t), loop_vars.d_states_1)
            d_theta = tf.nest.map_structure(lambda x, y: x + y,
                                            loop_vars.d_theta, d_cell_in.theta)
            d_states_0 = tf.nest.map_structure(lambda x, y: x + y[t],
                                               d_cell_in.states_0,
                                               loop_vars.d_cumulative_states)

            def set_t(x, x_t):
                return x.at[t].set(x_t)

            d_inputs = tf.nest.map_structure(set_t, loop_vars.d_inputs,
                                             d_cell_in.inputs_t)
            loop_vars_out = loop_vars.Transform(lambda x: x)
            loop_vars_out.d_inputs = d_inputs
            loop_vars_out.d_states_1 = d_states_0
            loop_vars_out.d_theta = d_theta
            loop_vars_out.cur_time = t - 1
            return loop_vars_out

        def bwd_continue_fn(loop_vars):
            return loop_vars.cur_time >= 0

        bwd_final_loop_vars = jax.lax.while_loop(bwd_continue_fn, bwd_comp_fn,
                                                 bwd_loop_vars)
        d_out = fwd_loop_vars.Transform(jnp.zeros_like)

        tf.nest.map_structure(same_type_shape, d_out.states_0,
                              bwd_final_loop_vars.d_states_1)
        tf.nest.map_structure(same_type_shape, d_out.theta,
                              bwd_final_loop_vars.d_theta)
        tf.nest.map_structure(same_type_shape, d_out.inputs,
                              bwd_final_loop_vars.d_inputs)

        d_out.states_0 = bwd_final_loop_vars.d_states_1
        d_out.theta = bwd_final_loop_vars.d_theta
        d_out.inputs = bwd_final_loop_vars.d_inputs
        return (d_out, )

    fwd_loop.defvjp(loop_fn_vjp_fwd, loop_fn_vjp_bwd)

    # Finally, let's simply run the forward loop fn.
    fwd_final_loop_vars = fwd_loop(fwd_initial_loop_vars)
    fwd_cumulative_states = fwd_final_loop_vars.cumulative_states.Transform(
        lambda x: x[1:])
    return fwd_final_loop_vars.final_states, fwd_cumulative_states
예제 #9
0
def scan(carry_init: NestedMap,
         xs: NestedMap,
         fn: Callable[[NestedMap, NestedMap], Tuple[NestedMap, NestedMap]],
         root_layer: Optional[base_layer.BaseLayer] = None,
         checkpoint_policy: AutodiffCheckpointType = AutodiffCheckpointType.
         SAVE_NOTHING):
    """A simple wrap around jax.lax.scan.

  Back-prop is availale through auto-diff.

  Args:
    carry_init: initial state. A `.NestedMap`.
    xs: inputs. A `.NestedMap`. All inputs in time-major.
    fn: A python function which computes:
        carry, ys[t] = fn(carry, xs[t, :])
    root_layer: The root layer within which this jax.lax.scan based while_loop
      is carried out. If root_layer is provided, some basic-effort check is
      performed to make sure fn is side-effect free. Otherwise, no such checks
      are performed.
    checkpoint_policy: A AutodiffCheckpointType. How to checkpoint for BProp:
      SAVE_NOTHING, SAVE_DOT_ONLY, SAVE_DOT_WITH_NO_BATCH_DIM.

  Returns:
    (final 'carry', 'ys', stacked summaries).
  """
    del root_layer
    assert isinstance(carry_init, py_utils.NestedMap)
    assert isinstance(xs, py_utils.NestedMap)
    # Make a copy of carry_init structure.
    carry_init = tf.nest.map_structure(lambda x: x, carry_init)
    # "carry" will be augmented with the following three tensors, so make sure
    # they don't already exist in the NestedMap.
    assert 'time_step' not in carry_init
    assert 'prng_key' not in carry_init
    assert 'global_step' not in carry_init

    def custom_policy(checkpoint_policy: AutodiffCheckpointType):

        # TODO(zhangqiaorjc): Configure custom checkpoint policy in expt config
        # without introducing enum.
        if checkpoint_policy == AutodiffCheckpointType.SAVE_EVERYTHING:
            return jax.checkpoint_policies.everything_saveable
        if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_ONLY:
            return jax.checkpoint_policies.checkpoint_dots
        if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM:
            return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
        if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ:
            return jax.checkpoint_policies.save_only_these_names(
                'combined_qkv_proj', 'out_proj')
        if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT:
            return jax.checkpoint_policies.save_only_these_names('context')
        if checkpoint_policy == AutodiffCheckpointType.SAVE_OUT_PROJ:
            return jax.checkpoint_policies.save_only_these_names('out_proj')
        if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ:
            return jax.checkpoint_policies.save_only_these_names(
                'context', 'out_proj')
        if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_FOR_MLPERF_200B:
            return jax.checkpoint_policies.save_only_these_names(
                'combined_qkv_proj', 'query_proj', 'value_proj', 'key_proj',
                'context', 'out_proj')
        assert checkpoint_policy == AutodiffCheckpointType.SAVE_NOTHING
        return jax.checkpoint_policies.nothing_saveable

    @functools.partial(ad_checkpoint.checkpoint,
                       prevent_cse=False,
                       policy=custom_policy(checkpoint_policy))
    def fn_wrap(carry, xs_t):
        # carry is augmented with time_step, prng_key, global_step three additional
        # tensors to make fn_wrap fully functional.
        # Start a new prng_key branch that also depends on the time step.
        prng_key_t = jax.random.fold_in(carry.prng_key, carry.time_step)
        with base_layer.JaxContext.new_context(prng_key=prng_key_t,
                                               global_step=carry.global_step):

            carry_new, ys_t = fn(carry, xs_t)
            carry_new.time_step = carry.time_step + 1
            # copy over prng_key and global_step
            carry_new.prng_key = carry.prng_key
            carry_new.global_step = carry.global_step

            tf.nest.assert_same_structure(carry_new, carry)
            summaries = base_layer.all_summaries()

        return carry_new, (ys_t, summaries)

    # The initial time step.
    time_step = jnp.array(0, dtype=jnp.uint32)
    prng_key = base_layer.next_prng_key()
    global_step = base_layer.cur_global_step()
    carry_init.time_step = time_step
    carry_init.prng_key = prng_key
    carry_init.global_step = global_step

    carry_final, (ys, summaries) = jax.lax.scan(fn_wrap, carry_init, xs)

    del carry_final.time_step
    del carry_final.global_step
    del carry_final.prng_key
    return carry_final, ys, summaries
예제 #10
0
def recurrent_static(theta: NestedMap,
                     states_0: NestedMap,
                     inputs: NestedMap,
                     cell_fn: Callable[[NestedMap, NestedMap, NestedMap],
                                       NestedMap],
                     root_layer: Optional[base_layer.BaseLayer] = None):
    """A simpler form of Recurrent where num of steps is known statically.

  Back-prop is availale through auto-diff.

  'padding' in inputs is used to skip certain steps dynamically. If the
  'padding' tensor exists, it is expected of a binary 0/1 tensor.

  Args:
    theta: weights. A `.NestedMap`.
    states_0: initial state. A `.NestedMap`.
    inputs: inputs. A `.NestedMap`. All inputs in time-major.
    cell_fn: A python function which computes::
        states_1 = cell_fn(theta, states_0, inputs[t, :])
    root_layer: The root layer within which this recurrent_static recurrent loop
      is carried out.

  Returns:
    `accumulate_state` and the final state.
  """

    assert 'time_step' not in states_0
    # The initial time step.
    time_step = jnp.array(0, dtype=jnp.uint32)
    # Make a copy of states_0 structure.
    states_0 = tf.nest.map_structure(lambda x: x, states_0)
    states_0.time_step = time_step

    prng_key = base_layer.next_prng_key()
    global_step = base_layer.cur_global_step()

    # TODO(zhangqiaorjc): Switch to ad_checkpoint.checkpoint after mattjj bug fix.
    @jax.checkpoint
    def comp_fn(states_0, inputs_t):
        # Start a new prng_key branch that also depends on the time step.
        if root_layer is not None:
            forward_updated_vars_before = tf.nest.map_structure(
                lambda x: x, root_layer.forward_updated_vars)
        prng_key_t = jax.random.fold_in(prng_key, states_0.time_step)
        with base_layer.JaxContext.new_context(prng_key=prng_key_t,
                                               global_step=global_step):
            # NO side-effect ops are allowed as the enclosing JaxContext is not bound
            # to any layer.
            #
            # Whether or not we should skip this time step.
            if 'padding' in inputs_t:
                # We skip if all are padded steps.
                skip = jnp.all(inputs_t.padding > 0.5)
            else:
                skip = jnp.array(False)

            def carry_over(args):
                states_0, inputs_t = args
                del inputs_t
                # We simply carry over the states for this time step.
                states_1 = tf.nest.map_structure(lambda x: x, states_0)
                states_1.time_step = states_0.time_step + 1
                return states_1

            def do_compute(args):
                states_0, inputs_t = args
                # Actually carry out the computation.
                states_1 = cell_fn(theta, states_0, inputs_t)
                states_1.time_step = states_0.time_step + 1
                return states_1

            if 'padding' in inputs_t:
                states_1 = jax.lax.cond(skip, carry_over, do_compute,
                                        (states_0, inputs_t))
            else:
                states_1 = do_compute((states_0, inputs_t))
            tf.nest.assert_same_structure(states_0, states_1)

            if root_layer is not None:
                forward_updated_vars_after = tf.nest.map_structure(
                    lambda x: x, root_layer.forward_updated_vars)

                def assert_no_change(x, y):
                    assert (x is None and y is None) or (x is not None
                                                         and y is not None)

                tf.nest.map_structure(assert_no_change,
                                      forward_updated_vars_before,
                                      forward_updated_vars_after)

            return states_1, states_1

    final_states, cumulative_states = jax.lax.scan(comp_fn, states_0, inputs)
    del final_states.time_step
    del cumulative_states.time_step
    return final_states, cumulative_states
예제 #11
0
 def _UniformSample(sample_p: float) -> JTensor:
     prng_key = base_layer.next_prng_key()
     rnd_sample = jax.random.uniform(prng_key, inputs.shape)
     return (rnd_sample < sample_p).astype(fprop_dtype)
예제 #12
0
    def fprop(self, inputs: JTensor,
              paddings: JTensor) -> Tuple[JTensor, JTensor]:
        """Applies data augmentation by randomly masking/replacing tokens in inputs.

    Args:
      inputs: An int32 tensor of shape [batch, length].
      paddings: A 0/1 tensor of shape [batch, length].

    Returns:
      A pair <new_inputs, mask>:
      new_inputs: An int32 tensor of shape [batch, length]. The new token ids
        after data augmentation.
      mask: A 0/1 tensor. A "1" indicates the corresponding token at that
        position had undergone the data augmentation process.
    """
        p = self.params
        assert p.vocab_size > 0
        assert p.mask_token_id >= 0
        assert p.mask_prob + p.random_prob + p.same_prob < 1.0
        assert p.mask_prob + p.random_prob + p.same_prob > 0.0

        fprop_dtype = self.fprop_dtype

        def _UniformSample(sample_p: float) -> JTensor:
            prng_key = base_layer.next_prng_key()
            rnd_sample = jax.random.uniform(prng_key, inputs.shape)
            return (rnd_sample < sample_p).astype(fprop_dtype)

        total_replacement_prob = p.mask_prob + p.random_prob + p.same_prob
        # valid_tokens == 1.0 if the corresponding position is a valid token.
        valid_tokens = 1.0 - paddings.astype(fprop_dtype)
        # replacement == 1.0 if the corresponding token is to be replaced by
        # something else (mask, random, self).
        replacement_pos = valid_tokens * _UniformSample(total_replacement_prob)
        no_replacement = 1.0 - replacement_pos

        # First sample the token positions to be masked out.
        remaining_prob = total_replacement_prob
        remaining_pos = replacement_pos
        mask_prob = p.mask_prob / remaining_prob
        # mask_pos == 1.0 if the corresponding token should be masked.
        mask_pos = remaining_pos * _UniformSample(mask_prob)

        # Next sample the token positions to be replaced by random tokens.
        remaining_prob -= p.mask_prob
        remaining_pos -= mask_pos
        assert remaining_prob > 0.0
        random_prob = p.random_prob / remaining_prob
        random_pos = remaining_pos * _UniformSample(random_prob)

        # Lastly, token positions to be replaced by self.
        self_pos = remaining_pos - random_pos

        random_tokens = jax.random.randint(base_layer.next_prng_key(),
                                           inputs.shape, 0, p.vocab_size,
                                           inputs.dtype)
        mask_tokens = jnp.zeros_like(inputs) + p.mask_token_id

        input_dtype = inputs.dtype
        augmented = (inputs * no_replacement.astype(input_dtype) +
                     mask_tokens * mask_pos.astype(input_dtype) +
                     random_tokens * random_pos.astype(input_dtype) +
                     inputs * self_pos.astype(input_dtype))

        return augmented, replacement_pos