Exemplo n.º 1
0
    def __init__(self, params):
        super(MlPerfInput, self).__init__(params)
        p = self.params

        self.natural_order_model = p.natural_order_model

        (
            self._src_ids,
            self._src_paddings,
            self._tgt_ids,
            self._tgt_paddings,
            self._tgt_labels,
            self._tgt_weights,
            self._src_seg_pos,
            self._src_seg_ids,
            self._tgt_seg_pos,
            self._tgt_seg_ids,
        ), self._bucket_keys = self._BuildDataSource()

        if p.pad_to_max_seq_length:
            assert p.source_max_length

            if min(self.scaled_bucket_batch_limit) == max(
                    self.scaled_bucket_batch_limit):
                source_shape = [
                    min(self.scaled_bucket_batch_limit), p.source_max_length
                ]
                target_shape = [
                    min(self.scaled_bucket_batch_limit), p.target_max_length
                ]
            else:
                source_shape = None
                target_shape = None
            self._src_ids = py_utils.PadSequenceDimension(
                self._src_ids, p.source_max_length, 0, source_shape)
            self._src_paddings = py_utils.PadSequenceDimension(
                self._src_paddings, p.source_max_length, 1, source_shape)
            self._tgt_ids = py_utils.PadSequenceDimension(
                self._tgt_ids, p.target_max_length, 0, target_shape)
            self._tgt_paddings = py_utils.PadSequenceDimension(
                self._tgt_paddings, p.target_max_length, 1, target_shape)
            self._tgt_labels = py_utils.PadSequenceDimension(
                self._tgt_labels, p.target_max_length, 0, target_shape)
            self._tgt_weights = py_utils.PadSequenceDimension(
                self._tgt_weights, p.target_max_length, 0, target_shape)

            self._src_seg_ids = py_utils.PadSequenceDimension(
                self._src_seg_ids, p.source_max_length, 0, source_shape)
            self._src_seg_pos = py_utils.PadSequenceDimension(
                self._src_seg_pos, p.source_max_length, 0, source_shape)
            self._tgt_seg_ids = py_utils.PadSequenceDimension(
                self._tgt_seg_ids, p.target_max_length, 0, target_shape)
            self._tgt_seg_pos = py_utils.PadSequenceDimension(
                self._tgt_seg_pos, p.target_max_length, 0, target_shape)

        self._input_batch_size = tf.shape(self._src_ids)[0]
        self._sample_ids = tf.range(0, self._input_batch_size, 1)
