示例#1
0
    def testBiEncoderForwardPassWithTransparent(self):
        with self.session(use_gpu=False):
            tf.set_random_seed(8372749040)
            p = self._BiEncoderParams()
            p.is_transparent = True
            mt_enc = encoder.MTEncoderBiRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            tf.global_variables_initializer().run()
            actual_enc_out = enc_out.eval()
            expected_enc_out = [[[-7.4536911e-05, 8.8465633e-05],
                                 [2.8940600e-05, 3.2297492e-05]],
                                [[-1.9775725e-05, 9.8312848e-05],
                                 [5.1837378e-05, 1.2998647e-05]],
                                [[4.5528584e-05, -6.8125606e-05],
                                 [1.0955606e-04, -2.1024598e-04]],
                                [[8.5454740e-05, -1.8263397e-04],
                                 [5.2042866e-05, -1.6407830e-04]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
示例#2
0
    def testMaxCanvasSizeUnderUniformRollinPolicy(self):
        """Tests for valid canvas size."""
        with self.session(use_gpu=True) as sess:
            params = insertion.SymbolInsertionLayer.Params()
            params.name = 'insertion'
            params.rollin_policy = 'oracle'
            params.oracle_policy = 'uniform'

            insertion_layer = insertion.SymbolInsertionLayer(params)

            batch_size = 4
            time_dim = 10

            inputs = tf.tile(tf.expand_dims(tf.range(time_dim), 0),
                             [batch_size, 1])
            inputs_len = tf.random.uniform([batch_size], 0, time_dim, tf.int32)
            paddings = 1 - tf.sequence_mask(inputs_len, time_dim, tf.int32)
            spec = insertion_layer.FProp(None,
                                         inputs,
                                         paddings,
                                         force_sample_last_token=False)

            canvas_with_max_length = False
            for _ in range(1000):
                canvas_max_len, canvas, canvas_paddings = sess.run(
                    [inputs_len, spec.canvas, spec.canvas_paddings])

                for b in range(batch_size):
                    max_len = canvas_max_len[b]
                    length = np.sum(1 - canvas_paddings[b, :]).astype(np.int32)
                    canvas_with_max_length |= length == max_len
                    self.assertLessEqual(length, max_len)
                    # Invalid entries of canvas should be 0.
                    self.assertAllEqual(canvas[b, length:],
                                        [0] * (canvas.shape[1] - length))

            # With high probability, there should be at least one canvas that is
            # of the same size as the maximum canvas size.
            self.assertEqual(canvas_with_max_length, True)
示例#3
0
 def testFarthestPointSamplerGatherPoints(self):
     points = tf.constant([
         [[0, 1, 1], [1, 1, 1], [2, 1, 1], [3, 1, 1], [4, 1, 1], [5, 1, 1]],
         [[0, 2, 1], [1, 2, 1], [2, 2, 1], [3, 2, 1], [4, 2, 1], [5, 2, 1]],
         [[0, 2, 3], [1, 2, 3], [2, 2, 3], [3, 2, 3], [4, 2, 3], [5, 2, 3]],
         [[0, 2, 1], [1, 2, 1], [2, 2, 1], [3, 2, 1], [4, 2, 1], [5, 2, 1]],
     ],
                          dtype=tf.float32)  # pyformat: disable
     padding = tf.zeros((4, 6), dtype=tf.float32)
     n = 4
     num_points = 3
     selected_idx, _ = car_lib.FarthestPointSampler(points, padding,
                                                    num_points)
     gather_indices = tf.stack([
         tf.tile(tf.expand_dims(tf.range(n), 1), [1, num_points]),
         selected_idx
     ],
                               axis=2)
     sampled_points = tf.gather_nd(points, gather_indices)
     with self.session() as sess:
         sampled_points = sess.run(sampled_points)
         self.assertEqual(sampled_points.shape, (n, num_points, 3))
示例#4
0
 def testSpectrumAugmenterWithDynamicTimeWarping(self):
   with self.session(use_gpu=False, graph=tf.Graph()):
     tf.random.set_seed(1234)
     inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (3, 10))
     inputs = tf.expand_dims(tf.expand_dims(inputs, -1), -1)
     paddings = []
     for i in range(3):
       paddings.append(
           tf.concat([tf.zeros([1, 2 * i + 5]),
                      tf.ones([1, 5 - 2 * i])],
                     axis=1))
     paddings = tf.concat(paddings, axis=0)
     p = spectrum_augmenter.SpectrumAugmenter.Params()
     p.name = 'specAug_layers'
     p.freq_mask_max_bins = 0
     p.time_mask_max_frames = 0
     p.time_warp_max_ratio = 0.5
     p.time_warp_bound = 'dynamic'
     p.random_seed = 34567
     specaug_layer = p.Instantiate()
     # pyformat: disable
     # pylint: disable=bad-whitespace,bad-continuation
     expected_output = np.array(
         [[[[0.0000000]], [[1.0000000]], [[2.0000000]], [[3.0000000]],
           [[4.0000000]], [[5.0000000]], [[6.0000000]], [[7.0000000]],
           [[8.0000000]], [[9.0000000]]],
          [[[0.0000000]], [[0.8333333]], [[1.6666666]], [[2.5000000]],
           [[3.3333333]], [[4.1666665]], [[5.0000000]], [[7.0000000]],
           [[8.0000000]], [[9.0000000]]],
          [[[0.0000000]], [[2.0000000]], [[2.8750000]], [[3.7500000]],
           [[4.6250000]], [[5.5000000]], [[6.3750000]], [[7.2500000]],
           [[8.1250000]], [[9.0000000]]]])
     # pylint: enable=bad-whitespace,bad-continuation
     # pyformat: enable
     h, _ = specaug_layer.FPropDefaultTheta(inputs, paddings)
     actual_layer_output = self.evaluate(h)
     print(np.array_repr(actual_layer_output))
     self.assertAllClose(actual_layer_output, expected_output)
示例#5
0
  def testIntraModalLabels(self):
    # Simulate a batch of 4 examples with 2 items each in the 'text' modality.
    batch_size = 4
    items_per_example = 2
    modality = 'text'
    modality_shape = tf.TensorShape([batch_size, items_per_example])
    inputs = label_lib.ExamplePairs.WithinBatch(
        batch=dict(some_feature=tf.range(batch_size)),
        query_modality=modality,
        result_modality=modality)

    def example_pair_labeler(_):
      return tf.constant([
          [1, 0, 0, X],
          [0, 1, 0, 0],
          [0, 0, 1, 0],
          [X, 0, 0, 1],
      ])

    labeler = label_lib.MultiItemExampleWrapper(
        example_pair_labeler, modality_batch_shapes={modality: modality_shape})
    labels = labeler(inputs)
    self.assertEqual(modality_shape + modality_shape, labels.shape)
    # The pairwise labels actually have rank 4 (twice the rank of ids), but we
    # compare them in matrix form for easier inspection. There are 8 items
    # total. Each should have a positive label for every other item from the
    # same example. Self-pairs should be ignored (they are neither positive
    # nor negative pairs), as well as pairs from duplicated examples.
    self.assertAllEqual([
        [X, 1, 0, 0, 0, 0, X, X],
        [1, X, 0, 0, 0, 0, X, X],
        [0, 0, X, 1, 0, 0, 0, 0],
        [0, 0, 1, X, 0, 0, 0, 0],
        [0, 0, 0, 0, X, 1, 0, 0],
        [0, 0, 0, 0, 1, X, 0, 0],
        [X, X, 0, 0, 0, 0, X, 1],
        [X, X, 0, 0, 0, 0, 1, X],
    ], tf.reshape(labels, [8, 8]))
示例#6
0
  def testForwardPass(self):
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)
      p = self._EncoderParams()
      mt_enc = encoder.MTEncoderV1(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out = enc_out.eval()
      expected_enc_out = [
          [[1.5309354e-06, -1.7816075e-07, 3.8047763e-06, -5.6422067e-07],
           [1.9017770e-06, -2.9778969e-06, -4.5083775e-06, -1.7054812e-06]],
          [[-2.1852782e-06, -1.8208171e-06, -1.4747930e-06, -5.8206351e-06],
           [6.7667429e-07, -3.6828042e-06, -1.0916860e-05, -3.2522742e-06]],
          [[-3.2333378e-07, 3.2147584e-06, 5.0556650e-07, -7.0188378e-07],
           [-6.5340635e-07, 1.9502845e-06, -9.2459632e-06, 5.1955390e-06]],
          [[2.0232728e-06, 4.9331529e-06, 1.1346837e-06, 7.5571520e-06],
           [-5.8475212e-07, 3.5547487e-06, -3.9037773e-06, 8.9575424e-06]]
      ]
      self.assertAllClose(expected_enc_out, actual_enc_out)
示例#7
0
def _BatchScatter(default_tensor, indices, values):
    """Performs tf.tensor_scatter_nd_update for each batch item.

  Args:
    default_tensor: A float tensor of shape [batch, vocab] that contains the
      default values.
    indices: An int tensor of shape [batch, k] that represents the k indices of
      `default_tensor` to update.
    values: A float tensor of shape [batch, k] that represents the value to
      replace with for each corresponding element of `indices`.

  Returns:
    A tensor like `default_tensor` where the (i, indices[i][j]) element has been
    replaced with values[i][j].
  """
    batch_size = tf.shape(default_tensor)[0]
    # Prepend batch indices to `indices`.
    batch_indices = tf.range(batch_size, dtype=indices.dtype)
    batch_indices = tf.expand_dims(batch_indices, 1)
    batch_indices = tf.broadcast_to(batch_indices, tf.shape(indices))
    batch_indices = tf.stack([batch_indices, indices], axis=2)

    return tf.tensor_scatter_nd_update(default_tensor, batch_indices, values)
示例#8
0
    def testUniEncoderForwardPass(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._UniEncoderParams()
            mt_enc = encoder.MTEncoderUniRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            tf.logging.info('testUniEncoderForwardPass actual_enc_out %r' %
                            actual_enc_out)
            expected_enc_out = [[[-4.3304257e-07, 5.4100457e-07],
                                 [-4.0170832e-07, -2.6441572e-07]],
                                [[-1.7024040e-07, -1.8555815e-07],
                                 [-6.4563977e-07, -3.7835261e-07]],
                                [[-2.4001852e-07, 5.1114228e-07],
                                 [-3.4349023e-07, -1.0049351e-06]],
                                [[1.8068013e-07, -6.8982729e-08],
                                 [3.3005003e-07, -8.8834116e-07]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
示例#9
0
    def testBiEncoderForwardPass(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._BiEncoderParams()
            mt_enc = encoder.MTEncoderBiRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            tf.logging.info('testBiEncoderForwardPass actual_enc_out %r' %
                            actual_enc_out)
            expected_enc_out = [[[-2.47998378e-06, 7.36457878e-06],
                                 [7.89248020e-07, -2.67464316e-06]],
                                [[-2.98803275e-06, 8.20233890e-06],
                                 [1.00139073e-06, -2.24554151e-06]],
                                [[-5.06675951e-06, 1.15983785e-05],
                                 [-4.58391014e-07, -2.99553108e-07]],
                                [[-4.34937465e-06, 8.58816838e-06],
                                 [-1.74859031e-06, 3.99598093e-06]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
示例#10
0
  def testBiEncoderForwardPassWithDropout(self):
    with self.session(use_gpu=False):
      tf.random.set_seed(8372749040)
      p = self._BiEncoderParams()
      p.dropout_prob = 0.5
      mt_enc = encoder.MTEncoderBiRNN(p)
      batch = py_utils.NestedMap()
      batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
      batch.paddings = tf.zeros([2, 4])
      enc_out = mt_enc.FPropDefaultTheta(batch).encoded

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out = enc_out.eval()
      print('bi_enc_actual_enc_out_with_dropout', np.array_repr(actual_enc_out))
      expected_enc_out = [[[1.60383240e-06, 1.22550023e-06],
                           [-7.21660126e-06, 1.05704457e-05]],
                          [[1.42539475e-05, -2.06075638e-05],
                           [-4.98754298e-06, 1.51066461e-05]],
                          [[-7.15192800e-06, -6.44075908e-06],
                           [5.02962678e-07, -3.40795486e-06]],
                          [[-6.54424548e-06, 9.88359807e-06],
                           [1.42836643e-06, -1.68607176e-06]]]
      self.assertAllClose(expected_enc_out, actual_enc_out)
示例#11
0
    def testMultipleResultsPerExample(self):

        # Simple batch of 3 examples with 2 items per example in the result
        # modality.
        batch_size = 3
        results_per_example = 2

        inputs = label_lib.ExamplePairs.WithinBatch(
            batch=dict(some_feature=tf.range(batch_size)),
            query_modality='q',
            result_modality='r')

        def example_pair_labeler(_):
            return tf.constant([
                [1, 0, 0],
                [0, 1, X],
                [0, X, 1],
            ],
                               dtype=tf.int64)

        multi_item_labeler = label_lib.MultiItemExampleWrapper(
            example_pair_labeler,
            modality_batch_shapes=dict(q=tf.TensorShape([None]),
                                       r=tf.TensorShape(
                                           [None, results_per_example])))
        labels = multi_item_labeler(inputs)
        self.assertEqual([batch_size, batch_size, results_per_example],
                         labels.shape.as_list())
        # [3, 3, 2]
        expected_labels = [
            # pyformat: disable
            [[1, 1], [0, 0], [0, 0]],
            [[0, 0], [1, 1], [X, X]],
            [[0, 0], [X, X], [1, 1]]
            # pyformat: enable
        ]
        self.assertAllEqual(expected_labels, labels)
示例#12
0
    def testBiEncoderForwardPassWithDropout(self):
        with self.session(use_gpu=False):
            tf.random.set_seed(8372749040)
            p = self._BiEncoderParams()
            p.dropout_prob = 0.5
            mt_enc = encoder.MTEncoderBiRNN(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.transpose(tf.reshape(tf.range(0, 8, 1), [4, 2]))
            batch.paddings = tf.zeros([2, 4])
            enc_out = mt_enc.FPropDefaultTheta(batch).encoded

            self.evaluate(tf.global_variables_initializer())
            actual_enc_out = enc_out.eval()
            print('bi_enc_actual_enc_out_with_dropout',
                  np.array_repr(actual_enc_out))
            expected_enc_out = [[[-1.8358192e-05, 1.2103478e-05],
                                 [2.9347059e-06, -3.0652325e-06]],
                                [[-8.1282624e-06, 4.5443494e-06],
                                 [3.0826509e-06, -5.2950490e-06]],
                                [[-4.6669629e-07, 2.4246765e-05],
                                 [-1.5221613e-06, -1.9654153e-06]],
                                [[-1.1511075e-05, 1.9061190e-05],
                                 [-5.7250163e-06, 9.2785704e-06]]]
            self.assertAllClose(expected_enc_out, actual_enc_out)
 def testSpectrumAugmenterWarpMatrixConstructor(self):
   with self.session(use_gpu=False, graph=tf.Graph()):
     inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (4, 10))
     origin = tf.cast([2, 4, 4, 5], dtype=tf.float32)
     destination = tf.cast([3, 2, 6, 8], dtype=tf.float32)
     choose_range = tf.cast([4, 8, 8, 10], dtype=tf.float32)
     outputs = []
     for p in [
         spectrum_augmenter.SpectrumAugmenter.Params(),
         spectrum_augmenter_on_device.SpectrumAugmenterOnDevice.Params()
     ]:
       p.name = 'specAug_layers'
       specaug_layer = p.Instantiate()
       warp_matrix = specaug_layer._ConstructWarpMatrix(
           batch_size=4,
           matrix_size=10,
           origin=origin,
           destination=destination,
           choose_range=choose_range,
           dtype=tf.float32)
       output = tf.einsum('bij,bj->bi', warp_matrix, inputs)
       outputs.append(output)
     layer_output, layer_output_on_device = self.evaluate(outputs)
     self.assertAllClose(layer_output, layer_output_on_device)
示例#14
0
    def testContiguousCanvasUnderUniformRollinPolicy(self):
        """Tests for valid canvas size."""
        with self.session(use_gpu=True) as sess:
            params = insertion.SymbolInsertionLayer.Params()
            params.name = 'insertion'
            params.rollin_policy = 'oracle'
            params.oracle_policy = 'uniform'

            insertion_layer = insertion.SymbolInsertionLayer(params)

            batch_size = 4
            time_dim = 10

            inputs = tf.tile(
                tf.expand_dims(tf.range(time_dim), 0) + 100, [batch_size, 1])
            inputs_len = tf.random.uniform([batch_size], 0, time_dim, tf.int32)
            paddings = 1 - tf.sequence_mask(inputs_len, time_dim, tf.int32)
            spec = insertion_layer.FProp(None,
                                         inputs,
                                         paddings,
                                         force_sample_last_token=False)

            for _ in range(1000):
                canvas, canvas_paddings = sess.run(
                    [spec.canvas, spec.canvas_paddings])

                for b in range(batch_size):
                    length = np.sum(1 - canvas_paddings[b, :]).astype(np.int32)
                    # Check for valid part of the canvas and padding.
                    for l in range(length):
                        self.assertEqual(canvas_paddings[b, l], 0)
                        self.assertNotEqual(canvas[b, l], 0)
                    # Check for invalid part of the canvas and padding.
                    for l in range(length, canvas.shape[1]):
                        self.assertEqual(canvas_paddings[b, l], 1)
                        self.assertEqual(canvas[b, l], 0)
示例#15
0
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
        py_utils.assert_equal(source_batch,
                              tf.shape(output.topk_hyps)[0]),
    ],
                                               tf.shape(output.topk_hyps)[1])
    for k, v in six.iteritems(output._asdict()):
      if v is None:
        continue
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
      scores)
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
    else:
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
示例#16
0
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
  """Computes attention according to a sparsity pattern.

  We use the following capital letters to denote shape parameters:
    B = batch size
    S = length of the source sequence
    T = length of the target sequence
    N = number of attention heads
    H = dimensions of each attention head
    K = number of clusters
    W = attention window (K <= S)

  The 'sparsity_indices' is a tensor of integral type where the last dimension
  contains W indices (W is the attention window) for each corresponding position
  along S in 'k' that the query is allowed to attend to.

  For example, if sparsity_indices[batch_idx, target time step, head_idx] =
  [1, 7, 8], it means that token in the query attends to values with indices
  1, 7, and 8, and the attention window here is 3.

  The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1
  is reserved to mean paddings, distinct from the value (S-1).

  For example, if W=S and 'sparsity_indices' contains range(S) on the last
  dimension, this degenerates to the original full attention.

  We require that 'sparsity_indices' does not contain duplicates (except for -1
  to indicate paddings), but we do not require 'sparsity_indices' to be sorted.

  Note that this implementation is flexible and generic but is not optimized for
  time or space complexity. Please consider grouping queries that attend to the
  same subset of values first for efficiency.

  Args:
    q: (projected) queries, [B, T, N, H];
    k: (projected) keys, [B, S, N, H];
    v: (projected) values, [B, S, N, H];
    sparsity_indices: [B, T, N, W], where W is the attention window;
    paddings: paddings for keys, [B, S] if not None.

  Returns:
    output: the encoded output, [B, T, N, H].
    atten_probs: the attention weights, [B, T, N, S].
  """
  q = tf.convert_to_tensor(q)
  k = tf.convert_to_tensor(k)
  v = tf.convert_to_tensor(v)
  sparsity_indices = tf.convert_to_tensor(sparsity_indices)

  k = py_utils.HasRank(k, 4)
  _, source_length, _, dim_per_head = py_utils.GetShape(k, 4)
  sparsity_indices = py_utils.HasRank(sparsity_indices, 4)
  batch_size, target_length, num_heads, attention_window = py_utils.GetShape(
      sparsity_indices, 4)
  py_utils.assert_less_equal(
      attention_window, source_length,
      'The provided sparsity_indices has attention window '
      ' > source length. This is likely an error.')

  # To prepare for gathering the relevant vectors from 'k', we prepare
  # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to
  # slices in 'k' indexed by (batch index, source time step, head index),
  # where the source length index comes from the original W dimension in
  # 'sparsity_indices'.
  seq_idx = tf.expand_dims(sparsity_indices, axis=-1)
  # Overwrite the paddings -1 with valid gather indices (zeros). We will
  # fix the logits with -inf in these positions later.
  seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx)
  batch_idx = tf.reshape(
      tf.range(0, batch_size, dtype=sparsity_indices.dtype),
      [batch_size, 1, 1, 1, 1])
  batch_idx = tf.tile(batch_idx,
                      [1, target_length, num_heads, attention_window, 1])
  head_idx = tf.reshape(
      tf.range(0, num_heads, dtype=sparsity_indices.dtype),
      [1, 1, num_heads, 1, 1])
  head_idx = tf.tile(head_idx,
                     [batch_size, target_length, 1, attention_window, 1])
  # [B, T, N, W, 3], where last dimension is (batch index, source length index,
  # head index).
  gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1)

  # Both the gathered k and v have shape [B, T, N, W, H]
  k = tf.gather_nd(k, gather_idx)
  v = tf.gather_nd(v, gather_idx)

  if paddings is None:
    paddings = tf.zeros([batch_size, source_length])
  paddings = tf.convert_to_tensor(paddings)
  paddings = tf.expand_dims(paddings, axis=-1)
  # [B, S, N]
  paddings = tf.tile(paddings, [1, 1, num_heads])
  # [B, T, N, W]
  paddings = tf.gather_nd(paddings, gather_idx)

  logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k)
  logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))

  very_negative_logits = (
      tf.ones_like(logits) * logits.dtype.max *
      tf.constant(-0.7, dtype=logits.dtype))
  padded_logits = tf.where(
      tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
      very_negative_logits, logits)

  # [B, T, N, W]
  atten_probs = tf.nn.softmax(padded_logits, name='attention_weights')
  atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits),
                         atten_probs)
  output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v)

  # Scatter 'atten_probs' back into the original source length.
  # [B, T, N, W, 1]
  batch_idx = tf.tile(
      tf.range(batch_size)[:, None, None, None, None],
      [1, target_length, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  target_seq_idx = tf.tile(
      tf.range(target_length)[None, :, None, None, None],
      [batch_size, 1, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  head_idx = tf.tile(
      tf.range(num_heads)[None, None, :, None, None],
      [batch_size, target_length, 1, attention_window, 1])
  # seq_idx: [B, T, N, W, 1]
  # [B, T, N, W, 4]
  scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1)
  # [B, T, N, S]
  scattered_probs = tf.scatter_nd(
      scatter_idx, atten_probs,
      [batch_size, target_length, num_heads, source_length])
  return output, scattered_probs
示例#17
0
    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
示例#18
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None, **kwargs):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.
      **kwargs: Additional inputs.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval or p.freeze_bn_stats:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = (self.vars.moving_mean,
                                            self.vars.moving_variance)
            else:
                rank = tf.rank(paddings)
                reduce_over_dims = tf.range(0, rank - 1)
                mean, variance = ComputeMoments(
                    inputs, paddings, reduce_over_dims, None,
                    p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self.vars.moving_variance,
                                             variance, self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram(
                    '%s_moving_mean' % p.name,
                    tf.cast(self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self.vars.moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(
                        tf.cast(mean, self.vars.moving_mean.dtype.base_dtype) -
                        self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(
                        tf.cast(variance,
                                self.vars.moving_variance.dtype.base_dtype) -
                        self.vars.moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies(
                        [mean], self.vars.moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self.vars.moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta, gamma = self._GetBetaGamma(theta, inputs, **kwargs)
            return norm_mean, norm_variance, beta, gamma
示例#19
0
def MakeLocalMask(seq_len,
                  block_size,
                  left_context,
                  right_context,
                  dtype=tf.float32):
  """Makes the mask tensor for a full sequence.

  The returned mask reflects the given context sizes, where position i
  attends to tokens in the range [i - (left_context-1), i + right_context].

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    dtype: tf.dtype, default is tf.float32.

  Returns:
    A tensor of [num_blocks, block_size, context_size] taking values in {0, 1},
    where context_size = block_size + (left_context - 1) + right_context.
    Element b, i, j is zero if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
  seq_len = py_utils.with_dependencies([
      py_utils.assert_greater_equal(
          seq_len, 1, message='seq_len must be at least 1')
  ], seq_len)

  num_blocks = (seq_len + block_size - 1) // block_size
  context_size = block_size + (left_context - 1) + right_context

  # [num_blocks, block_size]: source positions in the original sequence.
  src_positions = tf.reshape(
      tf.range(num_blocks * block_size), [num_blocks, block_size])
  # [num_blocks,]: source positions at the start of each block.
  block_start_positions = tf.range(0, num_blocks * block_size, block_size)
  # [context_size]:  positions relative to the block start.
  relative_context_positions = tf.range(context_size) - (left_context - 1)

  # [num_blocks, context_size]: target positions in the original sequence.
  tgt_positions = (
      block_start_positions[:, tf.newaxis] +
      relative_context_positions[tf.newaxis, :])
  # [num_blocks, block_size, context_size]: position differences between source-
  # target pairs.
  position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:,
                                                                  tf.newaxis, :]
  # [num_blocks, block_size, context_size]: if attention is allowed between
  # source-target pairs.
  valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                    position_diff < left_context)

  # [num_blocks, block_size]: if the source position is valid, not padded.
  valid_src = src_positions < seq_len
  # [num_blocks, context_size]: if the target position is valid, not padded.
  valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len)

  valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                     valid_tgt[:, tf.newaxis, :])

  return tf.cast(valid_atten, dtype=dtype)
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ P P P P P P - - - - - - P* - - - ]   ^
        # [ P P P P P P P P P P - - P* - - - ]   | batch
        # [ P - - - - - - - - - - - P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_pad = tf.cast(
            tf.less(tf.expand_dims(pfx_time, 0),
                    tf.expand_dims(pfx_len - 1, 1)), tf.int32)
        pfx_id = pfx * pfx_pad
        pfx_last = einsum_i32(
            'BT,BT->B', pfx, tf.one_hot(pfx_len - 1,
                                        pfx_max,
                                        dtype=fprop_dtype))

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_time * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state
  def _Extract(self, features):
    p = self.params
    ri_outputs = {}
    outputs = {}
    frame_pose = tf.reshape(_Dense(features['pose']), [4, 4])
    for laser in p.cbr_laser_names + p.gbr_laser_names:
      # Extract range images.
      for returns in p.returns:
        ri_shape = tf.reshape(
            _Dense(features['%s_%s_shape' % (laser, returns)]), [-1])
        range_image = tf.reshape(
            _Dense(features['%s_%s' % (laser, returns)]), ri_shape)

        shape_to_check = (
            p.cbr_ri_shape if laser in p.cbr_laser_names else p.gbr_ri_shape)
        range_image = py_utils.HasShape(range_image, shape_to_check)

        ri_outputs['%s_%s' % (laser, returns)] = range_image

      # Extract beam inclinations and extrinsics
      outputs['%s_extrinsics' % laser] = tf.reshape(
          _Dense(features['%s_extrinsics' % laser]), [4, 4])

    # CBRs have uniform inclination
    for laser in p.cbr_laser_names:
      beam_inclination_min = tf.reshape(
          _Dense(features['%s_beam_inclination_min' % laser]), [])
      beam_inclination_max = tf.reshape(
          _Dense(features['%s_beam_inclination_max' % laser]), [])
      outputs['%s_beam_inclinations' % laser] = tf.stack(
          [beam_inclination_min, beam_inclination_max], axis=0)

    # GBRs have non-uniform inclinations defined by 64 floats.
    for laser in p.gbr_laser_names:
      outputs['%s_beam_inclinations' % laser] = tf.reshape(
          _Dense(features['%s_beam_inclinations' % laser]), [64])

    # Embed xyz onto each range image pixel.
    for laser in p.cbr_laser_names + p.gbr_laser_names:
      extrinsics = outputs['%s_extrinsics' % laser]
      inclinations = outputs['%s_beam_inclinations' % laser]
      if laser in p.cbr_laser_names:
        ri_shape = p.cbr_ri_shape

        # Convert from 2-tuple range inclination to the full range
        # via linear interpolation.
        #
        # CBR lasers currently are always uniform inclinations specified by a
        # length 2 vector.
        height = ri_shape[0]
        min_inclination = inclinations[0]
        max_inclination = inclinations[1]
        diff = max_inclination - min_inclination
        ratio = (.5 + tf.cast(tf.range(0, height), tf.float32)) / tf.cast(
            height, tf.float32)
        # interpolate from min to max inclination.
        inclinations = (ratio * diff) + min_inclination
      else:
        ri_shape = p.gbr_ri_shape

      pixel_pose = None
      if laser in p.gbr_laser_names:
        pixel_pose = tf.reshape(
            _Dense(features['%s_pose' % laser]),
            shape=p.gbr_ri_shape[0:2] + [4, 4])
        outputs['%s_pose' % laser] = pixel_pose

      for returns in p.returns:
        range_image = ri_outputs['%s_%s' % (laser, returns)]
        range_image = tf.reshape(range_image, ri_shape)
        range_image_mask = range_image[..., 0] >= 0
        ri_xyz = tf.cast(
            self._XYZFromRangeImage(range_image, range_image_mask, extrinsics,
                                    inclinations, pixel_pose, frame_pose),
            tf.float32)

        # Produce the NestedMap of xyz, features, mask.
        ri_result = py_utils.NestedMap({
            'xyz': ri_xyz,
            'features': range_image,
            'mask': tf.cast(range_image_mask, tf.float32),
        })

        outputs['%s_%s' % (laser, returns)] = ri_result

    return py_utils.NestedMap(outputs)
示例#23
0
 def _WordsToIds(i, words, ids_ta):
     encoded_ids = self._EncodeToIds(BOW_STR + words[i])
     ids_ta = ids_ta.scatter(
         tf.range(ids_ta.size(),
                  ids_ta.size() + tf.size(encoded_ids)), encoded_ids)
     return i + 1, words, ids_ta
示例#24
0
    def _ConstructWarpMatrix(self, batch_size, matrix_size, origin,
                             destination, choose_range, dtype):
        """Returns warp matrices according to origin, destination and choose_range.

    This function constructs a batch of warp matrices which maps the batch
    of origin points to the batch of destination points with fixed boundary
    coordinates at 0 and choose_range.

    The warping function, defined by the origin anchor point `origin`,
    the destination of the origin anchor point `destination` and the
    length of the domain in the warping axis `choose_range` is a piecewise
    linear map that fixes the points 0 and `choose_range` and maps
    `origin` to `destination`.

    For the warping matrix to be non-singular, destination must lie in the
    range 1<= destination <= choose_range - 1, so a destination
    out of this range is adjusted to be in this range before the warping
    matrix is constructed.

    The warping map can be explicitly written by first defining the slopes:
      1) slope_0 = origin / destination.
      2) slope_1 = (choose_range - origin) / (choose_range - destination).
      3) slope_2 = 1.0.

    Then the origin point orig_i of the mapped coordinate i is given by:
      1) i < destination: orig_i = slope_0 * i.
      2) destination <= i < choose_range:
         orig_i = slope_1 * i - (slope_1 - slope_0) * destination.
      3) i >= choose_range: orig_i = i.

    Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
      1) j = n_i: 1 - n_i + orig_i.
      2) j = n_i - 1: n_i - orig_i.
      3) Otherwise: 0.

    Applying the warp matrix to an array of pixels, i.e.,
    warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get
    warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i].

    Args:
      batch_size: Batch size. Integer number.
      matrix_size: Dimension of the vector space the warp matrix is applied to.
        Integer number.
      origin: Origin anchor point for warping. Tensor of shape (batch_size,) and
        data type dtype.
      destination: Destination of the origin anchor point upon warping. Tensor
        of shape (batch_size,) and data type dtype.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,) data type dtype.
      dtype: Data type of origin, destination, choose_range and the output warp
        matrix.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params

        # Entries of destination must be in the range
        # 1 <= destination <= choose_range - 1
        # for warp matrix to have non-singular values.
        destination = tf.minimum(tf.maximum(destination, 1.0),
                                 choose_range - 1.0)

        # Construct piece-wise linear function fixing boundary points
        # specified by zero, choose_range and matrix size and maps
        # the origin anchor point to the destination.
        destination_bc = tf.broadcast_to(destination,
                                         (matrix_size, batch_size))
        destination_bc = tf.transpose(destination_bc)
        choose_range_bc = tf.broadcast_to(choose_range,
                                          (matrix_size, batch_size))
        choose_range_bc = tf.transpose(choose_range_bc)

        # Slopes of piece-wise linear function.
        slope_0 = origin / destination
        slope_1 = (choose_range - origin) / (choose_range - destination)
        slope_2 = 1.0

        # x is a batch of origin matrices.
        # The origin matrix is the matrix such that
        # origin[i][j] = Origin coordinate of coordinate i for the warp map.
        # Denoting the destination of the origin anchor point in the
        # warp map as "dest," the origin coordinate of point i is given by:
        # 1) i < dest: slope_0 * i.
        # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest.
        # 3) i >= choose_range: i.
        x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size))
        x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm(
            slope_1 - slope_0, tf.nn.relu(x - destination_bc)) +
             self.EinsumBBmBm(slope_2 - slope_1,
                              tf.nn.relu(x - choose_range_bc)))
        x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size))
        x = tf.transpose(x, perm=[1, 2, 0])

        # y is a batch of coordinate matrices.
        # A coordinate matrix is a matrix such that
        # coordinate[i][j] = j.
        y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size, matrix_size))
        # Warp matrix is obtained by applying hat function element-wise to (x-y).
        # Denoting the origin point of i under the warp map as orig_i,
        # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
        # 1) j = n_i: 1 - n_i + orig_i.
        # 2) j = n_i - 1: n_i - orig_i.
        # 3) Otherwise: 0.
        # Applying the warp matrix to pixels, i.e.,
        # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get
        # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1]
        #                   + (1 - n_i + orig_i) * original_pixel[n_i].
        warp_matrix = x - y
        warp_matrix = _hat(warp_matrix)
        if p.fprop_dtype is not None and p.fprop_dtype != dtype:
            warp_matrix = tf.cast(warp_matrix, p.fprop_dtype)

        return warp_matrix
  def _XYZFromRangeImage(self,
                         lidar_image,
                         lidar_image_mask,
                         extrinsics,
                         inclinations,
                         pixel_pose=None,
                         frame_pose=None):
    """Extract the cartesian coordinates from the range image.

    Args:
       lidar_image: [H, W, C] range image Tensor.
       lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the
         lidar image are present.
       extrinsics: [4, 4] float matrix representing transformation matrix to
         world coordinates.
       inclinations: [V] beam inclinations vector.
       pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR.
       frame_pose: [4, 4] matrix representing vehicle to world transformation.

    Returns:
      [H, W, 3] range image cartesian coordinates.
    """
    height, width, channels = py_utils.GetShape(lidar_image, 3)

    conversion_dtype = tf.float32
    lidar_image = tf.cast(lidar_image, conversion_dtype)
    extrinsics = tf.cast(extrinsics, conversion_dtype)
    inclinations = tf.cast(inclinations, conversion_dtype)
    inclinations = tf.reverse(inclinations, axis=[-1])

    az_correction = py_utils.HasShape(
        tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), [])
    ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) -
              .5) / tf.cast(width, conversion_dtype)
    ratios = py_utils.HasShape(ratios, [width])

    azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis]
    azimuth = py_utils.HasShape(azimuth, [width])

    lidar_image_mask = lidar_image_mask[..., tf.newaxis]
    lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels])
    lidar_image = tf.where(lidar_image_mask, lidar_image,
                           tf.zeros_like(lidar_image))
    lidar_image_range = lidar_image[..., 0]

    azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width])
    inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1])

    cos_azimuth = tf.cos(azimuth)
    sin_azimuth = tf.sin(azimuth)
    cos_incl = tf.cos(inclinations)
    sin_incl = tf.sin(inclinations)

    x = cos_azimuth * cos_incl * lidar_image_range
    y = sin_azimuth * cos_incl * lidar_image_range
    z = sin_incl * lidar_image_range

    lidar_image_points = tf.stack([x, y, z], -1)
    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    rotation = extrinsics[0:3, 0:3]
    translation = extrinsics[0:3, 3][tf.newaxis, ...]

    # Transform the image points in cartesian coordinates to
    # the world coordinate system using the extrinsics matrix.
    #
    # We first flatten the points, apply rotation, then
    # reshape to restore the original input and then apply
    # translation.
    lidar_image_points = tf.matmul(
        tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True)
    lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3])
    lidar_image_points += translation

    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    # GBR uses per pixel pose.
    if pixel_pose is not None:
      pixel_pose_rotation = pixel_pose[..., 0:3, 0:3]
      pixel_pose_translation = pixel_pose[..., 0:3, 3]
      lidar_image_points = tf.einsum(
          'hwij,hwj->hwi', pixel_pose_rotation,
          lidar_image_points) + pixel_pose_translation
      if frame_pose is None:
        raise ValueError('frame_pose must be set when pixel_pose is set.')
      # To vehicle frame corresponding to the given frame_pose
      # [4, 4]
      world_to_vehicle = tf.linalg.inv(frame_pose)
      world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3]
      world_to_vehicle_translation = world_to_vehicle[0:3, 3]
      # [H, W, 3]
      lidar_image_points = tf.einsum(
          'ij,hwj->hwi', world_to_vehicle_rotation,
          lidar_image_points) + world_to_vehicle_translation[tf.newaxis,
                                                             tf.newaxis, :]

    return lidar_image_points
