Example #1
0
    def _testPackedInputs(self, dtype=tf.float32):
        p = self._DecoderParams()
        np.random.seed(_NUMPY_RANDOM_SEED)
        src_time = 5
        batch = 2
        emb_dims = 4
        tgt_time = 5
        src_enc = tf.constant(
            np.random.normal(size=[src_time, batch, p.source_dim]),
            dtype=dtype)
        paddings = tf.zeros([src_time, batch], dtype=dtype)
        tgt_ids = tf.constant(np.random.randint(20, size=[batch, tgt_time]),
                              dtype=tf.int32)
        tgt_labels = tf.constant(np.random.randint(20, size=[batch, tgt_time]),
                                 dtype=tf.int32)
        tgt_paddings = tf.zeros([batch, tgt_time], dtype=dtype)
        tgt_weights = 1.0 - tgt_paddings
        tgts = py_utils.NestedMap({
            'ids': tgt_ids,
            'labels': tgt_labels,
            'weights': tgt_weights,
            'paddings': tgt_paddings
        })

        src_enc_packed = tf.transpose(src_enc, [1, 0, 2])
        src_enc_packed = tf.reshape(src_enc_packed, [-1, 1, emb_dims])
        src_enc_padding_packed = tf.reshape(paddings, [-1, 1])
        target_packed = py_utils.NestedMap({
            'ids':
            tf.reshape(tgts.ids, [1, -1]),
            'labels':
            tf.reshape(tgts.labels, [1, -1]),
            'weights':
            tf.reshape(tgts.weights, [1, -1]),
            'paddings':
            tf.reshape(tgts.paddings, [1, -1])
        })
        src_segment_id = tf.transpose(
            tf.constant(np.asarray([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]),
                        dtype=tf.float32))
        target_packed.segment_ids = tf.constant(np.asarray(
            [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]),
                                                dtype=tf.float32)
        target_packed.segment_pos = tf.constant(
            np.asarray([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]))
        return (src_enc, paddings, tgts, src_enc_packed,
                src_enc_padding_packed, src_segment_id, target_packed)
Example #2
0
  def testBiEncoderForwardPassWithTransparent(self):
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)
      p = self._BiEncoderParams()
      p.is_transparent = True
      mt_enc = encoder.MTEncoderBiRNN(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out = enc_out.eval()
      expected_enc_out = [[[-7.4536911e-05, 8.8465633e-05],
                           [2.8940600e-05, 3.2297492e-05]],
                          [[-1.9775725e-05, 9.8312848e-05],
                           [5.1837378e-05, 1.2998647e-05]],
                          [[4.5528584e-05, -6.8125606e-05],
                           [1.0955606e-04, -2.1024598e-04]],
                          [[8.5454740e-05, -1.8263397e-04],
                           [5.2042866e-05, -1.6407830e-04]]]
      self.assertAllClose(expected_enc_out, actual_enc_out)
Example #3
0
    def _ProcessMASSInput(self, source_id, src):
        """Perform MASS input processing."""
        skip_mass = self.do_eval and not self.params.enable_mass_for_eval
        if skip_mass or self.mass_layer is None:
            # At eval time, we copy src to tgt
            return self._ProcessSingleInput(source_id, src, src)

        _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]),
                                                is_source=True,
                                                key=self._src_tokenizer_key)
        weights = 1 - paddings
        actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
        src_lang_ids, tgt_lang_ids = self._GetLangIds(source_id)

        mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = mass_out.src.ids
        features.src.paddings = paddings
        features.src.weights = weights
        features.src.task_ids = tf.cast(features.src.weights,
                                        dtype=tf.int32) * src_lang_ids
        features.src.source_ids = tf.cast(features.src.weights,
                                          dtype=tf.int32) * source_id
        features.src.ids_indicator = weights
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = mass_out.tgt.ids
        features.tgt.labels = mass_out.tgt.labels
        features.tgt.paddings = paddings
        features.tgt.weights = mass_out.tgt.weights
        features.tgt.task_ids = tf.cast(weights, dtype=tf.int32) * tgt_lang_ids
        features.tgt.source_ids = tf.cast(weights, dtype=tf.int32) * source_id
        features.tgt.ids_indicator = weights

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = src
        return features.Transform(tf.squeeze)
Example #4
0
  def _ProcessMASSInput(self, source_id, src):
    """Perform MASS input processing."""
    # TODO(yuancao): By doing so we assume that right now for monolingual
    # eval/dev sets (xx->xx) are in double-column format (since it bypasses
    # the Mass op). Ideally we should add a dedicated eval/dev processing
    # procedure for unsupervised MT cases, so that single-column eval/devs sets
    # are also supported. This should not be handled by any specific ops like
    # Mass, but inside the TextPackedInput class.
    assert not self.do_eval, 'MASS input can only be used for training.'

    _, labels, paddings = self.StringsToIds(
        tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
    weights = 1 - paddings
    actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
    src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id)

    mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

    features = py_utils.NestedMap()
    features.src = py_utils.NestedMap()
    features.src.ids = mass_out.src.ids
    features.src.paddings = paddings
    features.src.weights = weights
    features.src.task_ids = tf.cast(
        features.src.weights, dtype=tf.int32) * src_lang_ids
    features.src.ids_indicator = weights
    features.tgt = py_utils.NestedMap()
    features.tgt.ids = mass_out.tgt.ids
    features.tgt.labels = mass_out.tgt.labels
    features.tgt.paddings = paddings
    features.tgt.weights = mass_out.tgt.weights
    features.tgt.task_ids = tf.ones_like(
        features.src.task_ids, dtype=tf.int32) * tgt_lang_ids
    features.tgt.ids_indicator = weights

    if not py_utils.use_tpu():
      features.src.strs = src
      features.tgt.strs = src
    return features.Transform(tf.squeeze)