Exemplo n.º 2
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params

        (utt_ids, src_frames,
         src_paddings), self._bucket_keys = self._BuildDataSource()

        self._sample_ids = utt_ids

        src_frames, src_paddings = self._MaybePadSourceInputs(
            src_frames, src_paddings)

        # We expect src_inputs to be of shape
        # [batch_size, num_frames, feature_dim, channels].
        src_frames = tf.expand_dims(src_frames, axis=-1)

        if p.pad_to_max_seq_length:
            assert p.source_max_length
            assert p.target_max_length

            if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
                # Set the input batch size as an int rather than a tensor.
                src_frames_shape = (self.InfeedBatchSize(),
                                    p.source_max_length, p.frame_size, 1)
                src_paddings_shape = (self.InfeedBatchSize(),
                                      p.source_max_length)
            else:
                tf.logging.warning(
                    'Could not set static input shape since not all bucket batch sizes '
                    'are the same:', p.bucket_batch_limit)
                src_frames_shape = None
                src_paddings_shape = None

            src_frames = py_utils.PadBatchDimension(src_frames,
                                                    self.InfeedBatchSize(), 0)
            src_paddings = py_utils.PadBatchDimension(src_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            self._sample_ids = py_utils.PadBatchDimension(
                self._sample_ids, self.InfeedBatchSize(), self._sample_ids.min)

            src_frames = py_utils.PadSequenceDimension(src_frames,
                                                       p.source_max_length,
                                                       0,
                                                       shape=src_frames_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
            self._sample_ids = tf.ensure_shape(self._sample_ids,
                                               self.InfeedBatchSize())

        src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings)

        self._src = src
Exemplo n.º 3
0
def _PadForLengthCompatibleStridesV2(tensor, stride, padding_algorithm,
                                     constant_values):
    """Pads tensor to make strided convolutions start in the first position.

  Tensorflow strided convolutions and Lingvo paddings are incompatible.
  Strided convolutions always end at the last index of the length dimension.
  Therefore, the output of a Lingvo padded tensor depends on the length
  dimension. Here we remove this dependency by pre-padding the tensor so that
  the first convolution starts in the first position.

  Args:
    tensor: The tensor to prepare for convolution. [batch, time, ...].
    stride: The stride in the length dimension.
    padding_algorithm: 'SAME' or 'VALID'.
    constant_values: Value to pad 0. for data tensor and 1.0 for padding tensor.

  Returns:
    A tuple (tensor, padded_length) where tensor is the potentionally padded
    tensor and padded_length is the number paddings.
  """
    if padding_algorithm == 'VALID':
        return tensor, 0

    input_length = py_utils.GetShape(tensor)[1]
    pad_len = ((input_length // stride) + 1) * stride - 1 - input_length
    if pad_len == 0:
        return tensor, 0
    tensor = py_utils.PadSequenceDimension(tensor, input_length + pad_len,
                                           constant_values)
    return tensor, pad_len
Exemplo n.º 4
0
def ConvertToBlocks(x, block_size, padding_val=0.0):
    """Turns a sequence to non overlapping blocks.

  Args:
    x: a tensor of [batch, time, ...].
    block_size: int. Number of time frames in a block.
    padding_val: float. value on the padded frames.

  Returns:
    A tensor of [batch, num_blocks, block_size, ...], with necessary paddings,
    where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  """
    shape = py_utils.GetShape(x)
    b = shape[0]
    t = shape[1]
    if block_size < 1:
        raise ValueError(
            'block_size must be at least 1, got {}'.format(block_size))
    w = block_size
    # Pad t to be a multiply of w.
    num_blocks = (t + w - 1) // w
    pad_to_length = num_blocks * w
    padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val)
    reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:])
    return reshaped
Exemplo n.º 5
0
    def _PadSequences(self):
        p = self.params
        assert p.source_max_length

        if min(self.infeed_bucket_batch_limit) == max(
                self.infeed_bucket_batch_limit):
            source_shape = [
                min(self.infeed_bucket_batch_limit), p.source_max_length
            ]
            target_shape = [
                min(self.infeed_bucket_batch_limit), p.target_max_length
            ]
        else:
            source_shape = None
            target_shape = None
        self._src_ids = py_utils.PadSequenceDimension(self._src_ids,
                                                      p.source_max_length, 0,
                                                      source_shape)
        self._src_paddings = py_utils.PadSequenceDimension(
            self._src_paddings, p.source_max_length, 1, source_shape)
        self._tgt_ids = py_utils.PadSequenceDimension(self._tgt_ids,
                                                      p.target_max_length, 0,
                                                      target_shape)
        self._tgt_paddings = py_utils.PadSequenceDimension(
            self._tgt_paddings, p.target_max_length, 1, target_shape)
        self._tgt_labels = py_utils.PadSequenceDimension(
            self._tgt_labels, p.target_max_length, 0, target_shape)
        self._tgt_weights = py_utils.PadSequenceDimension(
            self._tgt_weights, p.target_max_length, 0, target_shape)
Exemplo n.º 6
0
  def _PadSequences(self):
    p = self.params
    assert p.source_max_length

    if min(self.infeed_bucket_batch_limit) == max(
        self.infeed_bucket_batch_limit):
      source_shape = [min(self.infeed_bucket_batch_limit), p.source_max_length]
      target_shape = [min(self.infeed_bucket_batch_limit), p.target_max_length]
    else:
      source_shape = None
      target_shape = None
    self._src_ids = py_utils.PadSequenceDimension(self._src_ids,
                                                  p.source_max_length, 0,
                                                  source_shape)
    self._src_paddings = py_utils.PadSequenceDimension(self._src_paddings,
                                                       p.source_max_length, 1,
                                                       source_shape)
    self._tgt_ids = py_utils.PadSequenceDimension(self._tgt_ids,
                                                  p.target_max_length, 0,
                                                  target_shape)
    self._tgt_paddings = py_utils.PadSequenceDimension(self._tgt_paddings,
                                                       p.target_max_length, 1,
                                                       target_shape)
    self._tgt_labels = py_utils.PadSequenceDimension(self._tgt_labels,
                                                     p.target_max_length, 0,
                                                     target_shape)
    self._tgt_weights = py_utils.PadSequenceDimension(self._tgt_weights,
                                                      p.target_max_length, 0,
                                                      target_shape)

    # TODO(zhifengc): come up more meaningful training sample ids here.
    self._sample_ids = tf.range(0, self.InfeedBatchSize(), 1)
