Beispiel #1
0
  def _Pack(self, batch):
    """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
    src_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.src.ids_indicator, tf.int32), axis=1)
    tgt_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1)
    summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
    summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

    if not self.params.packing_factor:
      # Supply segment_ids and segment_pos with no packing.
      batch.src.segment_ids = batch.src.ids_indicator
      batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
      batch.tgt.segment_ids = batch.tgt.ids_indicator
      batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
      return

    (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids,
     tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences(
         src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
         self.params.source_max_length, self.params.target_max_length)

    uniq_src_indices_in_input = tf.unique(
        tf.reshape(src_indices_in_input, [-1])).y
    uniq_tgt_indices_in_input = tf.unique(
        tf.reshape(tgt_indices_in_input, [-1])).y
    summary_utils.histogram(
        'packed_source_seq_lengths',
        tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
    summary_utils.histogram(
        'packed_target_seq_lengths',
        tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

    # We deferred adding .paddings and use its complement .ids_indicator
    # exclusively so that we can apply the packing with padding set to 0 for all
    # fields.
    def ApplyPackingToSource(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input)
      return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input)

    batch.src = batch.src.Transform(ApplyPackingToSource)
    batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
    batch.src.segment_pos = src_segment_pos

    def ApplyPackingToTarget(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input)
      return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input)

    batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
    batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
    batch.tgt.segment_pos = tgt_segment_pos
Beispiel #2
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None, **kwargs):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.
      **kwargs: Additional inputs.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval or p.freeze_bn_stats:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = (self.vars.moving_mean,
                                            self.vars.moving_variance)
            else:
                rank = tf.rank(paddings)
                reduce_over_dims = tf.range(0, rank - 1)
                mean, variance = ComputeMomentsWithPadding(
                    inputs, paddings, reduce_over_dims, None,
                    p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self.vars.moving_variance,
                                             variance, self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram(
                    '%s_moving_mean' % p.name,
                    tf.cast(self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self.vars.moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(
                        tf.cast(mean, self.vars.moving_mean.dtype.base_dtype) -
                        self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(
                        tf.cast(variance,
                                self.vars.moving_variance.dtype.base_dtype) -
                        self.vars.moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies(
                        [mean], self.vars.moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self.vars.moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta, gamma = self._GetBetaGamma(theta, inputs, **kwargs)
            return norm_mean, norm_variance, beta, gamma
Beispiel #3
0
    def FProp(self, theta, x, paddings=None, update=False):
        """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
        p = self.params
        x = tf.cast(x, theta.means.dtype)
        if paddings is None:
            paddings = tf.zeros_like(x[:, :, 0, 0])
        # Shape [B, L, 1, 1]
        paddings_4d = paddings[:, :, None, None]

        if p.apply_layer_norm:
            x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

        # 'x' is normalized (but theta.means is not), we use negative dot product to
        # approximate the Euclidean distance here.
        dists = -2 * tf.einsum('BLNH, NKH -> BLNK', x, theta.means)
        if not p.apply_layer_norm:
            # If entries are not normalized, compute norms here.
            x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
            means_norm_sq = tf.reduce_sum(tf.square(theta.means),
                                          axis=-1,
                                          keepdims=False)
            means_norm_sq = tf.expand_dims(means_norm_sq, axis=0)
            means_norm_sq = tf.expand_dims(means_norm_sq, axis=0)
            dists += x_norm_sq + means_norm_sq

        # For padded positions we update the distances to very large numbers.
        very_large_dists = tf.ones_like(dists) * tf.constant(
            0.1, dtype=dists.dtype) * dists.dtype.max
        paddings_tiled = tf.tile(paddings_4d,
                                 [1, 1, p.num_heads, p.num_clusters])
        dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

        # Shape [B, L, N, K], the same as 'dists' above.
        nearest_one_hot = tf.one_hot(tf.math.argmin(dists, axis=-1),
                                     p.num_clusters,
                                     dtype=theta.means.dtype)
        # Same shape as the input 'x'.
        nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                     theta.means)
        diff = tf.math.squared_difference(x,
                                          tf.stop_gradient(nearest_centroid))
        diff = py_utils.ApplyPadding(paddings_4d, diff)
        diff = tf.math.reduce_mean(diff, axis=2)

        # The commitment loss which when back proped against encourages the 'x'
        # values to commit to their chosen centroids.
        diff = tf.cast(diff, tf.float32)
        paddings = tf.cast(paddings, tf.float32)
        k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 -
                                                                     paddings)
        summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

        # TODO(zhouwk): investigate normalizing theta.means after each update.
        means_norm = tf.norm(theta.means)
        summary_utils.scalar('k_means/centroid_l2_norm/min',
                             tf.math.reduce_min(means_norm))
        summary_utils.scalar('k_means/centroid_l2_norm/mean',
                             tf.math.reduce_mean(means_norm))

        if not update:
            return dists, k_means_loss

        # To update the centroids (self.vars.means), we apply gradient descent on
        # the mini-batch of input 'x', which yields the following:
        #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
        # where x_mean is the average over all the input vectors closest to this
        # centroid.
        #
        # Note that this approach is equivalent with backprop via
        #    loss = tf.math.reduce_mean(
        #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
        # , except that here the learning rate is independently set via 'decay'.

        # Ensure that the padded positions are not used to update the centroids.
        nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

        # Sum away batch and sequence length dimensions to get per cluster count.
        # Shape: [N, K]
        per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
        summary_utils.histogram('k_means/per_cluster_vec_count',
                                per_cluster_count)

        # Sum of the input 'x' per each closest centroid.
        sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

        if py_utils.use_tpu():
            per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
            sum_x = tf.tpu.cross_replica_sum(sum_x)

        if p.use_ema:
            updated_ema_count = moving_averages.assign_moving_average(
                self.vars.ema_count,
                tf.cast(per_cluster_count, self.vars.ema_count.dtype),
                p.decay,
                zero_debias=False)
            updated_ema_means = moving_averages.assign_moving_average(
                self.vars.ema_means,
                tf.cast(sum_x, self.vars.ema_means.dtype),
                p.decay,
                zero_debias=False)
            n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
            updated_ema_count = ((updated_ema_count + p.epsilon) /
                                 (n + p.num_clusters * p.epsilon) * n)
            updated_ema_means = updated_ema_means / tf.expand_dims(
                updated_ema_count, axis=-1)
            updated_ema_means = tf.cast(updated_ema_means,
                                        self.vars.means.dtype)
            means = tf.cast(theta.means, updated_ema_means.dtype)
            update_means_diff = updated_ema_means - means
        else:
            # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
            # cluster's position will always be 0, hence 'sum_x' in that dimension
            # will be 0.
            new_means = sum_x / tf.maximum(
                tf.constant(1.0, dtype=per_cluster_count.dtype),
                tf.expand_dims(per_cluster_count, axis=-1))
            # Note that we intentionally do not normalize the means after this update
            # as empirically this works better.
            update_means_diff = tf.cast(
                (1.0 - p.decay) * (new_means - theta.means),
                self.vars.means.dtype)
        return py_utils.with_dependencies(
            [tf.assign_add(self.vars.means, update_means_diff)],
            dists), k_means_loss