Example #5
0
    def testIntraModalLabels(self):
        # Simulate a batch of 4 examples with 2 items each in the 'text' modality.
        batch_size = 4
        items_per_example = 2
        modality = 'text'
        modality_shape = tf.TensorShape([batch_size, items_per_example])
        inputs = label_lib.ExamplePairs.WithinBatch(
            batch=dict(some_feature=tf.range(batch_size)),
            query_modality=modality,
            result_modality=modality)

        def example_pair_labeler(_):
            return tf.constant([
                [1, 0, 0, X],
                [0, 1, 0, 0],
                [0, 0, 1, 0],
                [X, 0, 0, 1],
            ])

        labeler = label_lib.MultiItemExampleWrapper(
            example_pair_labeler,
            modality_batch_shapes={modality: modality_shape})
        labels = labeler(inputs)
        self.assertEqual(modality_shape + modality_shape, labels.shape)
        # The pairwise labels actually have rank 4 (twice the rank of ids), but we
        # compare them in matrix form for easier inspection. There are 8 items
        # total. Each should have a positive label for every other item from the
        # same example. Self-pairs should be ignored (they are neither positive
        # nor negative pairs), as well as pairs from duplicated examples.
        self.assertAllEqual([
            [X, 1, 0, 0, 0, 0, X, X],
            [1, X, 0, 0, 0, 0, X, X],
            [0, 0, X, 1, 0, 0, 0, 0],
            [0, 0, 1, X, 0, 0, 0, 0],
            [0, 0, 0, 0, X, 1, 0, 0],
            [0, 0, 0, 0, 1, X, 0, 0],
            [X, X, 0, 0, 0, 0, X, 1],
            [X, X, 0, 0, 0, 0, 1, X],
        ], tf.reshape(labels, [8, 8]))
 def Proc(record):
     """Parses a serialized tf.Example record."""
     # There we go! string, string, float32. I hope frames is allowed
     # to be a waveform directly...
     features = [
         ('uttid', tf.io.VarLenFeature(tf.int64)),
         # Would like to change this to tf.int16 in the future, if that is possible (would have to read from
         ('frames', tf.io.VarLenFeature(tf.float32)),
     ]
     example = tf.io.parse_single_example(record, dict(features))
     fval = {k: v.values for k, v in example.items()}
     # Reshape the flattened vector into its original time-major
     # representation.
     fval['frames'] = tf.reshape(fval['frames'],
                                 shape=[-1, self.params.frame_size])
     # Input duration determines the bucket.
     bucket_key = tf.cast(tf.shape(fval['frames'])[0], tf.int32)
     if self.params.append_eos_frame:
         bucket_key += 1
     src_paddings = tf.zeros([tf.shape(fval['frames'])[0]],
                             dtype=tf.float32)
     return [fval['uttid'], fval['frames'], src_paddings], bucket_key
Example #7
0
    def testForwardPassPackedInput(self):
        with self.session(use_gpu=False) as sess:
            bs = 2
            sl = 21
            d = 16
            tf.random.set_seed(8372749040)
            p = self._EncoderParams(packed_input=True)

            mt_enc = p.Instantiate()
            batch = py_utils.NestedMap()
            batch.ids = tf.constant(
                np.random.randint(low=0,
                                  high=63,
                                  size=[bs, sl],
                                  dtype=np.int32))

            # Pack these into a single batch
            packed_bs = 1
            packed_sl = 2 * sl
            batch.ids = tf.reshape(batch.ids, [packed_bs, packed_sl])

            batch.paddings = tf.zeros([packed_bs, packed_sl])
            batch.segment_pos = [
                list(range(sl)) + list(range(sl)),
            ]
            batch.segment_ids = [
                [0 for i in range(sl)] + [1 for i in range(sl)],
            ]

            out = mt_enc.FPropDefaultTheta(batch)
            enc_out_sum = tf.reduce_sum(out.encoded)

            tf.global_variables_initializer().run()
            actual_enc_out, actual_enc_out_sum = sess.run(
                [out.encoded, enc_out_sum])

            self.assertAllEqual([packed_sl, packed_bs, d],
                                actual_enc_out.shape)
            self.assertAllClose(306.010132, actual_enc_out_sum)
Example #8
0
    def Unflatten(self, flat_tensors):
        """The inverse of Flatten(); expands the leading dim to `batch_shape`.

    Args:
      flat_tensors: A tensor or structure of tensors to be reshaped.

    Returns:
      The reshaped tensors, with `batch_shape.rank` - 1 more dimensions, in the
      same format (tensor, list, dict) as the input.
    """
        if self._is_no_op:
            return flat_tensors
        batch_shape = self._batch_shape.as_list()
        if batch_shape[0] is None:
            batch_shape[0] = -1

        unflattened_tensors = [
            tf.reshape(flat_tensor,
                       batch_shape + flat_tensor.shape.as_list()[1:])
            for flat_tensor in tf.nest.flatten(flat_tensors)
        ]
        return tf.nest.pack_sequence_as(flat_tensors, unflattened_tensors)