Exemplo n.º 7
0
 def testPadSequenceDimension_2D(self):
   with self.session(use_gpu=False, graph=tf.Graph()) as sess:
     x = tf.random_normal(shape=(3, 3), seed=123456)
     length = 6
     padded_x = py_utils.PadSequenceDimension(x, length, 0)
     self.assertEqual(padded_x.shape.as_list(), [3, 6])
     real_x = sess.run(padded_x)
     # pyformat: disable
     expected_x = [[0.38615, 2.975221, -0.852826, 0., 0., 0.],
                   [-0.571142, -0.432439, 0.413158, 0., 0., 0.],
                   [0.255314, -0.985647, 1.461641, 0., 0., 0.]]
     # pyformat: enable
     self.assertAllClose(expected_x, real_x)
Exemplo n.º 8
0
 def testPadSequenceDimension_2D_UnknownShape(self):
   with self.session(use_gpu=False, graph=tf.Graph()) as sess:
     shape = tf.placeholder(tf.int32)
     x = tf.random_normal(shape=shape, seed=123456)
     length = 6
     padded_x = py_utils.PadSequenceDimension(x, length, 0)
     self.assertEqual(padded_x.shape, None)
     real_x = sess.run(padded_x, feed_dict={shape: [3, 3]})
     # pyformat: disable
     expected_x = [[0.38615, 2.975221, -0.852826, 0., 0., 0.],
                   [-0.571142, -0.432439, 0.413158, 0., 0., 0.],
                   [0.255314, -0.985647, 1.461641, 0., 0., 0.]]
     # pyformat: enable
     self.assertAllClose(expected_x, real_x)
Exemplo n.º 9
0
 def testPadSequenceDimension_4D(self):
   with self.session(use_gpu=False, graph=tf.Graph()) as sess:
     x = tf.random_normal(shape=(2, 2, 2, 2), seed=123456)
     length = 4
     padded_x = py_utils.PadSequenceDimension(x, length, 1)
     real_x = sess.run(padded_x)
     # pyformat: disable
     expected_x = [[[[0.38614973, 2.97522092], [-0.85282576, -0.57114178]],
                    [[-0.43243945, 0.41315758], [0.2553139, -0.98564667]],
                    [[1., 1.], [1., 1.]],
                    [[1., 1.], [1., 1.]]],
                   [[[1.46164131, 0.12003655], [-0.0986772, 0.60644895]],
                    [[0.03092973, -0.96897006], [-1.27853918, -0.44018385]],
                    [[1., 1.], [1., 1.]],
                    [[1., 1.], [1., 1.]]]]
     # pyformat: enable
     self.assertAllClose(expected_x, real_x)
Exemplo n.º 10
0
  def __init__(self, params):
    super(NmtInput, self).__init__(params)
    p = self.params

    self.natural_order_model = p.natural_order_model

    (self._src_ids, self._src_paddings, self._tgt_ids, self._tgt_paddings,
     self._tgt_labels,
     self._tgt_weights), self._bucket_keys = self._BuildDataSource()

    if p.pad_to_max_seq_length:
      assert p.source_max_length

      if min(self.infeed_bucket_batch_limit) == max(
          self.infeed_bucket_batch_limit):
        source_shape = [
            min(self.infeed_bucket_batch_limit), p.source_max_length
        ]
        target_shape = [
            min(self.infeed_bucket_batch_limit), p.target_max_length
        ]
      else:
        source_shape = None
        target_shape = None
      self._src_ids = py_utils.PadSequenceDimension(self._src_ids,
                                                    p.source_max_length, 0,
                                                    source_shape)
      self._src_paddings = py_utils.PadSequenceDimension(
          self._src_paddings, p.source_max_length, 1, source_shape)
      self._tgt_ids = py_utils.PadSequenceDimension(self._tgt_ids,
                                                    p.target_max_length, 0,
                                                    target_shape)
      self._tgt_paddings = py_utils.PadSequenceDimension(
          self._tgt_paddings, p.target_max_length, 1, target_shape)
      self._tgt_labels = py_utils.PadSequenceDimension(self._tgt_labels,
                                                       p.target_max_length, 0,
                                                       target_shape)
      self._tgt_weights = py_utils.PadSequenceDimension(self._tgt_weights,
                                                        p.target_max_length, 0,
                                                        target_shape)

    # TODO(zhifengc): come up more meaningful training sample ids here.
    self._sample_ids = tf.range(0, self.InfeedBatchSize(), 1)