示例#26
0
 def _Slicing():
     # Choose a random set of indices.
     indices = tf.range(actual_num)
     indices = tf.random_shuffle(indices, seed=seed)[:num_points_out]
     return [tf.gather(t, indices, axis=0) for t in tensor_list]
示例#27
0
    def _GetMask(self,
                 batch_size,
                 choose_range,
                 mask_size,
                 global_seed,
                 max_length=None,
                 masks_per_frame=0.0,
                 multiplicity=1,
                 dtype=tf.float32,
                 max_ratio=1.0):
        """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      global_seed: an integer seed tensor for stateless random ops.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_1 and seed_2 are set separately to avoid
        # correlation of mask size and mask position.
        if p.use_input_dependent_random_seed:
            seed_1 = global_seed + 1
            seed_2 = global_seed + 2
        elif p.random_seed:
            seed_1 = p.random_seed + 1
            seed_2 = 2 * p.random_seed
        else:
            seed_1 = p.random_seed
            seed_2 = p.random_seed
        # Sample lengths for multiple masks.
        if max_length and max_length > 0:
            max_length = tf.broadcast_to(tf.cast(max_length, dtype),
                                         (batch_size, ))
        else:
            max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)
        masked_portion = random_uniform(shape=(batch_size, multiplicity),
                                        minval=0.0,
                                        maxval=1.0,
                                        dtype=dtype,
                                        seed=seed_1)
        masked_frame_size = self.EinsumBBmBm(max_length, masked_portion)
        masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way was biased
        # (shorter sequence may over-masked.)
        choose_range = tf.expand_dims(choose_range, -1)
        choose_range = tf.tile(choose_range, [1, multiplicity])
        length_bound = tf.cast(choose_range, dtype=dtype)
        length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
        length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

        # Choose starting point.
        random_start = random_uniform(shape=(batch_size, multiplicity),
                                      maxval=1.0,
                                      seed=seed_2)
        start_with_in_valid_range = random_start * tf.cast(
            (choose_range - length + 1), dtype=dtype)
        start = tf.cast(start_with_in_valid_range, tf.int32)
        end = start + length - 1

        # Shift starting and end point by small value.
        delta = tf.constant(0.1)
        start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
        start = tf.tile(start, [1, 1, mask_size])
        end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
        end = tf.tile(end, [1, 1, mask_size])

        # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
        diagonal = tf.expand_dims(
            tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
        diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
        pre_mask = tf.cast(tf.math.logical_and(diagonal < end,
                                               diagonal > start),
                           dtype=dtype)

        # Sum masks with appropriate multiplicity.
        if masks_per_frame > 0:
            multiplicity_weights = tf.tile(
                tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
                [batch_size, 1])
            multiplicity_tensor = masks_per_frame * tf.cast(choose_range,
                                                            dtype=dtype)
            multiplicity_weights = tf.cast(
                multiplicity_weights < multiplicity_tensor, dtype=dtype)
            pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights)
        else:
            pre_mask = tf.reduce_sum(pre_mask, 1)
        mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

        if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
            mask = tf.cast(mask, p.fprop_dtype)

        return mask
示例#28
0
  def __init__(self, params):
    super(MlPerfInput, self).__init__(params)
    p = self.params

    self.natural_order_model = p.natural_order_model

    (
        self._src_ids,
        self._src_paddings,
        self._tgt_ids,
        self._tgt_paddings,
        self._tgt_labels,
        self._tgt_weights,
        self._src_seg_pos,
        self._src_seg_ids,
        self._tgt_seg_pos,
        self._tgt_seg_ids,
    ), self._bucket_keys = self._BuildDataSource()

    if p.pad_to_max_seq_length:
      assert p.source_max_length

      if min(self.infeed_bucket_batch_limit) == max(
          self.infeed_bucket_batch_limit):
        source_shape = [
            min(self.infeed_bucket_batch_limit), p.source_max_length
        ]
        target_shape = [
            min(self.infeed_bucket_batch_limit), p.target_max_length
        ]
      else:
        source_shape = None
        target_shape = None
      self._src_ids = py_utils.PadSequenceDimension(self._src_ids,
                                                    p.source_max_length, 0,
                                                    source_shape)
      self._src_paddings = py_utils.PadSequenceDimension(
          self._src_paddings, p.source_max_length, 1, source_shape)
      self._tgt_ids = py_utils.PadSequenceDimension(self._tgt_ids,
                                                    p.target_max_length, 0,
                                                    target_shape)
      self._tgt_paddings = py_utils.PadSequenceDimension(
          self._tgt_paddings, p.target_max_length, 1, target_shape)
      self._tgt_labels = py_utils.PadSequenceDimension(self._tgt_labels,
                                                       p.target_max_length, 0,
                                                       target_shape)
      self._tgt_weights = py_utils.PadSequenceDimension(self._tgt_weights,
                                                        p.target_max_length, 0,
                                                        target_shape)

      self._src_seg_ids = py_utils.PadSequenceDimension(self._src_seg_ids,
                                                        p.source_max_length, 0,
                                                        source_shape)
      self._src_seg_pos = py_utils.PadSequenceDimension(self._src_seg_pos,
                                                        p.source_max_length, 0,
                                                        source_shape)
      self._tgt_seg_ids = py_utils.PadSequenceDimension(self._tgt_seg_ids,
                                                        p.target_max_length, 0,
                                                        target_shape)
      self._tgt_seg_pos = py_utils.PadSequenceDimension(self._tgt_seg_pos,
                                                        p.target_max_length, 0,
                                                        target_shape)

    self._sample_ids = tf.range(0, self.InfeedBatchSize(), 1)
示例#29
0
文件: model.py 项目: wgfi110/lingvo
  def _CreateCanvasAndTargets(self, batch):
    # pyformat: disable
    """Create the canvas and targets.

    Args:
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    """
    # pyformat: enable
    p = self.params

    if not p.is_eval:
      # Sample our src and tgt canvas.
      src_descriptor = self._SampleCanvasAndTargets(batch.src.ids,
                                                    batch.src.paddings)
      tgt_descriptor = self._SampleCanvasAndTargets(batch.tgt.ids,
                                                    batch.tgt.paddings)

      # Offset the src ids (to unshare embeddings between src/tgt). Note, we
      # only offset the canvas ids, but we do not offset the vocab ids. This
      # will result in unshared embeddings, but shared softmax. This is due to
      # GPU/TPU memory limitations, empirically it is known that unsharing
      # everything results in better performance.
      vocab_size = p.decoder.softmax.num_classes
      src_descriptor.canvas = tf.where(
          tf.equal(src_descriptor.canvas_paddings, 0),
          src_descriptor.canvas + vocab_size, src_descriptor.canvas)

      # Offset the tgt indices (need shift according to src length).
      batch_size = py_utils.GetShape(batch.src.ids)[0]
      # `target_batch` is a [num_targets, batch_size] tensor where each row
      # identifies which batch the target belongs to. Note the observation that,
      # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
      target_batch = tf.cast(
          tf.equal(
              tf.expand_dims(tf.range(batch_size), 0),
              tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)), tf.int32)
      src_lens = tf.cast(
          tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
      # `tgt_offset` is shape [num_targets] where each entry corresponds to the
      # offset needed for that target (due to the source length).
      tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
      # We shift the tgt slot without touching the batch or vocab.
      tgt_descriptor.target_indices += tf.concat(
          [tf.zeros_like(tgt_offset), tgt_offset,
           tf.zeros_like(tgt_offset)], 1)

      # The canvas is simply the sequence-level concat of the src and tgt.
      canvas, canvas_paddings = insertion.SequenceConcat(
          src_descriptor.canvas, src_descriptor.canvas_paddings,
          tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
      target_indices = tf.concat(
          [src_descriptor.target_indices, tgt_descriptor.target_indices], 0)
      target_weights = tf.concat(
          [src_descriptor.target_weights, tgt_descriptor.target_weights], 0)

      return py_utils.NestedMap(
          canvas=canvas,
          canvas_paddings=canvas_paddings,
          target_indices=target_indices,
          target_weights=target_weights)
示例#30
0
  def testConv2DLayerStridedWithPaddingFProp(self, seq_len):
    """Check strided convs get the same values for different length dim."""
    # TODO(isaace): THIS TEST SHOWS THAT THERE IS A BUG IN THE CODE.
    with self.session(use_gpu=True):
      batch_size = 3
      expected_seq_len = 3

      params = conv_layers.Conv2DLayerWithPadding.Params()
      params.weight_norm = False
      params.filter_stride = [2, 2]
      params.name = 'conv'
      params.filter_shape = [3, 3, 1, 1]
      params.params_init = py_utils.WeightInit.Constant(1.0)
      conv_layer = params.Instantiate()

      # Set up the padding for the sequence length. (starting at 5).
      in_padding = tf.constant([
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 1, 1],
      ], tf.float32)
      in_padding = tf.pad(
          in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0)

      inputs = 1.0 + tf.tile(
          tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]),
          [batch_size, 1, 3, 1])
      inputs = py_utils.ApplyPadding(
          tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs)

      inputs = py_utils.Debug(inputs)

      output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding)

      output = py_utils.Debug(output)
      out_padding = py_utils.Debug(out_padding)

      self.evaluate(tf.global_variables_initializer())
      output, out_padding = self.evaluate([output, out_padding])

      self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape)
      self.assertAllClose([
          [0, 0, 1],
          [0, 0, 1],
          [0, 1, 1],
      ], out_padding)

      # This here shows a bug in the implementation; the output should be the
      # same. Also there are bugs with the output not having the correct
      # padding.
      if seq_len == 5:
        self.assertAllClose([
            [[[6], [6]], [[18], [18]], [[18], [18]]],
            [[[6], [6]], [[18], [18]], [[8], [8]]],
            [[[6], [6]], [[10], [10]], [[0], [0]]],
        ], output)
      elif seq_len == 6:
        self.assertAllClose([
            [[[12], [12]], [[24], [24]], [[10], [10]]],
            [[[12], [12]], [[14], [14]], [[0], [0]]],
            [[[12], [12]], [[6], [6]], [[0], [0]]],
        ], output)
      else:
        raise ValueError('Test does not handle length {seq_len}')