Example #9
0
  def testForwardPass(self):
    with self.session(use_gpu=False):
      tf.set_random_seed(8372749040)
      p = self._EncoderParams()
      mt_enc = encoder.MTEncoderV1(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      tf.global_variables_initializer().run()
      actual_enc_out = enc_out.eval()
      expected_enc_out = [[[
          -7.51581979e-07, 1.55304758e-06, -3.39117889e-07, 2.79457527e-06
      ], [-1.06733505e-05, 7.56898862e-06, -4.18875834e-06, -9.10360086e-06]], [
          [1.58444971e-06, 5.11627661e-07, 1.33408967e-05, 1.81603957e-06],
          [-1.59942228e-05, 1.26068180e-05, 4.49321249e-07, -1.43790385e-05]
      ], [[5.56546365e-06, -8.01007627e-06, 8.96620350e-06, 3.96485439e-06], [
          -8.77006005e-06, 4.04282991e-06, -4.79895652e-06, -5.90156833e-06
      ]], [[-8.59513818e-07, -7.63760727e-06, -5.57065960e-06, 1.80756274e-06],
           [-2.96017470e-06, -1.51323195e-06, -1.03562079e-05, 1.23328198e-06]]]
      self.assertAllClose(expected_enc_out, actual_enc_out)
    def _Extract(self, features):
        """Returns the image Tensor."""
        outputs = py_utils.NestedMap()
        p = self.params
        for camera_name in p.camera_names:
            image_shape = tf.reshape(
                _Dense(features['image_%s_shape' % camera_name]), [-1])
            image_shape = tf.cast(image_shape, tf.int32)

            if p.decode_image:
                image = tf.io.decode_png(
                    tf.strings.reduce_join(
                        _Dense(features['image_%s' % camera_name],
                               default_value='')))
                image = tf.reshape(image, image_shape)
                image = py_utils.PadOrTrimTo(image, p.image_shape)

            intrinsics = tf.reshape(
                _Dense(features['camera_%s_intrinsics' % camera_name]), [9])
            extrinsics = tf.reshape(
                _Dense(features['camera_%s_extrinsics' % camera_name]), [4, 4])
            pose = tf.reshape(_Dense(features['image_%s_pose' % camera_name]),
                              [4, 4])
            velocity = tf.reshape(
                _Dense(features['image_%s_velocity' % camera_name]), [6])

            outputs[camera_name] = py_utils.NestedMap()
            if p.decode_image:
                outputs[camera_name]['image'] = tf.cast(
                    image, p.image_output_dtype)
            outputs[camera_name]['image_shape'] = image_shape
            outputs[camera_name]['intrinsics'] = intrinsics
            outputs[camera_name]['extrinsics'] = extrinsics
            outputs[camera_name]['pose'] = pose
            outputs[camera_name]['velocity'] = velocity
            outputs[camera_name]['rolling_shutter_direction'] = features[
                'camera_%s_rolling_shutter_direction' % camera_name]

            for feat in [
                    'shutter', 'camera_trigger_time',
                    'camera_readout_done_time', 'pose_timestamp'
            ]:
                outputs[camera_name][feat] = features['image_%s_%s' %
                                                      (camera_name, feat)]

        return outputs
Example #11
0
 def Proc(record):
     """Parses a serialized tf.Example record."""
     features = [
         ('uttid', tf.VarLenFeature(tf.string)),
         ('transcript', tf.VarLenFeature(tf.string)),
         ('frames', tf.VarLenFeature(tf.float32)),
     ]
     example = tf.parse_single_example(record, dict(features))
     fval = {k: v.values for k, v in six.iteritems(example)}
     # Reshape the flattened vector into its original time-major
     # representation.
     fval['frames'] = tf.reshape(fval['frames'],
                                 shape=[-1, self.params.frame_size])
     # Input duration determines the bucket.
     bucket_key = tf.to_int32(tf.shape(fval['frames'])[0])
     if self.params.append_eos_frame:
         bucket_key += 1
     tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
         fval['transcript'])
     src_paddings = tf.zeros([tf.shape(fval['frames'])[0]],
                             dtype=tf.float32)
     return fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval[
         'frames'], src_paddings, bucket_key