Exemplo n.º 11
0
 def testPadSequenceDimension_ShortPaddingLength(self):
   x = tf.random_normal(shape=(3, 8), seed=123456)
   length = 6
   with self.assertRaisesRegexp(ValueError, 'Paddings must be non-negative'):
     py_utils.PadSequenceDimension(x, length, 0)
Exemplo n.º 12
0
    def BuildInputBatch(self, batch_size, features_list, bucket_keys=None):
        """Builds an input batch.

    Args:
      batch_size: batch size to use, defaults to infeed batch size.
      features_list: Use this list to build the batch.
      bucket_keys: If None, bucket_keys[i] is the bucketing key of the i-th
        sample.

    Returns:
      py_utils.NestedMap with feature names as keys and tensors as values.
    """
        p = self.params

        ret = py_utils.NestedMap()
        ret.bucket_keys = bucket_keys

        (src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels,
         tgt_weights) = features_list
        if p.pad_to_max_seq_length:
            assert p.source_max_length

            if min(self.infeed_bucket_batch_limit) == max(
                    self.infeed_bucket_batch_limit):
                source_shape = [
                    min(self.infeed_bucket_batch_limit), p.source_max_length
                ]
                target_shape = [
                    min(self.infeed_bucket_batch_limit), p.target_max_length
                ]
            else:
                source_shape = None
                target_shape = None
            src_ids = py_utils.PadSequenceDimension(src_ids,
                                                    p.source_max_length, 0,
                                                    source_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, source_shape)
            tgt_ids = py_utils.PadSequenceDimension(tgt_ids,
                                                    p.target_max_length, 0,
                                                    target_shape)
            tgt_paddings = py_utils.PadSequenceDimension(
                tgt_paddings, p.target_max_length, 1, target_shape)
            tgt_labels = py_utils.PadSequenceDimension(tgt_labels,
                                                       p.target_max_length, 0,
                                                       target_shape)
            tgt_weights = py_utils.PadSequenceDimension(
                tgt_weights, p.target_max_length, 0, target_shape)

        ret.src = py_utils.NestedMap()
        ret.src.ids = tf.cast(src_ids, dtype=tf.int32)
        ret.src.paddings = src_paddings

        ret.tgt = py_utils.NestedMap()
        ret.tgt.ids = tgt_ids
        ret.tgt.labels = tf.cast(tgt_labels, dtype=tf.int32)
        ret.tgt.weights = tgt_weights
        ret.tgt.paddings = tgt_paddings

        if (self.params.fprop_dtype is None
                or self.params.dtype == self.params.fprop_dtype):
            return ret

        def _Cast(v):
            if not v.dtype.is_floating:
                return v
            return tf.cast(v, self.params.fprop_dtype)

        return ret.Transform(_Cast)
Exemplo n.º 13
0
  def __init__(self, params):
    super(AsrInput, self).__init__(params)
    p = self.params

    (utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames,
     src_paddings), self._bucket_keys = self._BuildDataSource()

    self._input_batch_size = tf.shape(utt_ids)[0]
    self._sample_ids = utt_ids

    src_frames, src_paddings = self._MaybePadSourceInputs(
        src_frames, src_paddings)

    # We expect src_inputs to be of shape
    # [batch_size, num_frames, feature_dim, channels].
    src_frames = tf.expand_dims(src_frames, dim=-1)

    # Convert target ids, labels, paddings, and weights from shape [batch_size,
    # 1, num_frames] to [batch_size, num_frames]
    tgt_ids = tf.squeeze(tgt_ids, axis=1)
    tgt_labels = tf.squeeze(tgt_labels, axis=1)
    tgt_paddings = tf.squeeze(tgt_paddings, axis=1)

    if p.pad_to_max_seq_length:
      assert p.source_max_length
      assert p.target_max_length

      if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
        # Set the input batch size as an int rather than a tensor.
        self._input_batch_size = self.scaled_bucket_batch_limit[0]
        src_frames_shape = (self._input_batch_size, p.source_max_length,
                            p.frame_size, 1)
        src_paddings_shape = (self._input_batch_size, p.source_max_length)
        tgt_shape = (self._input_batch_size, p.target_max_length)
      else:
        tf.logging.warning(
            'Could not set static input shape since not all bucket batch sizes '
            'are the same:', p.bucket_batch_limit)
        src_frames_shape = None
        src_paddings_shape = None
        tgt_shape = None

      src_frames = py_utils.PadSequenceDimension(
          src_frames, p.source_max_length, 0, shape=src_frames_shape)
      src_paddings = py_utils.PadSequenceDimension(
          src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
      tgt_ids = py_utils.PadSequenceDimension(
          tgt_ids, p.target_max_length, 0, shape=tgt_shape)
      tgt_labels = py_utils.PadSequenceDimension(
          tgt_labels, p.target_max_length, 0, shape=tgt_shape)
      tgt_paddings = py_utils.PadSequenceDimension(
          tgt_paddings, p.target_max_length, 1, shape=tgt_shape)

    tgt = py_utils.NestedMap(
        ids=tgt_ids,
        labels=tgt_labels,
        paddings=tgt_paddings,
        weights=1.0 - tgt_paddings)
    src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings)

    self._tgt = tgt
    self._src = src
Exemplo n.º 14
0
    def BuildInputBatch(self, batch_size, features_list, bucket_keys=None):
        """Builds an input batch.

    Args:
      batch_size: batch size to use, defaults to infeed batch size.
      features_list: Use this list to build the batch.
      bucket_keys: If None, bucket_keys[i] is the bucketing key of the i-th
        sample.

    Returns:
      py_utils.NestedMap with feature names as keys and tensors as values.
    """
        p = self.params

        batch = py_utils.NestedMap()
        batch.bucket_keys = bucket_keys

        (utt_ids, tgt_ids, tgt_labels, tgt_paddings, src_frames,
         src_paddings) = features_list

        if not py_utils.use_tpu():
            batch.sample_ids = utt_ids

        src_frames, src_paddings = self._MaybePadSourceInputs(
            src_frames, src_paddings)

        # We expect src_inputs to be of shape
        # [batch_size, num_frames, feature_dim, channels].
        src_frames = tf.expand_dims(src_frames, axis=-1)

        # Convert target ids, labels, paddings, and weights from shape [batch_size,
        # 1, num_frames] to [batch_size, num_frames]
        tgt_ids = tf.squeeze(tgt_ids, axis=1)
        tgt_labels = tf.squeeze(tgt_labels, axis=1)
        tgt_paddings = tf.squeeze(tgt_paddings, axis=1)

        if p.pad_to_max_seq_length:
            assert p.source_max_length
            assert p.target_max_length

            if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
                # Set the input batch size as an int rather than a tensor.
                src_frames_shape = (self.InfeedBatchSize(),
                                    p.source_max_length, p.frame_size, 1)
                src_paddings_shape = (self.InfeedBatchSize(),
                                      p.source_max_length)
                tgt_shape = (self.InfeedBatchSize(), p.target_max_length)
            else:
                tf.logging.warning(
                    'Could not set static input shape since not all bucket batch sizes '
                    'are the same:', p.bucket_batch_limit)
                src_frames_shape = None
                src_paddings_shape = None
                tgt_shape = None

            src_frames = py_utils.PadSequenceDimension(src_frames,
                                                       p.source_max_length,
                                                       0,
                                                       shape=src_frames_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
            tgt_ids = py_utils.PadSequenceDimension(tgt_ids,
                                                    p.target_max_length,
                                                    0,
                                                    shape=tgt_shape)
            tgt_labels = py_utils.PadSequenceDimension(tgt_labels,
                                                       p.target_max_length,
                                                       0,
                                                       shape=tgt_shape)
            tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings,
                                                         p.target_max_length,
                                                         1,
                                                         shape=tgt_shape)

        batch.src = py_utils.NestedMap(src_inputs=src_frames,
                                       paddings=src_paddings)
        batch.tgt = py_utils.NestedMap(ids=tgt_ids,
                                       labels=tgt_labels,
                                       paddings=tgt_paddings,
                                       weights=1.0 - tgt_paddings)

        return batch
Exemplo n.º 15
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params

        (utt_ids, audio_document_ids, num_utterances_in_audio_document,
         tgt_ids, tgt_labels, tgt_paddings, src_frames,
         src_paddings), self._bucket_keys = self._BuildDataSource()

        self._sample_ids = utt_ids

        src_frames, src_paddings = self._MaybePadSourceInputs(
            src_frames, src_paddings)

        # We expect src_inputs to be of shape
        # [batch_size, num_frames, feature_dim, channels].
        src_frames = tf.expand_dims(src_frames, axis=-1)

        # Convert target ids, labels, paddings, and weights from shape [batch_size,
        # 1, num_frames] to [batch_size, num_frames]
        tgt_ids = tf.squeeze(tgt_ids, axis=1)
        tgt_labels = tf.squeeze(tgt_labels, axis=1)
        tgt_paddings = tf.squeeze(tgt_paddings, axis=1)

        if p.pad_to_max_seq_length:
            assert p.source_max_length
            assert p.target_max_length

            if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
                # Set the input batch size as an int rather than a tensor.
                src_frames_shape = (self.InfeedBatchSize(),
                                    p.source_max_length, p.frame_size, 1)
                src_paddings_shape = (self.InfeedBatchSize(),
                                      p.source_max_length)
                tgt_shape = (self.InfeedBatchSize(), p.target_max_length)
            else:
                tf.logging.warning(
                    'Could not set static input shape since not all bucket batch sizes '
                    'are the same:', p.bucket_batch_limit)
                src_frames_shape = None
                src_paddings_shape = None
                tgt_shape = None

            src_frames = py_utils.PadBatchDimension(src_frames,
                                                    self.InfeedBatchSize(), 0)
            src_paddings = py_utils.PadBatchDimension(src_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            tgt_ids = py_utils.PadBatchDimension(tgt_ids,
                                                 self.InfeedBatchSize(), 0)
            tgt_labels = py_utils.PadBatchDimension(tgt_labels,
                                                    self.InfeedBatchSize(), 0)
            tgt_paddings = py_utils.PadBatchDimension(tgt_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            self._sample_ids = py_utils.PadBatchDimension(
                self._sample_ids, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of self._sample_ids after the above is
            # [BatchSize, 1] rather than [BatchSize].
            self._sample_ids = tf.squeeze(self._sample_ids, axis=1)
            self._sample_ids = tf.ensure_shape(self._sample_ids,
                                               self.InfeedBatchSize())

            audio_document_ids = py_utils.PadBatchDimension(
                audio_document_ids, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of audio_document_ids after the above is
            # [BatchSize, 1] rather than [BatchSize].
            audio_document_ids = tf.squeeze(audio_document_ids, axis=1)
            audio_document_ids = tf.ensure_shape(audio_document_ids,
                                                 self.InfeedBatchSize())

            num_utterances_in_audio_document = py_utils.PadBatchDimension(
                num_utterances_in_audio_document, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of num_utterances_in_audio_document after the above is
            # [BatchSize, 1] rather than [BatchSize].
            num_utterances_in_audio_document = tf.squeeze(
                num_utterances_in_audio_document, axis=1)
            num_utterances_in_audio_document = tf.ensure_shape(
                num_utterances_in_audio_document, self.InfeedBatchSize())

            src_frames = py_utils.PadSequenceDimension(src_frames,
                                                       p.source_max_length,
                                                       0,
                                                       shape=src_frames_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
            tgt_ids = py_utils.PadSequenceDimension(tgt_ids,
                                                    p.target_max_length,
                                                    0,
                                                    shape=tgt_shape)
            tgt_labels = py_utils.PadSequenceDimension(tgt_labels,
                                                       p.target_max_length,
                                                       0,
                                                       shape=tgt_shape)
            tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings,
                                                         p.target_max_length,
                                                         1,
                                                         shape=tgt_shape)

        tgt = py_utils.NestedMap(ids=tgt_ids,
                                 labels=tgt_labels,
                                 paddings=tgt_paddings,
                                 weights=1.0 - tgt_paddings)
        src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings)

        self._tgt = tgt
        self._src = src

        self._audio_document_ids = audio_document_ids
        self._num_utterances_in_audio_document = num_utterances_in_audio_document