Beispiel #4
0
  def ComputePredictions(self, theta, source_encs, source_paddings, targets,
                         src_segment_id):
    """Decodes `targets` given encoded source.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_encs: source encoding, of shape [time, batch, depth].
      source_paddings: source encoding's padding, of shape [time, batch].
      targets: A dict of string to tensors representing the targets one try to
        predict. Each tensor in targets is of shape [batch, time].
      src_segment_id: source segment id, of shape [time, batch].

    Returns:
      A Tensor with shape [time, batch, params.softmax.input_dim].
    """
    p = self.params
    time, batch = py_utils.GetShape(source_paddings, 2)
    source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim])
    with tf.name_scope(p.name):
      target_ids = tf.transpose(targets.ids)
      target_paddings = py_utils.HasRank(targets.paddings, 2)
      target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2)
      if p.packed_input:
        target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2)
      else:
        target_segment_id = tf.zeros_like(target_paddings)

      if py_utils.use_tpu():
        emb_device = self.cluster.WorkerDeviceInModelSplit(0)
      else:
        emb_device = ''
      with tf.device(emb_device):
        inputs = self.emb.EmbLookup(theta.emb, target_ids)
        inputs = self.ApplyClipping(theta, inputs)
        summary_utils.histogram('input_emb', inputs)
        inputs = self.ApplyDropout(inputs)
        self._emb_out = inputs

        # Layer 0 interwines with attention.
        (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp(
            theta.frnn_with_atten,
            source_encs,
            source_paddings,
            inputs,
            target_paddings,
            src_segment_id=src_segment_id,
            segment_id=target_segment_id)
        self._AddAttenProbsSummary(source_paddings, targets, [atten_probs])

        atten_ctxs = self.ApplyClipping(theta, atten_ctxs)
        summary_utils.histogram('atten_ctxs', atten_ctxs)

        for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)):
          # Forward through Layer-(i + 1) because Layer-0 handled before.
          ys, _ = layer.FProp(
              layer_theta,
              tf.concat([xs, atten_ctxs], 2),
              target_paddings,
              segment_id=target_segment_id)
          ys = self.ApplyDropout(ys)
          if 1 + i >= p.residual_start:
            xs += ys  # Residual skip
            xs = self.ApplyClipping(theta, xs)
          else:
            xs = ys
          summary_utils.histogram('layer_out_%s' % i, xs)

        if p.feed_attention_context_vec_to_softmax:
          xs = tf.concat([xs, atten_ctxs], 2)

        return xs
Beispiel #5
0
 def _AddAttenProbsHistogramSummary(self, name, atten_probs):
   """Add histogram summary of attention probs."""
   summary_utils.histogram(name + '/atten_probs', atten_probs)
Beispiel #6
0
  def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
    """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
    p = self.params
    if paddings is None:
      paddings = self._GetDefaultPaddings(inputs)
    inputs = py_utils.with_dependencies([
        py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim]),
        py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
    ], inputs)
    with tf.name_scope(p.name):
      if p.is_eval:
        # The mean and variance used for normalization.
        norm_mean, norm_variance = self._moving_mean, self._moving_variance
      else:
        mean, variance = self._Moments(inputs, 1.0 - paddings,
                                       p.enable_cross_replica_sum_on_tpu)

        py_utils.UpdateBatchNormVars(self._moving_mean, mean, self._decay)
        py_utils.UpdateBatchNormVars(self._moving_variance, variance,
                                     self._decay)
        # Add some summaries for visualization.
        summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32))
        summary_utils.histogram('%s_variance' % p.name,
                                tf.cast(variance, tf.float32))
        summary_utils.histogram('%s_moving_mean' % p.name,
                                tf.cast(self._moving_mean, tf.float32))
        summary_utils.histogram('%s_moving_variance' % p.name,
                                tf.cast(self._moving_variance, tf.float32))
        summary_utils.histogram('%s_mean_diff' % p.name,
                                tf.cast(mean - self._moving_mean, tf.float32))
        summary_utils.histogram(
            '%s_variance_diff' % p.name,
            tf.cast(variance - self._moving_variance, tf.float32))
        if p.use_moving_avg_in_training:
          # Use the global statistics for normalization.
          # Control dependencies on mean and variance make sure
          # moving_mean and variance will be updated for every training step.
          norm_mean = py_utils.with_dependencies([mean], self._moving_mean)
          norm_variance = py_utils.with_dependencies([variance],
                                                     self._moving_variance)
        else:
          # Use the batch statistics for normalization.
          norm_mean = mean
          norm_variance = variance

      norm_mean = py_utils.CheckNumerics(
          norm_mean, 'mean of %s failed numeric check' % p.name)
      norm_variance = py_utils.CheckNumerics(
          norm_variance, 'variance of %s failed numeric check' % p.name)

      if p.use_moving_avg_in_training:
        beta = 0.0
        gamma = 1.0
      else:
        beta = theta.beta
        gamma = theta.gamma
      return norm_mean, norm_variance, beta, gamma
Beispiel #7
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            summary_utils.histogram('input_embs', input_embs)
            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)
                summary_utils.histogram('emb_proj', input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Beispiel #8
0
  def _ApplyPacking(self, batch):
    """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
    src_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.src.ids_indicator, tf.int32), axis=1)
    tgt_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1)
    summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
    summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

    if not self.params.packing_factor:
      # Supply segment_ids and segment_pos with no packing.
      batch.src.segment_ids = batch.src.ids_indicator
      batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
      batch.tgt.segment_ids = batch.tgt.ids_indicator
      batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
      return

    (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids,
     tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences(
         src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
         self.params.source_max_length, self.params.target_max_length)

    uniq_src_indices_in_input = tf.unique(
        tf.reshape(src_indices_in_input, [-1])).y
    uniq_tgt_indices_in_input = tf.unique(
        tf.reshape(tgt_indices_in_input, [-1])).y
    summary_utils.histogram(
        'packed_source_seq_lengths',
        tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
    summary_utils.histogram(
        'packed_target_seq_lengths',
        tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

    # Ratio of number of non-padded tokens. If < 1.0, we are dropping
    # input data due to p.packing_factor too high.
    src_orig_tokens_count = tf.cast(
        tf.reduce_sum(src_actual_seq_len), tf.float32)
    src_packed_tokens_count = tf.reduce_sum(
        tf.cast(src_segment_ids > 0, tf.float32))
    summary_utils.scalar('examples/src_packed_token_ratio',
                         src_packed_tokens_count / src_orig_tokens_count)
    tgt_orig_tokens_count = tf.cast(
        tf.reduce_sum(tgt_actual_seq_len), tf.float32)
    tgt_packed_tokens_count = tf.reduce_sum(
        tf.cast(tgt_segment_ids > 0, tf.float32))
    summary_utils.scalar('examples/tgt_packed_token_ratio',
                         tgt_packed_tokens_count / tgt_orig_tokens_count)

    # We deferred adding .paddings and use its complement .ids_indicator
    # exclusively so that we can apply the packing with padding set to 0 for all
    # fields.
    def ApplyPackingToSource(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input)
      return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input)

    src_paddings = ops.apply_packing(batch.src.paddings, 1, src_segment_ids,
                                     src_indices_in_input)
    batch.src = batch.src.Transform(ApplyPackingToSource)
    batch.src.paddings = src_paddings
    batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
    batch.src.segment_pos = src_segment_pos

    def ApplyPackingToTarget(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input)
      return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input)

    tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1, tgt_segment_ids,
                                     tgt_indices_in_input)
    batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
    batch.tgt.paddings = tgt_paddings
    batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
    batch.tgt.segment_pos = tgt_segment_pos

    # The number of examples is indicated by the segment_ids of the target.
    num_segments = tf.math.reduce_max(batch.tgt.segment_ids, axis=1)
    num_examples = tf.reduce_sum(num_segments)
    # Note that this is per infeed value when p.use_per_host_infeed = True.
    metric_name = 'examples/num_packed_examples'
    summary_utils.scalar(metric_name, num_examples)