Example #12
0
    def testUniEncoderForwardPass(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._UniEncoderParams()
            mt_enc = encoder.MTEncoderUniRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            tf.logging.info('testUniEncoderForwardPass actual_enc_out %r' %
                            actual_enc_out)
            expected_enc_out = [[[-4.3304257e-07, 5.4100457e-07],
                                 [-4.0170832e-07, -2.6441572e-07]],
                                [[-1.7024040e-07, -1.8555815e-07],
                                 [-6.4563977e-07, -3.7835261e-07]],
                                [[-2.4001852e-07, 5.1114228e-07],
                                 [-3.4349023e-07, -1.0049351e-06]],
                                [[1.8068013e-07, -6.8982729e-08],
                                 [3.3005003e-07, -8.8834116e-07]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
Example #13
0
  def testBiEncoderForwardPassWithDropout(self):
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)
      p = self._BiEncoderParams()
      p.dropout_prob = 0.5
      mt_enc = encoder.MTEncoderBiRNN(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out = enc_out.eval()
      print('bi_enc_actual_enc_out_with_dropout', np.array_repr(actual_enc_out))
      expected_enc_out = [[[1.60383240e-06, 1.22550023e-06],
                           [-7.21660126e-06, 1.05704457e-05]],
                          [[1.42539475e-05, -2.06075638e-05],
                           [-4.98754298e-06, 1.51066461e-05]],
                          [[-7.15192800e-06, -6.44075908e-06],
                           [5.02962678e-07, -3.40795486e-06]],
                          [[-6.54424548e-06, 9.88359807e-06],
                           [1.42836643e-06, -1.68607176e-06]]]
      self.assertAllClose(expected_enc_out, actual_enc_out)
Example #14
0
    def testBiEncoderForwardPass(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._BiEncoderParams()
            mt_enc = encoder.MTEncoderBiRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            tf.logging.info('testBiEncoderForwardPass actual_enc_out %r' %
                            actual_enc_out)
            expected_enc_out = [[[-2.47998378e-06, 7.36457878e-06],
                                 [7.89248020e-07, -2.67464316e-06]],
                                [[-2.98803275e-06, 8.20233890e-06],
                                 [1.00139073e-06, -2.24554151e-06]],
                                [[-5.06675951e-06, 1.15983785e-05],
                                 [-4.58391014e-07, -2.99553108e-07]],
                                [[-4.34937465e-06, 8.58816838e-06],
                                 [-1.74859031e-06, 3.99598093e-06]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
Example #15
0
  def FProp(self, theta):
    """Combines the list of input tensors into a single tensor.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.

    Returns:
      A tensor of weights with dropout applied with shape [num_sources].
    """
    p = self.params

    # The constant factor is just meant to support the non-normalized scenario.
    # If softmax is applied, this factor will cancel out.
    w = theta.sum_weight * p.global_weight_scale + (1 / p.num_sources)
    w = tf.reshape(w, [p.num_sources])
    w = self.weighted_merger_dropout.FProp(theta.weighted_merger_dropout, w)
    if p.weighted_merger_softmax:
      residual_weights = p.minimal_prob * p.num_sources
      assert residual_weights >= 0.0
      assert residual_weights < 1.0
      w = tf.nn.softmax(w, axis=0) * (1.0 - residual_weights) + p.minimal_prob
    return w
Example #16
0
def ConvertToBlocks(x, block_size, padding_val=0.0):
  """Turns a sequence to non overlapping blocks.

  Args:
    x: a tensor of [batch, time, ...].
    block_size: int. Number of time frames in a block.
    padding_val: float. value on the padded frames.

  Returns:
    A tensor of [batch, num_blocks, block_size, ...], with necessary paddings,
    where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  """
  shape = py_utils.GetShape(x)
  b, t = shape[:2]
  if block_size < 1:
    raise ValueError('block_size must be at least 1, got {}'.format(block_size))
  w = block_size
  # Pad t to be a multiply of w.
  num_blocks = (t + w - 1) // w
  pad_to_length = num_blocks * w
  padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val)
  reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:])
  return reshaped
Example #17
0
  def testForwardPass(self):
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)
      p = self._EncoderParams()
      mt_enc = encoder.MTEncoderV1(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out = enc_out.eval()
      expected_enc_out = [
          [[1.5309354e-06, -1.7816075e-07, 3.8047763e-06, -5.6422067e-07],
           [1.9017770e-06, -2.9778969e-06, -4.5083775e-06, -1.7054812e-06]],
          [[-2.1852782e-06, -1.8208171e-06, -1.4747930e-06, -5.8206351e-06],
           [6.7667429e-07, -3.6828042e-06, -1.0916860e-05, -3.2522742e-06]],
          [[-3.2333378e-07, 3.2147584e-06, 5.0556650e-07, -7.0188378e-07],
           [-6.5340635e-07, 1.9502845e-06, -9.2459632e-06, 5.1955390e-06]],
          [[2.0232728e-06, 4.9331529e-06, 1.1346837e-06, 7.5571520e-06],
           [-5.8475212e-07, 3.5547487e-06, -3.9037773e-06, 8.9575424e-06]]
      ]
      self.assertAllClose(expected_enc_out, actual_enc_out)
Example #18
0
    def testBiEncoderForwardPassWithDropout(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._BiEncoderParams()
            p.dropout_prob = 0.5
            mt_enc = encoder.MTEncoderBiRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            print('bi_enc_actual_enc_out_with_dropout',
                  np.array_repr(actual_enc_out))
            expected_enc_out = [[[-1.8358192e-05, 1.2103478e-05],
                                 [2.9347059e-06, -3.0652325e-06]],
                                [[-8.1282624e-06, 4.5443494e-06],
                                 [3.0826509e-06, -5.2950490e-06]],
                                [[-4.6669629e-07, 2.4246765e-05],
                                 [-1.5221613e-06, -1.9654153e-06]],
                                [[-1.1511075e-05, 1.9061190e-05],
                                 [-5.7250163e-06, 9.2785704e-06]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
Example #19
0
  def _CellFeaturizer(self, theta, input_batch):
    """Featurizes each center location."""
    # Validate Shapes
    cell_feature = py_utils.HasRank(input_batch.cell_feature, 4)
    batch_size, num_centers, num_points_per_cell = py_utils.GetShape(
        cell_feature, 3)

    cell_points_xyz = py_utils.HasShape(
        input_batch.cell_points_xyz,
        [batch_size, num_centers, num_points_per_cell, 3])
    cell_center_xyz = py_utils.HasShape(input_batch.cell_center_xyz,
                                        [batch_size, num_centers, 3])

    cell_points_padding = py_utils.HasShape(
        input_batch.cell_points_padding,
        [batch_size, num_centers, num_points_per_cell])

    # Center each cell
    cell_center_xyz = tf.reshape(cell_center_xyz,
                                 [batch_size, num_centers, 1, 3])
    centered_cell_points_xyz = cell_points_xyz - cell_center_xyz
    concat_feature = tf.concat([
        centered_cell_points_xyz, cell_feature
    ], axis=-1)  # pyformat: disable

    # Featurize point clouds at each center.
    point_input = py_utils.NestedMap({
        'points': centered_cell_points_xyz,
        'features': concat_feature,
        'padding': cell_points_padding,
    })
    featurized_cell = self.cell_featurizer.FProp(theta.cell_featurizer,
                                                 point_input)
    featurized_cell = py_utils.HasShape(featurized_cell,
                                        [batch_size, num_centers, -1])
    return featurized_cell
Example #20
0
  def AttenLogitsRPEOneStep(self, query, key, abs_pos_emb):
    """RPE attention logits for one single target (query) step.

    B: batch size
    S: sequence length
    N: num of attention heads.
    H: per-head attention dimension.

    Args:
      query:          [B, N, H].
      key:         [S, B, N, H] or [S, B, N*H/128, 128].
      abs_pos_emb: [S, 1, N, H]

    Returns:
      A Tensor of shape [S, B, N]
    """
    s, b, _, _ = py_utils.GetShape(key, 4)
    _, n, h = py_utils.GetShape(query, 3)
    key = tf.reshape(key, [s, b, n, h])

    key_emb = key + abs_pos_emb
    query, key_emb = self.ToAqtActActInputs(query, key_emb)
    logits = tf.einsum('BNH,SBNH->SBN', query, key_emb)
    return self.FromAqtActActMatmul(logits)
Example #21
0
    def testMergeBeamSearchOutputs(self):
        with self.session():
            topk_scores_1 = [[1., 3., 5.], [-2., -1., 0.]]
            topk_ids_1 = [[[10, 11, 12], [30, 31, 32], [50, 51, 52]],
                          [[20, 21, 22], [10, 11, 12], [0, 0, 0]]]
            topk_lens_1 = [[3, 3, 2], [3, 3, 0]]
            topk_hyps_1 = [['one', 'three', 'five'],
                           ['minus two', 'minus one', '']]
            topk_1 = beam_search_helper.BeamSearchDecodeOutput(
                None, tf.constant(topk_hyps_1),
                tf.reshape(tf.constant(topk_ids_1), [6, -1]),
                tf.reshape(tf.constant(topk_lens_1), [-1]),
                tf.reshape(tf.constant(topk_scores_1), [-1]), None, None)

            topk_scores_2 = [[2., 4.], [-3., 0.]]
            topk_ids_2 = [[[20, 21, 22], [40, 41, 42]],
                          [[30, 31, 33], [0, 0, 0]]]
            topk_lens_2 = [[3, 2], [3, 0]]
            topk_hyps_2 = [['two', 'four'], ['minus three', '']]
            topk_2 = beam_search_helper.BeamSearchDecodeOutput(
                None, tf.constant(topk_hyps_2),
                tf.reshape(tf.constant(topk_ids_2), [4, -1]),
                tf.reshape(tf.constant(topk_lens_2), [-1]),
                tf.reshape(tf.constant(topk_scores_2), [-1]), None, None)

            topk = beam_search_helper.MergeBeamSearchOutputs(
                3, [topk_1, topk_2])
            self.assertIsNone(topk.done_hyps)
            self.assertIsNone(topk.topk_decoded)
            self.assertAllEqual([5., 4., 3., -1., -2., -3.],
                                topk.topk_scores.eval())
            self.assertAllEqual([2, 2, 3, 3, 3, 3], topk.topk_lens.eval())
            self.assertAllEqual([[50, 51, 52], [40, 41, 42], [30, 31, 32],
                                 [10, 11, 12], [20, 21, 22], [30, 31, 33]],
                                topk.topk_ids.eval())
            self.assertAllEqual([[b'five', b'four', b'three'],
                                 [b'minus one', b'minus two', b'minus three']],
                                topk.topk_hyps.eval())
Example #22
0
 def _ReshapeTransform(inp):
     """Reshape the transformation tensor to [..., idims, odims]."""
     base_shape = py_utils.GetShape(inp)[:-1]
     out_shape = list(base_shape) + [idims, odims]
     return tf.reshape(inp, out_shape)
Example #23
0
 def _CombineLastTwoDims(x):
     shape = py_utils.GetShape(x)
     return tf.reshape(x, shape[:-2] + [np.prod(shape[-2:])])
Example #24
0
def NeighborhoodIndices(points,
                        query_points,
                        k,
                        points_padding=None,
                        max_distance=None,
                        sample_neighbors_uniformly=False):
    """Get indices to k-neighbors of query_points in points.

  Padding is returned along-side indices. Non-padded points are guaranteed to
  be unique (non-repeated) points from original non-padded points.

  Padded points arise due to either a lack of points (k exceeds the number
  of original non-padded points) or points are too far away (exceeds max
  distance).

  Note: Padded point indices may refer to padded points from the original, or
  may be duplicates of the closest point.

  TODO(weihan,jngiam): PointCNN implementation makes an assumption that padded
  points are repeated points from the original points. This behavior is
  maintained here, but we should update PointCNN to respect indices paddings.

  Args:
    points: tensor of shape [N, P1, dims].
    query_points: tensor of shape [N, P2, dims]
    k: Integer.
    points_padding: optional tensor of shape [N, P1] containing True/1.0 iff the
      point is a padded point. if None, then all points are considered real
      points.
    max_distance: float representing the maximum distance that each neighbor can
      be. If there are no points within the distance, then the closest point is
      returned (regardless of distance). If this is set to None, then no
      filtering by distance is performed.
    sample_neighbors_uniformly: boolean specifying whether to sample neighbors
      uniformly if they are within max distance.

  Returns:
    A pair of tensors:

    - indices: tensor of shape [N, P2, k].
    - padding: tensor of shape [N, P2, k] where 1 represents a padded point, and
      0 represents an unpadded (real) point.

  """
    n, p1 = py_utils.GetShape(points, 2)
    query_points = py_utils.HasShape(query_points, [n, -1, -1])
    _, p2 = py_utils.GetShape(query_points, 2)

    # Compute pair-wise squared distances.
    # Note that dist_mat contains the squared distance (without sqrt). Thus, when
    # using max_distance, we will need to square max_distance to make sure it's
    # in the same units.
    dist_mat = SquaredDistanceMatrix(query_points, points)
    dist_mat = py_utils.HasShape(dist_mat, [n, p2, p1])

    # Add a large scalar to the distances for padded points.
    # dist_mat[i, j, k] will be:
    #   if k < valid_num[i]: distance between points[i, k] and query_points[i, j]
    #   otherwise:           a large scalar added to dist_mat[i, j, k]
    if points_padding is not None:
        points_padding = tf.cast(tf.expand_dims(points_padding, 1), tf.float32)
        points_padding = py_utils.HasShape(points_padding, [n, 1, p1])
        large_scalar = tf.reduce_max(dist_mat) + 1
        dist_mat += points_padding * large_scalar

    # To perform sampling neighbors uniformly efficiently, we set all neighbors
    # that are within the distance threshold to have distances be drawn uniformly
    # at random. Using top_k with this enables selecting a random set quickly
    # without replacement.
    if sample_neighbors_uniformly:
        if max_distance is not None:
            mask_by_distance = tf.less_equal(dist_mat, max_distance**2)
            dist_mat = tf.where(
                mask_by_distance,
                tf.square(max_distance) *
                tf.random_uniform(tf.shape(dist_mat)), dist_mat)
        else:
            raise ValueError(
                'Uniform sampling requires specifying max_distance.')

    top_k_dist, indices = tf.nn.top_k(-dist_mat, k=k,
                                      sorted=True)  # N x P2 x K

    # Set padding using top_k_dist; padded points will have distance exceeding
    # the large_scalar.
    if points_padding is not None:
        paddings = tf.greater_equal(-top_k_dist, large_scalar)
    else:
        paddings = tf.zeros_like(top_k_dist, dtype=tf.bool)

    # Filter by max_distances by setting all indices that exceed the max_distance
    # to the closest point.
    if max_distance is not None:
        # Mask is true for points that are further than max_distance.
        mask_by_distance = tf.greater(-top_k_dist, tf.square(max_distance))
        closest_idx = tf.tile(indices[:, :, :1], [1, 1, k])
        indices = tf.where(mask_by_distance, closest_idx, indices)
        paddings |= mask_by_distance

    indices = tf.reshape(indices, [n, p2, k])
    paddings = tf.cast(paddings, tf.float32)

    return indices, paddings
Example #25
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

        p = self.params
        with tf.name_scope(p.name):
            # [batch, time]
            input_ids = input_batch.ids
            # [batch, time]
            paddings = input_batch.paddings

            # [batch, time]
            segment_ids = input_batch.segment_ids if p.packed_input else None

            batch = py_utils.GetShape(input_ids)[0]
            time = py_utils.GetShape(input_ids)[1]

            # Embedding layer.
            # [batch, time, dim]
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                      input_ids)
            else:
                input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
            orig_input_embs = input_embs

            # [1, time, dim]
            if p.packed_input:
                positions = input_batch.segment_pos
                position_embs = tf.expand_dims(
                    self.position_emb.FPropWithPosition(
                        theta.position_emb, positions), 0)
            else:
                position_embs = tf.expand_dims(
                    self.position_emb.FProp(theta.position_emb, time), 0)

            # [batch, time, dim]
            input_embs += position_embs

            if p.input_dropout_tpl.fprop_dtype:
                input_embs = tf.cast(input_embs,
                                     p.input_dropout_tpl.fprop_dtype)
                paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)
            # [batch, time, dim]
            transformer_input = input_embs
            # Explicitly set the input shape of Transformer layers, to avoid
            # unknown shape error occurred to tf.einsum on nonTPU devices.
            transformer_input = tf.reshape(transformer_input,
                                           [batch, time, p.model_dim])

            # Compute self-attention segment mask once.
            if p.packed_input:
                segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids, dtype=transformer_input.dtype)
            else:
                segment_mask = tf.zeros([batch, 1, time, time])

            encoded, padding = self.transformer_stack.FProp(
                theta.transformer_stack, transformer_input, paddings,
                segment_mask)

            if p.final_layer_norm:
                encoded = self.final_ln.FProp(theta.final_ln, encoded)

            seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1),
                                  tf.int32)

            if p.output_data_format == 'TBC':
                encoded = tf.transpose(encoded,
                                       [1, 0, 2])  # [time, batch, dim]
                padding = tf.transpose(padding)  # [time, batch]
                segment_ids = tf.transpose(
                    segment_ids) if p.packed_input else None
                orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

            return py_utils.NestedMap(
                encoded=encoded,
                padding=padding,
                seq_lengths=seq_lengths,  # used by beam_search_helper.
                segment_id=segment_ids,
                embedded_inputs=orig_input_embs)
Example #26
0
    def FProp(self,
              theta,
              source_input,
              source_paddings,
              target_input=None,
              target_paddings=None,
              source_segment_id=None,
              target_segment_id=None,
              labels=None,
              label_weights=None,
              source_pos_id=None,
              target_pos_id=None,
              source_task_id=None,
              target_task_id=None):
        """Transforms source sequence of Tensors with Transformers layers.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_input:  A sequence of ints indicating source input ids of [time,
        batch] shape or [batch, time] if batch_dim is 0.
      source_paddings: A sequence of 0s and 1s indicating input paddings of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      target_input: A sequence of ints indicating target input ids of [time,
        batch] shape or [batch, time] if batch_dim is 0.
      target_paddings: [target_time, target_batch] or [target_batch,
        target_time] if batch_dim is 0.
      source_segment_id: A sequence of ints indicating source segment ids of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      target_segment_id: A sequence of ints indicating target segment ids of
        [time, batch] shape or [batch, time] if batch_dim is 0.
      labels: A sequence of ints indicating label ids of [time, batch] shape,
        or [batch, time] if batch_dim is 0.
      label_weights: A sequence of floats indicates label weights of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      source_pos_id: A sequence of ints indicating source position ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      target_pos_id: A sequence of ints indicating target position ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      source_task_id: A sequence of ints indicating source task ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.
      target_task_id: A sequence of ints indicating target task ids of [time,
        batch] shape, or [batch, time] if batch_dim is 0.

    Returns:
      transformer_output with shape [time, batch, dim] or [batch, time, dim]
      if batch_dim is 0.
    """
        p = self.params
        if p.num_decoder_layers > 0:
            assert target_input is not None
            assert target_paddings is not None
        if p.packed_input:
            assert source_segment_id is not None, (
                'Need to specify src_segment_id if packed input is supported.')
            assert source_pos_id is not None, (
                'Need to specify src_pos_id for packed input and embeddings.')

        logits = super(GPipeTransformerStack,
                       self).FProp(theta, source_input, source_paddings,
                                   target_input, target_paddings,
                                   source_segment_id, target_segment_id,
                                   source_pos_id, target_pos_id,
                                   source_task_id, target_task_id)
        if not p.softmax_tpl:
            return logits
        label_weights = tf.reshape(label_weights, [-1])
        target_probs = None
        if p.label_smoothing:
            if p.batch_dim:  # Time-major
                target_probs = tf.transpose(
                    self.smoother.FProp(theta.smoother,
                                        tf.transpose(target_paddings),
                                        tf.transpose(labels),
                                        target_ids=None), [1, 0, 2])
            else:
                target_probs = self.smoother.FProp(theta.smoother,
                                                   target_paddings,
                                                   labels,
                                                   target_ids=None)
            target_probs = tf.reshape(target_probs,
                                      [-1, p.softmax_tpl.num_classes])
        reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes])
        tgt_labels = tf.reshape(labels, [-1])
        num_splits = len(p.splits)
        softmax = self.children['cell_{}'.format(num_splits - 1)].softmax
        softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax
        per_example_xent, _ = softmax.XentLossFromLogits(
            softmax_theta,
            reshaped_logits,
            class_weights=tf.reshape(label_weights, [-1]),
            class_ids=tgt_labels,
            class_probabilities=target_probs)
        xent_shape = tf.shape(logits)[:2]
        per_example_xent = tf.reshape(per_example_xent, xent_shape)
        return per_example_xent, logits
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
        py_utils.assert_equal(source_batch,
                              tf.shape(output.topk_hyps)[0]),
    ],
                                               tf.shape(output.topk_hyps)[1])
    for k, v in six.iteritems(output._asdict()):
      if v is None:
        continue
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
      scores)
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
    else:
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
  def BeamSearchDecode(self,
                       theta,
                       encoder_outputs,
                       num_hyps_per_beam_override=0,
                       init_beam_search_state=None,
                       pre_beam_search_step_callback=None,
                       post_beam_search_step_callback=None,
                       max_steps=None):
    """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
    p = self.params
    num_hyps_per_beam = p.num_hyps_per_beam
    if num_hyps_per_beam_override > 0:
      num_hyps_per_beam = num_hyps_per_beam_override
    if max_steps is None:
      max_steps = p.target_seq_len

    initial_results, other_states = init_beam_search_state(
        theta, encoder_outputs, num_hyps_per_beam)

    num_hyps = tf.shape(initial_results.log_probs)[0]
    num_beams = num_hyps // num_hyps_per_beam

    if 'step_ids' in initial_results:
      # [num_hyps, 1]
      step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
    else:
      step_ids = tf.fill([num_hyps, 1],
                         tf.constant(p.target_sos_id, dtype=tf.int32))

    min_score = -1e36
    best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
    cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
    in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
    in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
    bs_atten_probs = tf.zeros(
        [max_steps, num_hyps,
         tf.shape(initial_results.atten_probs)[1]],
        dtype=p.dtype)
    cur_step = tf.constant(0, dtype=tf.int32)
    all_done = tf.constant(False, dtype=tf.bool)
    core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                      in_prev_hyps, in_done_hyps, bs_atten_probs)

    def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states,
                     unused_other_states_list):
      return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done))

    def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                 other_states_list):
      (cur_step, all_done, new_step_ids, new_bs_states,
       new_other_states) = self._BeamSearchStep(
           theta, encoder_outputs, cur_step, step_ids, core_bs_states,
           other_states.Pack(other_states_list), num_hyps_per_beam,
           pre_beam_search_step_callback, post_beam_search_step_callback)
      return (cur_step, all_done, new_step_ids, new_bs_states,
              new_other_states.Flatten())

    flat_other_states = other_states.Flatten()
    _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
        LoopContinue,
        LoopBody,
        loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                   flat_other_states),
        parallel_iterations=10,
        back_prop=False,
        swap_memory=False,
        shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                          tf.TensorShape(all_done.get_shape()),
                          tf.TensorShape(step_ids.get_shape()),
                          _GetShapes(core_bs_states),
                          _GetShapes(flat_other_states, none_shapes=True)))
    # [target_seq_len, num_beams * num_hyps_per_beam].
    final_done_hyps = final_bs_states[5]
    final_other_states = other_states.Pack(flat_final_other_states)

    # TODO(rpang): avoid inspecting 'encoder_outputs'.
    source_paddings = encoder_outputs.padding
    if isinstance(source_paddings, py_utils.NestedMap):
      source_seq_lengths = tf.cast(
          tf.round(
              tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
    else:
      source_seq_lengths = tf.cast(
          tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
          tf.int32)

    # [num_beams, num_hyps_per_beam].
    topk_hyps = ops.top_k_terminated_hyps(
        final_done_hyps,
        source_seq_lengths,
        k=num_hyps_per_beam,
        num_hyps_per_beam=num_hyps_per_beam,
        length_normalization=p.length_normalization,
        coverage_penalty=p.coverage_penalty,
        target_seq_length_ratio=p.target_seq_length_ratio,
        eoc_id=p.target_eoc_id,
        merge_paths=p.merge_paths)
    # [num_beams * num_hyps_per_beam, ...].
    max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
    topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
        tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
    # [num_beams, num_hyps_per_beam].
    topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

    return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                  topk_lens, topk_scores, None,
                                  final_other_states)
  def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                      core_bs_states, other_states, num_hyps_per_beam,
                      pre_beam_search_step_callback,
                      post_beam_search_step_callback):
    """Extend beam search hyps for one step.

      | num_beams = Number of source sequences to be decoded.
      | num_hyps_per_beam = Number of hyps to keep per source sequence.
      | num_hyps = num_beams * num_hyps_per_beam
      | src_seq_len = Number of time steps in the source sequence.
      | src_batch = Number of examples in the source sequence.
      | tgt_seq_len = Maximum allowed time steps in the target sequence.
      | tgt_batch = num_hyps_per_beam * src_batch

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      core_bs_states: A tuple of core beam search states. This list is
        maintained by this helper class.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      num_hyps_per_beam: Num of hyps to keep per beam.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next beam search step,
      (next step, all_done, step_ids, core_bs_states, other_states)
    """
    p = self.params

    bs_results, other_states = pre_beam_search_step_callback(
        theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam)

    (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps,
     in_done_hyps, in_atten_probs) = core_bs_states

    (out_best_scores, out_cumulative_scores, out_scores, out_hyps,
     out_prev_hyps, out_done_hyps, out_atten_probs,
     all_done) = ops.beam_search_step(
         bs_results.log_probs,
         bs_results.atten_probs,
         best_scores,
         cumulative_scores,
         in_scores,
         in_hyps,
         in_prev_hyps,
         in_done_hyps,
         in_atten_probs,
         bs_results.is_last_chunk if self._model_uses_eoc_id else [],
         cur_step,
         eoc_id=p.target_eoc_id,
         eos_id=p.target_eos_id,
         beam_size=p.beam_size,
         num_hyps_per_beam=num_hyps_per_beam,
         valid_eos_max_logit_delta=p.valid_eos_max_logit_delta,
         merge_paths=p.merge_paths,
         allow_empty_terminated_hyp=p.allow_empty_terminated_hyp,
         ensure_full_beam=p.ensure_full_beam,
         force_eos_in_last_step=p.force_eos_in_last_step,
         local_eos_threshold=p.local_eos_threshold)

    new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids))
    new_step_ids.set_shape(step_ids.get_shape())

    old_hyp_ids = tf.reshape(
        tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1])

    if p.batch_major_compute:
      # Transformed the indices into the key/value cache for fast decoding
      # (prefix_states in other_states) due to the num_hyps dimension of
      # cache is computed as num_beams by num_hyps_per_beam, which is different
      # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams).
      # Both transpose and recomputation are required to correct the indices.
      num_beams = tf.shape(best_scores)[0]
      old_hyp_ids_in_cache_order = tf.reshape(
          tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1])
      old_hyp_ids_in_cache_order = (
          (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam +
          old_hyp_ids_in_cache_order // num_beams)

    new_bs_states = (out_best_scores, out_cumulative_scores, out_scores,
                     out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs)

    def ReOrderHyps(x_in):
      """Reorders x_in based on prev hyp ids."""
      if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and
          x_in.shape.ndims > 0):
        if x_in.shape.ndims > 2 and not p.batch_major_state:
          # Use corrected indices only here for batch major compute as key/value
          # caches are the states being affected.
          correct_old_hyp_ids = (
              old_hyp_ids_in_cache_order
              if p.batch_major_compute else old_hyp_ids)
          x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
        else:
          x_out = tf.gather(x_in, old_hyp_ids)
        x_out.set_shape(x_in.get_shape())
        return x_out
      else:
        return x_in

    new_other_states = other_states.Transform(ReOrderHyps)

    final_other_states = post_beam_search_step_callback(theta, encoder_outputs,
                                                        new_step_ids,
                                                        new_other_states)

    return (cur_step + 1, all_done, new_step_ids, new_bs_states,
            final_other_states)
Example #30
0
    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx