示例#1
0
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
示例#2
0
    def _MaybeExpandPaddings(self, inputs, paddings):
        # rank difference is at most one.
        rank_diff = tf.rank(inputs) - tf.rank(paddings)
        paddings = py_utils.with_dependencies([
            py_utils.assert_less_equal(rank_diff, 1),
            py_utils.assert_greater_equal(rank_diff, 0)
        ], paddings)

        # Pads [1] to the end of paddings.
        paddings = tf.reshape(
            paddings,
            tf.concat([tf.shape(paddings),
                       tf.tile([1], [rank_diff])], axis=0))
        return paddings
示例#3
0
def ComputeMoments(inputs,
                   padding,
                   reduce_over_dims,
                   cumulative_axis=None,
                   enable_cross_replica_sum_on_tpu=False,
                   keepdims=False):
    """Computes mean and variance over the valid data points in inputs."""
    mask = 1.0 - padding
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims,
                          keepdims=keepdims)
    count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims)

    if cumulative_axis is not None:
        sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis)
        count_v = tf.math.cumsum(count_v, axis=cumulative_axis)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    input_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(inputs), reduce_over_dims))
    mask_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(mask), reduce_over_dims))
    mask_multiplier = tf.math.truediv(input_size_on_reduced_dims,
                                      mask_size_on_reduced_dims)
    count_v *= tf.cast(mask_multiplier, count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_v = tf.tpu.cross_replica_sum(sum_v)
        count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims,
                           keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
示例#4
0
 def _Slice(tensor):
   """Return a slice of this tensor at time=state0.t."""
   shape = py_utils.GetShape(tensor)
   # All zeros except for t in the time dimension.
   # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
   begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t)
   # Same as shape, but with a 1 in the time dimension.
   # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
   size = tf.concat([
       shape[0:self.params.axis],
       tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
   ],
                    axis=0)
   # Make a slice where the time dimension is fixed at state0.t.
   time_slice = tf.slice(tensor, begin, size)
   # Remove the time dimension.
   return tf.squeeze(time_slice, axis=self.params.axis)
示例#5
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
示例#6
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None, **kwargs):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.
      **kwargs: Additional inputs.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval or p.freeze_bn_stats:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = (self.vars.moving_mean,
                                            self.vars.moving_variance)
            else:
                rank = tf.rank(paddings)
                reduce_over_dims = tf.range(0, rank - 1)
                mean, variance = ComputeMoments(
                    inputs, paddings, reduce_over_dims, None,
                    p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self.vars.moving_variance,
                                             variance, self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram(
                    '%s_moving_mean' % p.name,
                    tf.cast(self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self.vars.moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(
                        tf.cast(mean, self.vars.moving_mean.dtype.base_dtype) -
                        self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(
                        tf.cast(variance,
                                self.vars.moving_variance.dtype.base_dtype) -
                        self.vars.moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies(
                        [mean], self.vars.moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self.vars.moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta, gamma = self._GetBetaGamma(theta, inputs, **kwargs)
            return norm_mean, norm_variance, beta, gamma
示例#7
0
  def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None):
    # pyformat: disable
    """Interpolates source ids in input_batch and interpolation_batch.

    Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060.
    It is a standard Transformer Encoder if interpolation_batch != None.

    Args:
      theta: A `.NestedMap` object containing weights values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
      interpolation_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
        - embs: Embeddings of ids.
      lambdas: A pair of tensors to combine embeddings of ids in input_batch and
        interpolation_batch.

    Returns:
      A NestedMap of

        - encoded: The encoded features, either a tensor of shape
          [time, batch, depth], or a list of tensors if is_transparent is set in
          transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
          (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
          positional encodings.
    """
    # pyformat: enable

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      max_seq_length = None
      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      if interpolation_batch is not None:
        other_input_ids = interpolation_batch.ids
        if not p.shared_emb:
          other_input_embs = self.token_emb.EmbLookup(
              theta.token_emb, tf.reshape(other_input_ids, [-1]))
        else:
          other_input_embs = self.softmax.EmbLookup(
              theta.softmax, tf.reshape(other_input_ids, [-1]))
        lambdas = [tf.expand_dims(a, -1) for a in lambdas]
        if 'embs' in input_batch and input_batch.embs is not None:
          input_embs = input_batch.embs
        if 'embs' in interpolation_batch and interpolation_batch.embs is not None:
          other_input_embs = interpolation_batch.embs
        else:
          input_embs = tf.reshape(
              input_embs,
              [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim])
          other_input_embs = tf.reshape(
              other_input_embs,
              [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim])
        input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs
        paddings = paddings + interpolation_batch.paddings - 1.0
        paddings = tf.clip_by_value(paddings, 0.0, 1.0)

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])

      orig_input_embs = input_embs
      if p.task_emb:
        if interpolation_batch is None:
          input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                input_batch.task_ids)
        else:
          task_embs = self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)
          other_task_embs = self.task_emb.EmbLookup(
              theta.task_emb, interpolation_batch.task_ids)
          task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs
          input_embs += task_embs

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      input_embs += position_embs

      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)

    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)