Exemplo n.º 1
0
 def ComputeLoss(self, theta, predictions, input_batch):
   loss = tf.reduce_sum(predictions['result'])
   return {'loss': (loss, 1)}, {}
Exemplo n.º 2
0
    def testForwardPass(self):
        with self.session(use_gpu=False) as sess:
            bs = 2
            sl = 21
            tf.set_random_seed(8372749040)
            p = self._EncoderParams()
            mt_enc = encoder.TransformerEncoder(p)
            batch = py_utils.NestedMap()
            batch.ids = tf.constant(
                np.random.randint(low=0,
                                  high=63,
                                  size=[bs, sl],
                                  dtype=np.int32))
            batch.paddings = tf.zeros([bs, sl])
            out = mt_enc.FPropDefaultTheta(batch)
            enc_out_sum = tf.reduce_sum(out.encoded, 0)
            emb_out_sum = tf.reduce_sum(out.embedded_inputs, 0)
            enc_padding = out.padding

            tf.global_variables_initializer().run()
            actual_enc_out, actual_enc_out_sum, actual_emb_out_sum, \
                actual_padding = sess.run(
                    [out.encoded, enc_out_sum, emb_out_sum, enc_padding])

            # pyformat: disable
            # pylint: disable=bad-whitespace
            expected_enc_out = [[
                49.45291519, -31.5743885, 39.43684387, -47.67513275,
                35.39754105, 14.41970444, 29.58752823, -43.06747055,
                24.09403419, -7.62717247, 18.48112106, 20.42408371, 5.1519866,
                -19.66542244, 29.81095314, 56.90407944
            ],
                                [
                                    55.26333618, -30.39743614, 29.68314743,
                                    -37.61392975, 43.02292252, 13.88345146,
                                    15.73033905, -24.68696213, 24.70776558,
                                    -29.18026161, 15.41469955, 27.77672577,
                                    -5.36326742, -22.78984642, 22.15843391,
                                    22.7237072
                                ]]
            expected_emb_out_sum = [[
                3.11785889, 1.33086884, -1.96904886, -4.81911993, 1.25389254,
                1.52582073, 0.79906291, 4.07078457, -1.20546532, -2.97308111,
                0.22460097, 2.99702668, -2.29453254, 6.06631422, 1.68836212,
                5.35728741
            ],
                                    [
                                        1.41723049, -1.39409399, -1.49569404,
                                        -0.24654561, 1.09658146, 4.51638842,
                                        2.72023368, -0.45651400, 3.46091199,
                                        -0.43925080, 1.02091551, 3.89704037,
                                        1.87841535, -0.27947778, -0.91630745,
                                        1.34230828
                                    ]]
            # pylint: enable=bad-whitespace
            # pyformat: enable
            self.assertAllEqual(actual_enc_out.shape, [sl, bs, p.model_dim])
            self.assertAllEqual(actual_padding.shape, [sl, bs])
            self.assertAllClose(expected_enc_out,
                                actual_enc_out_sum,
                                rtol=1e-05,
                                atol=1e-05)
            self.assertAllClose(expected_emb_out_sum,
                                actual_emb_out_sum,
                                rtol=1e-05,
                                atol=1e-05)
    def testTargetSequenceSamplerWithEOC(self, use_recurrent):
        with self.session(use_gpu=False):
            np.random.seed(9384758)
            tf.random.set_seed(8274758)
            vocab_size = 4
            src_len = 5
            tgt_len = 20
            batch_size = 2
            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh',
                target_seq_len=tgt_len,
                target_eoc_id=0,
                use_recurrent=use_recurrent)
            seq_sampler = p.Instantiate()

            def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs,
                                       num_hyps_per_beam):
                self.assertEqual(1, num_hyps_per_beam)
                logits = tf.zeros((batch_size, vocab_size), dtype=tf.float32)
                is_last_chunk = tf.constant(False, shape=[batch_size])
                result = py_utils.NestedMap(log_probs=logits,
                                            is_last_chunk=is_last_chunk)
                states = py_utils.NestedMap(step=tf.constant(0),
                                            src_step=tf.zeros([batch_size],
                                                              dtype=tf.int32))
                return result, states

            def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                          unused_step_ids, states,
                                          num_hyps_per_beam, unused_cur_step):
                self.assertEqual(1, num_hyps_per_beam)
                logits = tf.random.stateless_normal([batch_size, vocab_size],
                                                    seed=[8273747, 9])
                # Make it never predict <eos>.
                logits -= tf.one_hot([p.target_eos_id], vocab_size, 1e30)
                is_last_chunk = tf.equal(states.src_step, src_len - 1)
                result = py_utils.NestedMap(log_probs=logits,
                                            is_last_chunk=is_last_chunk)
                return result, states

            def PostBeamSearchStepCallback(unused_theta,
                                           unused_encoder_outputs,
                                           new_step_ids, states):
                return py_utils.NestedMap(
                    step=states.step + 1,
                    src_step=states.src_step +
                    tf.cast(tf.equal(new_step_ids, p.target_eoc_id),
                            dtype=tf.int32))

            src_enc = tf.random.stateless_normal([src_len, batch_size, 8],
                                                 seed=[982774838, 9])
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)

            theta = py_utils.NestedMap()
            random_seed = tf.constant(123)
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [
                [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [0, 0, 3, 3, 1, 0, 3, 0, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
            expected_lens = [5, 11]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)

            # Now do the same, except with use_stop_fn=True.
            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh',
                target_seq_len=tgt_len,
                target_eoc_id=0,
                use_stop_fn=True)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [
                [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [0, 0, 3, 3, 1, 0, 3, 0, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
            expected_lens = [5, 11]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)
Exemplo n.º 4
0
def FarthestPointSampler(points,
                         padding,
                         num_sampled_points,
                         precomputed_squared_distance=None,
                         num_seeded_points=0,
                         random_seed=None):
  """Samples num_sampled_points from points using farthest point sampling.

  Algorithm:
  1. Start by selecting a random point and adding to a selected set.
  2. For all remaining points, find the furthest point from those selected.
  3. Add furthest point to selected.
  4. Repeat 2-3 until num_sampled_points are selected.

  More details at https://en.wikipedia.org/wiki/Farthest-first_traversal

  This output of this function can be used with tf.array_ops.batch_gather to
  extract the desired points, for example:
  tf.array_ops.batch_gather(points, sampled_idx)

  Args:
    points: floating point tf.Tensor of shape [N, P1, dims]
    padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is
      real, and 1 otherwise.
    num_sampled_points: integer number of points to sample.
    precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of
      distances between each point. if None, distances will be computed on the
      fly.
    num_seeded_points: If num_seeded_points > 0, then the first
      num_seeded_points in points are considered to be seeded in the FPS
      sampling. Note that we assume that these points are *not* padded, and do
      not check padding when seeding them.
    random_seed: optional integer random seed to use with all the random ops.

  Returns:
    A tuple of tf.Tensors (sampled_idx, closest_idx) of types
    (tf.int32, tf.int32).

    sampled_idx is of shape [N, num_sampled_points] representing the indices
    selected using the sampler. This will have range of [0, P1].

    closest_idx is of shape [N, P1] representing the indices of the closest
    sampled points for each input point. closest_idx is used in PCNN as part of
    the pooling operation: each point is assigned to the closest sampled point
    and a max is taken over them. This will have a range of [0, P2] with the
    index of the closest sampled point that remains.
  """
  points = py_utils.HasRank(points, 3)
  batch_size, num_points, dims = py_utils.GetShape(points, 3)

  points = py_utils.with_dependencies(
      [py_utils.assert_greater_equal(num_points, num_sampled_points)], points)

  # Add a tiny bit of noise to the distance matrix or points so all
  # points are unique. This will also ensure true repeated points
  # like padded points are only selected after all valid points are selected.
  if precomputed_squared_distance is not None:
    precomputed_squared_distance = py_utils.HasShape(
        precomputed_squared_distance, [batch_size, num_points, num_points])
    precomputed_squared_distance += tf.random.uniform(
        (batch_size, num_points, 1),
        minval=1e-6,
        maxval=1e-5,
        dtype=tf.float32,
        seed=random_seed)
  else:
    points += tf.random.uniform((batch_size, num_points, dims),
                                minval=1e-6,
                                maxval=1e-5,
                                dtype=tf.float32,
                                seed=random_seed)

  # TensorArray to store the sampled indices in the loop.
  sampled_idx = tf.TensorArray(tf.int32, num_sampled_points)

  # Initialize distance_to_selected to inf for all points.
  distance_to_selected = float('inf') * tf.ones((batch_size, num_points))

  # For tracking the index to the closest selected point.
  closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32)

  # Current loop index counter.
  curr_idx = tf.constant(0, dtype=tf.int32)

  # Get number of valid points (1 is padded, so num_points - num_padded).
  num_valid_points = tf.cast(
      tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1),
      dtype=tf.int32)

  def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
    """Loop body for farthest point sampler."""

    def _GetRandomRealPoint():
      """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
      random_values = tf.random.uniform((batch_size, num_points),
                                        minval=0,
                                        maxval=1,
                                        dtype=tf.float32,
                                        seed=random_seed)
      random_values = tf.where(
          tf.equal(padding, 0.0), random_values, padding * 10)
      return tf.argmin(random_values, axis=1, output_type=tf.int32)

    def _GetFurthestPoint():
      """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
      # Set padded points distance to negative so they aren't selected.
      padding_masked_distance_to_selected = tf.where(
          tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
              (batch_size, num_points), dtype=tf.float32))
      # But only do this when we still have valid points left.
      padding_masked_distance_to_selected = tf.where(
          tf.less(curr_idx, num_valid_points),
          padding_masked_distance_to_selected, distance_to_selected)
      return tf.argmax(
          padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)

    def _GetSeededPoint():
      """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
      return tf.ones((batch_size,), dtype=tf.int32) * curr_idx

    # Select indices for this loop iteration.
    def _Seeded():
      return tf.cond(
          tf.less(curr_idx, num_seeded_points), _GetSeededPoint,
          _GetFurthestPoint)

    def _Real():
      return tf.cond(
          tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)

    new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real)
    sampled_idx = sampled_idx.write(curr_idx, new_selected)

    # Extract the distance to the latest point selected to update
    # distance_to_selected.
    new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected],
                                       axis=1)
    if precomputed_squared_distance is not None:
      new_distance = tf.gather_nd(precomputed_squared_distance,
                                  new_selected_gather_idx)
    else:
      new_points = tf.reshape(
          tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims])
      new_distance = tf.reshape(
          SquaredDistanceMatrix(points, new_points), [batch_size, num_points])

    is_newly_closest = tf.less(new_distance, distance_to_selected)
    distance_to_selected = tf.minimum(distance_to_selected, new_distance)

    # Track the index to the closest selected point.
    new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
    closest_idx = tf.cond(
        tf.equal(curr_idx, 0),
        # At the first loop iteration, the init points are the closest.
        lambda: new_selected_tiled,
        # Otherwise, update with the new points based on the distances.
        lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx))
    return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx

  _, _, sampled_idx, closest_idx = tf.while_loop(
      lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points),
      _BodyFn,
      loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx),
      back_prop=False,
      maximum_iterations=num_sampled_points)

  sampled_idx = sampled_idx.stack()  # num_sampled_points x n
  sampled_idx = tf.transpose(sampled_idx, [1, 0])

  if isinstance(batch_size, int) and isinstance(num_sampled_points, int):
    sampled_idx.set_shape((batch_size, num_sampled_points))

  return sampled_idx, closest_idx
Exemplo n.º 5
0
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 -
                            tf.cast(tf.transpose(source_paddings), tf.float32),
                            1)), tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
 def InferenceFn(x):
     return tf.reduce_sum(self.vars.w * x) + self.vars.b
Exemplo n.º 7
0
  def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
    """Takes a tensor of strings and returns id/padding tensors.

    This generates `token_ids`, `target_ids`, and `paddings` in the format that
    is expected for tokenizers. This performs padding to a fixed length and
    appends the end-of-sentence token as appropriate.

    Args:
      strs: a string Tensor.
      max_length: a python integer. The second dimension of the returned arrays.
        All sequences are padded or truncated to that length.
      append_eos: a python bool. See `BaseTokenizer` for explanation.
      languages: A vector of strings with the same length as `strs`.

    Returns:
      A tuple of 3 tensors:

      - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences
        always end with EOS unless the sequence exceeds the maximum length.
        Always padded with EOS.
      - target_ids: a tensor of sequences of WPM ids not starting with SOS
        but ending with EOS. Always padded with EOS.
      - paddings: a tensor of floats indicating, at each position, whether
        the corresponding position is padded.
    """
    p = self.params
    if append_eos is None:
      append_eos = p.append_eos

    batch_size = py_utils.GetShape(strs)[0]
    token_ids_ta = tf.TensorArray(tf.int32, batch_size)
    target_ids_ta = tf.TensorArray(tf.int32, batch_size)
    paddings_ta = tf.TensorArray(tf.float32, batch_size)

    def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta):
      """Tokenizes a single sentence."""
      ids, _ = self._wpm_encoder.Encode(strs[i])

      if append_eos:
        ids = tf.concat([ids, [self.eos_id]], axis=0)

      # This truncates after the eos is added, so some sentences might
      # not have </s> at the end.
      token_ids_ta = token_ids_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.concat([[self.sos_id], ids], axis=0), [max_length],
              self.eos_id))
      target_ids_ta = target_ids_ta.write(
          i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id))
      paddings_ta = paddings_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.))

      return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta

    _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop(
        lambda i, *_: i < batch_size,
        _TokenizeOneSentence,
        loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta,
                   paddings_ta),
        parallel_iterations=30,
        back_prop=False)

    token_ids = token_ids_ta.stack()
    target_ids = target_ids_ta.stack()
    paddings = paddings_ta.stack()

    if not p.pad_to_max_length:
      maxlen = tf.cast(
          tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))),
          tf.int32)
      token_ids = token_ids[:, :maxlen]
      target_ids = target_ids[:, :maxlen]
      paddings = paddings[:, :maxlen]

    return token_ids, target_ids, paddings
Exemplo n.º 8
0
 def FPropTower(self, theta, unused_input_batch):
   return py_utils.NestedMap(
       loss=(tf.reduce_sum(theta.w) + theta.b, 1.0),
       loss2=(tf.reduce_sum(theta.w) - theta.b, 1.0)), py_utils.NestedMap()
Exemplo n.º 9
0
 def KeyFunc(batch):
     key = tf.reduce_min(batch.bucket_keys)
     idx = tf.reduce_sum(
         tf.cast(tf.greater(key, p.bucket_upper_bound), tf.int32))
     return tf.constant(p.bucket_upper_bound, dtype=tf.int64)[idx]
Exemplo n.º 10
0
 def testForwardPassWithStackingAfterFinalLayer(self):
     with self.session(use_gpu=False):
         vn_config = py_utils.VariationalNoiseParams(None, False, False)
         p = self._EncoderParams(vn_config)
         p.stacking_layer_tpl.left_context = 1
         p.stacking_layer_tpl.right_context = 0
         p.stacking_layer_tpl.stride = 2
         p.layer_index_before_stacking = 1
         enc_out = self._ForwardPass(p).encoded
         enc_out_sum = tf.reduce_sum(enc_out, 0)
         tf.global_variables_initializer().run()
         # pyformat: disable
         # pylint: disable=bad-whitespace
         expected_enc_out = [
             [
                 -1.25796525e-02, -2.32883729e-02, 7.40477070e-03,
                 -4.51436592e-03, -5.84740378e-03, 2.30195466e-03,
                 -3.08505213e-03, 4.05658083e-03, -8.12252797e-03,
                 -1.08030904e-02, -4.17955732e-03, -3.73707339e-03,
                 6.97144482e-04, 2.79850606e-03, 8.33133236e-04,
                 -5.75614115e-03, -1.10648498e-02, -1.20132393e-03,
                 -1.69872947e-03, 6.97519444e-03, 2.46211258e-03,
                 -1.28190573e-02, -8.66306946e-05, -6.09322963e-03,
                 7.14540575e-03, -5.67986863e-05, 5.17684873e-03,
                 1.18097477e-02, 1.74862407e-02, 9.13049746e-03,
                 7.31027778e-03, 4.83186450e-05, -1.38104409e-02,
                 -2.56096497e-02, 1.04327593e-02, -5.15327370e-03,
                 -8.69584084e-03, 1.33647269e-03, -1.84873224e-03,
                 5.81806153e-03, -1.17716007e-02, -1.23606063e-02,
                 -2.58761784e-03, -6.46180846e-03, 4.11718246e-03,
                 6.22369815e-03, 4.84800315e-04, -8.21352564e-03,
                 -1.25989169e-02, 6.75740885e-04, -2.09423108e-03,
                 4.02465323e-03, 6.08023722e-03, -1.15798926e-02,
                 -6.19094400e-03, -1.03260633e-02, 8.31142440e-03,
                 3.74771934e-03, 7.58658582e-03, 1.32339774e-02,
                 2.02648211e-02, 8.03512800e-03, 1.21787926e-02,
                 4.27130330e-03
             ],
             [
                 -5.94401825e-03, 4.23503201e-03,
                 -7.39302021e-03, 3.84659087e-03, 2.92047067e-03,
                 -2.28955783e-03, 7.80778937e-05, 7.74920732e-03,
                 -1.29534695e-02, -1.44997425e-02, 3.00848205e-03,
                 -1.33561785e-04, 7.31927902e-03, -2.24683899e-03,
                 -6.27679843e-03, -5.35295857e-03, -5.39031485e-03,
                 -4.90641687e-05, 4.03603073e-03, -1.08133641e-03,
                 9.59445070e-03, 9.81783494e-03, 8.77558347e-03,
                 -5.13678743e-03, 7.19959754e-03, 3.93835502e-03,
                 -6.01979066e-03, 6.13247836e-03, 1.39782019e-03,
                 4.60287556e-04, 1.04263611e-02, -9.61792190e-03,
                 -1.02399308e-02, 8.54056142e-03, -1.22422148e-02,
                 6.58972748e-03, 3.18149826e-03, -2.79453350e-03,
                 -9.98417381e-04, 1.77927073e-02, -2.28664111e-02,
                 -2.73113251e-02, 6.44177478e-03, -5.66864444e-04,
                 1.58752780e-02, 2.18148530e-03, -1.31809842e-02,
                 -9.98921506e-03, -9.63711366e-03, 1.11398206e-03,
                 4.28507291e-03, -3.02007422e-04, 1.06751733e-02,
                 1.15796775e-02, 1.35387452e-02, -1.02765551e-02,
                 1.11750513e-02, 4.31185029e-03, -1.04119312e-02,
                 8.54373723e-03, 4.97616245e-04, -3.82199232e-03,
                 2.10159980e-02, -1.68744288e-02
             ]
         ]
         # pylint: enable=bad-whitespace
         # pyformat: enable
         enc_out_sum_val = enc_out_sum.eval()
         print('expected enc_out_sum_val', enc_out_sum_val)
         self.assertAllClose(expected_enc_out, enc_out_sum_val)
Exemplo n.º 11
0
    def ComputeLoss(self, theta, input_batch, predictions):
        """Computes loss and other metrics for the given predictions.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: The input batch from which we accesses the groundtruth.
      predictions: The output of `ComputePredictions`, contains: logits - [b,
        nx, ny, nz, na, 7 + num_classes]. na is the number of anchor
        boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt).

    Returns:
      Two dicts defined as BaseTask.ComputeLoss.
    """
        p = self.params
        predicted_residuals = py_utils.HasShape(
            predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7])
        predicted_class_logits = py_utils.HasShape(
            predictions.classification_logits,
            [-1, -1, -1, -1, p.num_anchors, p.num_classes])
        bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6)

        # Compute class and regression weights.
        class_weights = input_batch.assigned_cls_mask
        class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na])
        reg_weights = input_batch.assigned_reg_mask
        reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na])
        reg_weights = tf.expand_dims(reg_weights, -1)

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask,
                                                [bs, nx, ny, nz, na])
            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask,
                                               axis=[1, 2, 3, 4])
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))
            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [bs, 1, 1, 1, 1, 1])

            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        # Classification loss.
        assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels,
                                               [bs, nx, ny, nz, na])
        class_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_class_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)
        class_loss *= class_weights[..., tf.newaxis]
        class_loss_sum = tf.reduce_sum(class_loss)

        # Regression loss.
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7])

        # Location and dimensions loss.
        reg_loc_and_dims_loss = self._utils.ScaledHuberLoss(
            predictions=py_utils.HasShape(predicted_residuals[..., :6],
                                          [bs, nx, ny, nz, na, 6]),
            labels=anchor_localization_residuals[..., :6],
            delta=1 / (3.**2))

        # Rotation loss with SmoothL1(sin(delta)).
        rot_delta = (predicted_residuals[..., 6:] -
                     input_batch.anchor_localization_residuals[..., 6:])
        reg_rot_loss = self._utils.ScaledHuberLoss(
            predictions=tf.sin(rot_delta),
            labels=tf.zeros_like(rot_delta),
            delta=1 / (3.**2))

        # Direction loss
        if p.direction_classifier_weight > 0.0:
            # The target rotations are in the assigned_gt_bbox tensor,
            # which already has assigned a gt bounding box to every anchor.
            rot_target = input_batch.assigned_gt_bbox[..., 6]
            # If rotation is > 0, the class is 1, else it is 0.
            rot_dir = tf.to_int32(rot_target > 0.)

            # Compute one-hot labels as a target.
            rot_dir_onehot = tf.one_hot(rot_dir, 2)

            # Manually handle loss reduction.
            dir_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=rot_dir_onehot,
                logits=predictions.predicted_dir,
                weights=tf.squeeze(reg_weights, axis=-1),
                reduction=tf.losses.Reduction.NONE)
            # Reduce across all dimensions (we'll divide by the batch size below).
            dir_loss_sum = tf.reduce_sum(dir_loss)
        else:
            dir_loss_sum = 0.0

        # Compute loss contribution from location and dimension separately.
        reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights
        reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss)

        reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights
        reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss)

        # Compute rotation loss contribution.
        reg_rot_loss *= reg_weights
        reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss)

        # Num. predictions.
        # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes.
        preds = tf.cast(bs, class_loss_sum.dtype)

        # Normalize all of the components by batch size.
        reg_loc_loss = reg_loc_loss_sum / preds
        reg_dim_loss = reg_dim_loss_sum / preds
        reg_rot_loss = reg_rot_loss_sum / preds
        class_loss = class_loss_sum / preds
        dir_loss = dir_loss_sum / preds

        # Compute total localization regression loss.
        reg_loss = (p.location_loss_weight * reg_loc_loss +
                    p.dimension_loss_weight * reg_dim_loss +
                    p.rotation_loss_weight * reg_rot_loss)

        # Apply weights to normalized class losses.
        loss = (class_loss * p.classification_loss_weight +
                reg_loss * p.localization_loss_weight +
                dir_loss * p.direction_classifier_weight)

        metrics_dict = {
            'loss': (loss, preds),
            'loss/class': (class_loss, preds),
            'loss/reg': (reg_loss, preds),
            'loss/reg/rot': (reg_rot_loss, preds),
            'loss/reg/loc': (reg_loc_loss, preds),
            'loss/reg/dim': (reg_dim_loss, preds),
            'loss/dir': (dir_loss, preds),
        }

        per_example_dict = {
            'residuals': predicted_residuals,
            'classification_logits': predicted_class_logits,
        }

        return metrics_dict, per_example_dict
Exemplo n.º 12
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.to_int32(tf.round(tf.reduce_sum(1 - x_paddings, 1)))

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Exemplo n.º 13
0
    def FProp(self, theta, input_batch):
        # pyformat: disable
        """Compute features for the pillars and convert them back to a dense grid.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors. Following
        keys are required:

        - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz],
          where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels
          in each axis dimension).
        - pillar_points: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3 + num_laser_features]
        - pillar_centers: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3]
        - pillar_locations: Float tensor with shape [batch size, num_pillars, 3]

    Returns:
      The dense features with shape [b, nx, ny, nz * fdims].
    """
        # pyformat: enable
        p = self.params
        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]

        # Compute number of points per pillar and prepare for broadcasting.
        pillar_num_points = tf.gather_nd(input_batch.grid_num_points,
                                         input_batch.pillar_locations,
                                         batch_dims=1)
        pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis]

        # Compute mean by computing sum and dividing by number of points. Clip the
        # denominator by 1.0 to gracefully handle empty pillars.
        pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keepdims=True)
        pillar_means = pillar_sum / tf.maximum(
            tf.cast(pillar_num_points, tf.float32), 1.0)

        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                  values=[
                                      pillar_xyz - pillar_means, pillar_feats,
                                      tf.tile(pillar_means,
                                              [1, 1, npoints, 1]),
                                      tf.tile(pillar_centers,
                                              [1, 1, npoints, 1])
                                  ])
        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,
                                                pillar_concat)

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = SparseToDense(grid_shape=(nx, ny, nz),
                                       locations=pillar_locations,
                                       feats=pillar_features)
        return dense_features
Exemplo n.º 14
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Computes loss and other metrics for the given predictions.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: The output of `ComputePredictions`, contains: logits - [b,
        nx, ny, nz, na, 7 + num_classes]. na is the number of anchor
        boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt).
      input_batch: The input batch from which we accesses the groundtruth.

    Returns:
      Two dicts defined as BaseTask.ComputeLoss.
    """
        p = self.params
        predicted_residuals = py_utils.HasShape(
            predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7])
        predicted_class_logits = py_utils.HasShape(
            predictions.classification_logits,
            [-1, -1, -1, -1, p.num_anchors, p.num_classes])
        bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6)

        # Compute class and regression weights.
        class_weights = input_batch.assigned_cls_mask
        class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na])
        reg_weights = input_batch.assigned_reg_mask
        reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na])
        reg_weights = tf.expand_dims(reg_weights, -1)

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask,
                                                [bs, nx, ny, nz, na])
            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask,
                                               axis=[1, 2, 3, 4])
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))
            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [bs, 1, 1, 1, 1, 1])

            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        # Classification loss.
        class_loss_sum = self._ComputeClassificationLoss(
            predictions, input_batch, class_weights)

        # Regression loss.
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7])

        # Location and dimensions loss.
        reg_loc_and_dims_loss = self._utils.ScaledHuberLoss(
            predictions=py_utils.HasShape(predicted_residuals[..., :6],
                                          [bs, nx, ny, nz, na, 6]),
            labels=anchor_localization_residuals[..., :6],
            delta=1 / (3.**2))

        # Rotation loss is computed on a transform on rot_delta. For a direction
        # aware loss, we simply wrap the angles to -pi to pi; for a loss that is
        # symmetric to direction (i.e., rotating by pi), we use a sin transform.
        rot_delta_transform = tf.sin
        if p.direction_aware_rot_loss:
            rot_delta_transform = functools.partial(geometry.WrapAngleRad,
                                                    min_val=-np.pi,
                                                    max_val=np.pi)

        rot_delta = (predicted_residuals[..., 6:] -
                     anchor_localization_residuals[..., 6:])
        reg_rot_loss = self._utils.ScaledHuberLoss(
            predictions=rot_delta_transform(rot_delta),
            labels=tf.zeros_like(rot_delta),
            delta=1 / (3.**2))

        # Direction loss
        if p.direction_classifier_weight > 0.0:
            # The target rotations are in the assigned_gt_bbox tensor,
            # which already has assigned a gt bounding box to every anchor.
            rot_target = input_batch.assigned_gt_bbox[..., 6]
            # If rotation is > 0, the class is 1, else it is 0.
            rot_dir = tf.cast(rot_target > 0., tf.int32)

            # Compute one-hot labels as a target.
            rot_dir_onehot = tf.one_hot(rot_dir, 2)

            # Manually handle loss reduction.
            dir_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=rot_dir_onehot,
                logits=predictions.predicted_dir,
                weights=tf.squeeze(reg_weights, axis=-1),
                reduction=tf.losses.Reduction.NONE)
            # Reduce across all dimensions (we'll divide by the batch size below).
            dir_loss_sum = tf.reduce_sum(dir_loss)
        else:
            dir_loss_sum = 0.0

        # Compute loss contribution from location and dimension separately.
        reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights
        reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss)

        reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights
        reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss)

        # Compute rotation loss contribution.
        reg_rot_loss *= reg_weights
        reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss)

        # Num. predictions.
        # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes.
        preds = tf.cast(bs, class_loss_sum.dtype)

        # Normalize all of the components by batch size.
        reg_loc_loss = reg_loc_loss_sum / preds
        reg_dim_loss = reg_dim_loss_sum / preds
        reg_rot_loss = reg_rot_loss_sum / preds
        class_loss = class_loss_sum / preds
        dir_loss = dir_loss_sum / preds

        # Compute total localization regression loss.
        reg_loss = (p.location_loss_weight * reg_loc_loss +
                    p.dimension_loss_weight * reg_dim_loss +
                    p.rotation_loss_weight * reg_rot_loss)

        # Apply weights to normalized class losses.
        loss = (class_loss * p.classification_loss_weight +
                reg_loss * p.localization_loss_weight +
                dir_loss * p.direction_classifier_weight)

        metrics_dict = {
            'loss': (loss, preds),
            'loss/class': (class_loss, preds),
            'loss/reg': (reg_loss, preds),
            'loss/reg/rot': (reg_rot_loss, preds),
            'loss/reg/loc': (reg_loc_loss, preds),
            'loss/reg/dim': (reg_dim_loss, preds),
            'loss/dir': (dir_loss, preds),
        }

        # Calculate dimension errors
        min_angle_rad = -np.pi if p.direction_aware_rot_loss else 0
        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            predicted_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = {
            'residuals': predicted_residuals,
            'classification_logits': predicted_class_logits,
        }

        return metrics_dict, per_example_dict
Exemplo n.º 15
0
 def InferenceFn(x):
     return tf.reduce_sum(self._w * x) + self._b
Exemplo n.º 16
0
    def FProp(self, theta, x, paddings=None, update=False):
        """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
        p = self.params
        x = tf.cast(x, theta.means.dtype)
        if paddings is None:
            paddings = tf.zeros_like(x[:, :, 0, 0])
        # Shape [B, L, 1, 1]
        paddings_4d = paddings[:, :, None, None]

        if p.apply_layer_norm:
            x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

        # 'x' is normalized (but theta.means is not), we use negative dot product to
        # approximate the Euclidean distance here.
        dists = -2 * tf.einsum('BLNH, NKH -> BLNK', x, theta.means)
        if not p.apply_layer_norm:
            # If entries are not normalized, compute norms here.
            x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
            means_norm_sq = tf.reduce_sum(tf.square(theta.means),
                                          axis=-1,
                                          keepdims=False)
            means_norm_sq = tf.expand_dims(means_norm_sq, axis=0)
            means_norm_sq = tf.expand_dims(means_norm_sq, axis=0)
            dists += x_norm_sq + means_norm_sq

        # For padded positions we update the distances to very large numbers.
        very_large_dists = tf.ones_like(dists) * tf.constant(
            0.1, dtype=dists.dtype) * dists.dtype.max
        paddings_tiled = tf.tile(paddings_4d,
                                 [1, 1, p.num_heads, p.num_clusters])
        dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

        # Shape [B, L, N, K], the same as 'dists' above.
        nearest_one_hot = tf.one_hot(tf.math.argmin(dists, axis=-1),
                                     p.num_clusters,
                                     dtype=theta.means.dtype)
        # Same shape as the input 'x'.
        nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                     theta.means)
        diff = tf.math.squared_difference(x,
                                          tf.stop_gradient(nearest_centroid))
        diff = py_utils.ApplyPadding(paddings_4d, diff)
        diff = tf.math.reduce_mean(diff, axis=2)

        # The commitment loss which when back proped against encourages the 'x'
        # values to commit to their chosen centroids.
        diff = tf.cast(diff, tf.float32)
        paddings = tf.cast(paddings, tf.float32)
        k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 -
                                                                     paddings)
        summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

        # TODO(zhouwk): investigate normalizing theta.means after each update.
        means_norm = tf.norm(theta.means)
        summary_utils.scalar('k_means/centroid_l2_norm/min',
                             tf.math.reduce_min(means_norm))
        summary_utils.scalar('k_means/centroid_l2_norm/mean',
                             tf.math.reduce_mean(means_norm))

        if not update:
            return dists, k_means_loss

        # To update the centroids (self.vars.means), we apply gradient descent on
        # the mini-batch of input 'x', which yields the following:
        #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
        # where x_mean is the average over all the input vectors closest to this
        # centroid.
        #
        # Note that this approach is equivalent with backprop via
        #    loss = tf.math.reduce_mean(
        #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
        # , except that here the learning rate is independently set via 'decay'.

        # Ensure that the padded positions are not used to update the centroids.
        nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

        # Sum away batch and sequence length dimensions to get per cluster count.
        # Shape: [N, K]
        per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
        summary_utils.histogram('k_means/per_cluster_vec_count',
                                per_cluster_count)

        # Sum of the input 'x' per each closest centroid.
        sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

        if py_utils.use_tpu():
            per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
            sum_x = tf.tpu.cross_replica_sum(sum_x)

        if p.use_ema:
            updated_ema_count = moving_averages.assign_moving_average(
                self.vars.ema_count,
                tf.cast(per_cluster_count, self.vars.ema_count.dtype),
                p.decay,
                zero_debias=False)
            updated_ema_means = moving_averages.assign_moving_average(
                self.vars.ema_means,
                tf.cast(sum_x, self.vars.ema_means.dtype),
                p.decay,
                zero_debias=False)
            n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
            updated_ema_count = ((updated_ema_count + p.epsilon) /
                                 (n + p.num_clusters * p.epsilon) * n)
            # pylint: disable=g-no-augmented-assignment
            updated_ema_means = updated_ema_means / tf.expand_dims(
                updated_ema_count, axis=-1)
            # pylint: enable=g-no-augmented-assignment
            updated_ema_means = tf.cast(updated_ema_means,
                                        self.vars.means.dtype)
            means = tf.cast(theta.means, updated_ema_means.dtype)
            update_means_diff = updated_ema_means - means
        else:
            # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
            # cluster's position will always be 0, hence 'sum_x' in that dimension
            # will be 0.
            new_means = sum_x / tf.maximum(
                tf.constant(1.0, dtype=per_cluster_count.dtype),
                tf.expand_dims(per_cluster_count, axis=-1))
            # Note that we intentionally do not normalize the means after this update
            # as empirically this works better.
            update_means_diff = tf.cast(
                (1.0 - p.decay) * (new_means - theta.means),
                self.vars.means.dtype)
        return py_utils.with_dependencies(
            [tf.assign_add(self.vars.means, update_means_diff)],
            dists), k_means_loss
Exemplo n.º 17
0
    def _testExtendStep(self, sess, dec, encoder_outputs, tgts, num_hyps):
        p = self._DecoderParams()

        # Infer true source encoder length from the padding.
        src_enc_len = tf.reduce_sum(1 - encoder_outputs.padding, axis=0)
        src_enc_len = dec._ExpandToNumHyps(src_enc_len, num_hyps)

        # Run Fprop
        fprop_out = dec._FProp(dec.theta, encoder_outputs, tgts)
        l_out1 = fprop_out.softmax_input
        attention_map_fprop = fprop_out.attention

        # run ExtendStep
        prefix_states = py_utils.NestedMap()
        for i in range(6):
            layer_i_states = py_utils.NestedMap()
            # The first dim is for the decode step (sequence length).
            # Here's 0 as placeholder
            layer_i_states.key = tf.zeros([0, self.tgt_batch, p.model_dim])
            layer_i_states.value = tf.zeros([0, self.tgt_batch, p.model_dim])
            prefix_states['layer_%i' % i] = layer_i_states

        l_out2 = []
        per_step_atten_probs = []
        for i in range(5):
            l_i_out, prefix_states, atten_probs = dec.ExtendStep(
                dec.theta, encoder_outputs, tgts.ids[:, i], i, prefix_states)
            l_out2.append(l_i_out)
            per_step_atten_probs.append(atten_probs)
        l_out2 = tf.stack(l_out2)
        bs_atten_probs = tf.stack(per_step_atten_probs)

        attention_map_bs = py_utils.NestedMap(probs=bs_atten_probs)

        def _TransposeAttentions(x):
            return tf.transpose(x, [1, 0, 2])

        attention_map_bs = attention_map_bs.Transform(_TransposeAttentions)

        tf.global_variables_initializer().run()

        l_out1_v, l_out2_v, attention_map_fprop_v, attention_map_bs_v, src_enc_len_v = sess.run(
            [
                l_out1, l_out2, attention_map_fprop, attention_map_bs,
                src_enc_len
            ])

        # Ensure that FProp and BeamSearch output are the same.
        self.assertAllClose(l_out1_v, l_out2_v, rtol=1e-05, atol=1e-05)

        # Ensure that FProp and BeamSearch attention matrix is the same.
        self.assertAllClose(attention_map_fprop_v.probs,
                            attention_map_bs_v.probs)

        print('attention map', attention_map_fprop_v.probs)

        # End-to-end test attention probs -- ensure EOS symbol and positions
        # behind EOS have 0 probability.
        for i in range(0, len(src_enc_len_v)):
            pos = int(src_enc_len_v[i]) - 1
            self.assertEqual(
                np.count_nonzero(attention_map_fprop_v.probs[i][:, pos:]), 0)
Exemplo n.º 18
0
    def _Pack(self, batch):
        """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
        src_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.src.ids_indicator, tf.int32),
                                                axis=1)
        tgt_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.tgt.ids_indicator, tf.int32),
                                                axis=1)
        summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
        summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

        if not self.params.packing_factor:
            # Supply segment_ids and segment_pos with no packing.
            batch.src.segment_ids = batch.src.ids_indicator
            batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
            batch.tgt.segment_ids = batch.tgt.ids_indicator
            batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
            return

        (src_segment_ids, src_segment_pos, src_indices_in_input,
         tgt_segment_ids, tgt_segment_pos,
         tgt_indices_in_input) = ops.pack_sequences(
             src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
             self.params.source_max_length, self.params.target_max_length)

        uniq_src_indices_in_input = tf.unique(
            tf.reshape(src_indices_in_input, [-1])).y
        uniq_tgt_indices_in_input = tf.unique(
            tf.reshape(tgt_indices_in_input, [-1])).y
        summary_utils.histogram(
            'packed_source_seq_lengths',
            tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
        summary_utils.histogram(
            'packed_target_seq_lengths',
            tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

        # Ratio of number of non-padded tokens. If < 1.0, we are dropping
        # input data due to p.packing_factor too high.
        src_orig_tokens_count = tf.cast(tf.reduce_sum(src_actual_seq_len),
                                        tf.float32)
        src_packed_tokens_count = tf.reduce_sum(
            tf.cast(src_segment_ids > 0, tf.float32))
        summary_utils.scalar('examples/src_packed_token_ratio',
                             src_packed_tokens_count / src_orig_tokens_count)
        tgt_orig_tokens_count = tf.cast(tf.reduce_sum(tgt_actual_seq_len),
                                        tf.float32)
        tgt_packed_tokens_count = tf.reduce_sum(
            tf.cast(tgt_segment_ids > 0, tf.float32))
        summary_utils.scalar('examples/tgt_packed_token_ratio',
                             tgt_packed_tokens_count / tgt_orig_tokens_count)

        # We deferred adding .paddings and use its complement .ids_indicator
        # exclusively so that we can apply the packing with padding set to 0 for all
        # fields.
        def ApplyPackingToSource(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', src_segment_ids,
                                         src_indices_in_input)
            return ops.apply_packing(x, 0, src_segment_ids,
                                     src_indices_in_input)

        src_paddings = ops.apply_packing(batch.src.paddings, 1,
                                         src_segment_ids, src_indices_in_input)
        batch.src = batch.src.Transform(ApplyPackingToSource)
        batch.src.paddings = src_paddings
        batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
        batch.src.segment_pos = src_segment_pos

        def ApplyPackingToTarget(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', tgt_segment_ids,
                                         tgt_indices_in_input)
            return ops.apply_packing(x, 0, tgt_segment_ids,
                                     tgt_indices_in_input)

        tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1,
                                         tgt_segment_ids, tgt_indices_in_input)
        batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
        batch.tgt.paddings = tgt_paddings
        batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
        batch.tgt.segment_pos = tgt_segment_pos
Exemplo n.º 19
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Exemplo n.º 20
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

        p = self.params
        with tf.name_scope(p.name):
            # [batch, time]
            input_ids = input_batch.ids
            # [batch, time]
            paddings = input_batch.paddings

            # [batch, time]
            segment_ids = input_batch.segment_ids if p.packed_input else None

            batch = py_utils.GetShape(input_ids)[0]
            time = py_utils.GetShape(input_ids)[1]

            # Embedding layer.
            # [batch, time, dim]
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                      input_ids)
            else:
                input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
            orig_input_embs = input_embs

            # [1, time, dim]
            if p.packed_input:
                positions = input_batch.segment_pos
                position_embs = tf.expand_dims(
                    self.position_emb.FPropWithPosition(
                        theta.position_emb, positions), 0)
            else:
                position_embs = tf.expand_dims(
                    self.position_emb.FProp(theta.position_emb, time), 0)

            # [batch, time, dim]
            input_embs += position_embs

            if p.input_dropout_tpl.fprop_dtype:
                input_embs = tf.cast(input_embs,
                                     p.input_dropout_tpl.fprop_dtype)
                paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)
            # [batch, time, dim]
            transformer_input = input_embs
            # Explicitly set the input shape of Transformer layers, to avoid
            # unknown shape error occurred to tf.einsum on nonTPU devices.
            transformer_input = tf.reshape(transformer_input,
                                           [batch, time, p.model_dim])

            # Compute self-attention segment mask once.
            if p.packed_input:
                segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids, dtype=transformer_input.dtype)
            else:
                segment_mask = tf.zeros([batch, 1, time, time])

            shape = py_utils.GetShape(transformer_input)
            batch_size = shape[0]
            seq_len = shape[1]
            paddings = tf.reshape(paddings, [batch_size, seq_len])
            encoded, padding = self.transformer_stack.FProp(
                theta.transformer_stack, transformer_input, paddings,
                segment_mask)

            if p.final_layer_norm:
                encoded = self.final_ln.FProp(theta.final_ln, encoded)

            seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1),
                                  tf.int32)

            if p.output_data_format == 'TBC':
                encoded = tf.transpose(encoded,
                                       [1, 0, 2])  # [time, batch, dim]
                padding = tf.transpose(padding)  # [time, batch]
                segment_ids = tf.transpose(
                    segment_ids) if p.packed_input else None
                orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

            return py_utils.NestedMap(
                encoded=encoded,
                padding=padding,
                seq_lengths=seq_lengths,  # used by beam_search_helper.
                segment_id=segment_ids,
                embedded_inputs=orig_input_embs)
Exemplo n.º 21
0
    def _CreateCanvasAndTargets(self, batch):
        # pyformat: disable
        """Create the canvas and targets.

    Args:
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    """
        # pyformat: enable
        p = self.params

        if not p.is_eval:
            # Sample our src and tgt canvas.
            src_descriptor = self._SampleCanvasAndTargets(
                batch.src.ids, batch.src.paddings)
            tgt_descriptor = self._SampleCanvasAndTargets(
                batch.tgt.ids, batch.tgt.paddings)

            # Offset the src ids (to unshare embeddings between src/tgt). Note, we
            # only offset the canvas ids, but we do not offset the vocab ids. This
            # will result in unshared embeddings, but shared softmax. This is due to
            # GPU/TPU memory limitations, empirically it is known that unsharing
            # everything results in better performance.
            vocab_size = p.decoder.softmax.num_classes
            src_descriptor.canvas = tf.where(
                tf.equal(src_descriptor.canvas_paddings, 0),
                src_descriptor.canvas + vocab_size, src_descriptor.canvas)

            # Offset the tgt indices (need shift according to src length).
            batch_size = py_utils.GetShape(batch.src.ids)[0]
            # `target_batch` is a [num_targets, batch_size] tensor where each row
            # identifies which batch the target belongs to. Note the observation that,
            # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
            target_batch = tf.cast(
                tf.equal(
                    tf.expand_dims(tf.range(batch_size), 0),
                    tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)),
                tf.int32)
            src_lens = tf.cast(
                tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
            # `tgt_offset` is shape [num_targets] where each entry corresponds to the
            # offset needed for that target (due to the source length).
            tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
            # We shift the tgt slot without touching the batch or vocab.
            tgt_descriptor.target_indices += tf.concat([
                tf.zeros_like(tgt_offset), tgt_offset,
                tf.zeros_like(tgt_offset)
            ], 1)

            # The canvas is simply the sequence-level concat of the src and tgt.
            canvas, canvas_paddings = insertion.SequenceConcat(
                src_descriptor.canvas, src_descriptor.canvas_paddings,
                tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
            target_indices = tf.concat(
                [src_descriptor.target_indices, tgt_descriptor.target_indices],
                0)
            target_weights = tf.concat(
                [src_descriptor.target_weights, tgt_descriptor.target_weights],
                0)

            return py_utils.NestedMap(canvas=canvas,
                                      canvas_paddings=canvas_paddings,
                                      target_indices=target_indices,
                                      target_weights=target_weights)
Exemplo n.º 22
0
 def ComputePredictions(self, theta, input_batch):
     """sum(m * x) + b."""
     return tf.reduce_sum(theta.m * input_batch.src_ids, axis=1) + theta.b
Exemplo n.º 23
0
 def _NonExpandedSquaredDistanceMatrix(pa, pb):
   diff = tf.expand_dims(pa, axis=2) - tf.expand_dims(pb, axis=1)
   squared_diff = tf.square(diff)
   squared_dis = tf.reduce_sum(squared_diff, axis=3)
   return squared_dis
Exemplo n.º 24
0
def max_assignment(score: tf.Tensor,
                   *,
                   elementwise_upper_bound: tf.Tensor,
                   row_sums: tf.Tensor,
                   col_sums: tf.Tensor,
                   epsilon: float = 0.1,
                   num_iterations: int = 50,
                   use_epsilon_scaling: bool = True):
    """Differentiable max assignment with margin and upper bound constraints.

  Args:
    score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k]
      denotes the weight if the assignment on this entry is non-zero.
    elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows,
      n_columns]. Each entry denotes the maximum value assignment[i, j, k] can
      take and must be a non-negative value. For example, upper_bound[i, j,
      k]=1.0 for binary assignment problem.
    row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint.
      The output assignment p[i, j, :] must sum to row_sums[i, j].
    col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum
      constraint. The output assignment p[i, :, k] must sum to col_sums[i, k].
    epsilon: the epsilon coefficient of entropy regularization. The value should
      be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may
      not make the assignment close enough to 0 or 1.
    num_iterations: the maximum number of iterations to perform.
    use_epsilon_scaling: whether to use epsilon scaling. In practice, the
      convergence of the iterative algorithm is much better if we start by
      solving the optimization with a larger epsilon value and re-use the
      solution (i.e. dual variables) for the instance with a smaller epsilon.
      This is called the epsilon scaling trick. See [Schmitzer 2019]
      (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if
      use_epsilon_scaling=True, after each iteration we decrease the running
      epsilon by a constant factor until it reaches the target epsilon
      value. We found this to work well for gradient backward propagation,
      while the original scaling trick doesn't.

  Returns:
    A tuple with the following values.
      - assignment: a 3D tensor of size [batch_size, n_rows, n_columns].
        The output assignment.
      - used_iter: a scalar tensor indicating the number of iterations used.
      - eps: a scalar tensor indicating the stopping epsilon value.
      - delta: a scalar tensor indicating the stopping delta value (the relative
        change on the margins of assignment p in the last iteration).
  """

    # Check if all shapes are correct
    score_shape = score.shape
    bsz = score_shape[0]
    n = score_shape[1]
    m = score_shape[2]
    score = tf.ensure_shape(score, [bsz, n, m])
    elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound,
                                              [bsz, n, m])
    row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1])
    col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m])

    # the total sum of row sums must be equal to total sum of column sums
    sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums,
                                                               axis=2)
    sum_diff = tf.abs(sum_diff)
    tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff])

    # Convert upper_bound constraint into another margin constraint
    # by adding auxiliary variables & scores. Tensor `a`, `b` and `c`
    # represent the margins (i.e. reduced sum) of 3 axes respectively.
    #
    max_row_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-1,
                                 keepdims=True)
    max_col_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-2,
                                 keepdims=True)
    score_ = tf.stack([score, tf.zeros_like(score)], axis=1)  # (bsz, 2, n, m)
    a = tf.stack([row_sums, max_row_sums - row_sums], axis=1)  # (bsz, 2, n, 1)
    b = tf.stack([col_sums, max_col_sums - col_sums], axis=1)  # (bsz, 2, 1, m)
    c = tf.expand_dims(elementwise_upper_bound, axis=1)  # (bsz, 1, n, m)

    # Clip log(0) to a large negative values -1e+36 to avoid
    # getting inf or NaN values in computation. Cannot use larger
    # values because float32 would use `-inf` automatically.
    #
    tf.Assert(tf.reduce_all(a >= 0), [a])
    tf.Assert(tf.reduce_all(b >= 0), [b])
    tf.Assert(tf.reduce_all(c >= 0), [c])
    log_a = tf.maximum(tf.math.log(a), -1e+36)
    log_b = tf.maximum(tf.math.log(b), -1e+36)
    log_c = tf.maximum(tf.math.log(c), -1e+36)

    # Initialize the dual variables of margin constraints
    u = tf.zeros_like(a)
    v = tf.zeros_like(b)
    w = tf.zeros_like(c)

    eps = tf.constant(1.0 if use_epsilon_scaling else epsilon,
                      dtype=score.dtype)
    epsilon = tf.constant(epsilon, dtype=score.dtype)

    def do_updates(cur_iter, eps, u, v, w):  # pylint: disable=unused-argument
        # Epsilon scaling, i.e. gradually decreasing `eps` until it
        # reaches the target `epsilon` value
        cur_iter = tf.cast(cur_iter, u.dtype)
        scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85)
        eps = tf.maximum(epsilon, eps * scaling)
        score_div_eps = score_ / eps

        # Update u
        log_q_1 = score_div_eps + (w + v) / eps
        log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True)
        new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps

        # Update v
        log_q_2 = score_div_eps + (w + new_u) / eps
        log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True)
        new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps

        # Update w
        log_q_3 = score_div_eps + (new_u + new_v) / eps
        log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True)
        new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps
        return eps, new_u, new_v, new_w

    def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w):
        prev_sum_uvw = tf.stop_gradient((u + v + w) / eps)
        sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps)

        # Compute the relative changes on margins of P.
        # This will be used for stopping criteria.
        # Note the last update on w would guarantee the
        # margin constraint c is satisfied, so we don't
        # need to check it here.
        p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw))
        p_a = tf.reduce_sum(p, axis=-1, keepdims=True)
        p_b = tf.reduce_sum(p, axis=-2, keepdims=True)
        delta_a = tf.abs(a - p_a) / (a + 1e-6)
        delta_b = tf.abs(b - p_b) / (b + 1e-6)
        new_delta = tf.reduce_max(delta_a)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b))

        # Compute the relative changes on assignment solution P.
        # This will be used for stopping criteria.
        delta_p = tf.abs(tf.exp(prev_sum_uvw) -
                         tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p))
        return new_delta

    for cur_iter in tf.range(num_iterations):
        prev_eps, prev_u, prev_v, prev_w = eps, u, v, w
        eps, u, v, w = do_updates(cur_iter, eps, u, v, w)
    delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u,
                                     v, w)
    cur_iter = num_iterations
    assignment = tf.exp((score_ + u + v + w) / eps)
    assignment = assignment[:, 0]
    return assignment, cur_iter, eps, delta
Exemplo n.º 25
0
    def ComputeMetrics(self, decoder_outs, input_batch, ids_to_strings_fn):
        """Computes metrics on output from decoder.

    Args:
      decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the
        decode results.
      input_batch:  A `NestedMap` of tensors representing the source, target,
        and other components of the input batch.
      ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has
        shape [batch, length], lens has shape [batch], and strings has shape
        [batch].

    Returns:
      A dict of Tensors containing decoder output and metrics.
    """
        topk = self.GetTopK(decoder_outs, ids_to_strings_fn=ids_to_strings_fn)
        tgt_batch = tf.shape(topk.scores)[0]
        num_hyps_per_beam = tf.shape(topk.scores)[1]
        tgt = input_batch.tgt
        tgt_lens = tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1)),
                           tf.int32)
        tgt_lens = py_utils.HasShape(tgt_lens, [tgt_batch])
        transcripts = ids_to_strings_fn(tgt.labels, tgt_lens - 1)

        # Filter out all isolated '<noise>' tokens.
        noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$'
        filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern,
                                                 ' ')
        filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern,
                                                 ' ')
        # Compute translation quality scores for all hyps.
        filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]),
                                [1, num_hyps_per_beam])
        filtered_hyps = tf.reshape(filtered_hyps, [-1])
        filtered_refs = tf.reshape(filtered_refs, [-1])
        tf.logging.info('filtered_refs=%s', filtered_refs)
        norm_wer_errors, norm_wer_words = self.ComputeNormalizedWER(
            filtered_hyps, filtered_refs, num_hyps_per_beam)

        ret_dict = {
            'target_ids': tgt.ids,
            'target_labels': tgt.labels,
            'target_weights': tgt.weights,
            'target_paddings': tgt.paddings,
            'transcripts': transcripts,
            'topk_decoded': topk.decoded,
            'topk_ids': topk.ids,
            'topk_lens': topk.lens,
            'topk_scores': topk.scores,
            'norm_wer_errors': norm_wer_errors,
            'norm_wer_words': norm_wer_words,
        }

        if not py_utils.use_tpu() and 'sample_ids' in input_batch:
            ret_dict['utt_id'] = input_batch.sample_ids

        ret_dict.update(
            self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps,
                                                    filtered_refs, input_batch,
                                                    decoder_outs))
        return ret_dict
Exemplo n.º 26
0
    def testBasicWithAccumulator(self):

        with self.session() as sess:

            p = _SampleAccumulatorLayer.Params()
            p.name = 'sample'
            accum_layer = _SampleAccumulatorLayer(p)
            accum_obj = accum_layer.accumulators[accum_layer.accumulator_name]

            theta = py_utils.NestedMap()
            theta.x = tf.constant(2.0)
            state = py_utils.NestedMap()
            state.value = tf.constant(0.0)
            state.x_power = tf.constant(1.0)
            inputs = py_utils.NestedMap()
            inputs.coeff = tf.constant([1., 2., 3.])

            def _CellFn(theta, state, inputs):
                print('TEST ACCUM WITHIN CellFn = ', accum_obj.GetValue())
                accum_obj.Update(inputs.coeff)
                return _Poly(theta, state, inputs)

            # By doing one accumulate prior to recurrent, we ensure that incoming
            # recurrent state is preserved.
            accum_obj.Update(10.)

            # x = 2
            # 1 + 2*x + 3*x^2
            ret = recurrent.Recurrent(theta,
                                      state,
                                      inputs,
                                      _CellFn,
                                      accumulator_layer=accum_layer)

            # Verify bprop.
            y = ret[1].value
            dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])

            # 2 + 6*x
            self.assertAllClose(dx_val, 14.)
            self.assertAllClose(d_coeff_val, [1., 2., 4.])

            # acc = [1, 1+2x, 1+2x+3x^2]
            # sum(acc) = 3 + 4x + 3x^2
            acc = ret[0].value
            dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)],
                                       xs=[theta.x, inputs.coeff])
            dx_val, d_coeff_val = sess.run([dx, d_coeff])
            # 4 + 6*x
            self.assertAllClose(dx_val, 16.)
            self.assertAllClose(d_coeff_val, [3., 4., 4.])

            # Verify fprop.
            (acc, state), accum_obj_value = sess.run(
                (ret, accum_obj.GetValue()))

            # Verify that accumulators don't change fprop results.
            self.assertAllClose(acc.value, [1., 5., 17.])
            self.assertAllClose(acc.x_power, [2., 4., 8.])
            self.assertAllClose(state.value, 17.)
            self.assertAllClose(state.x_power, 8.)

            # Verify accumulator (should be 10 (initial increment) + 1 + 2 + 3).
            self.assertEqual(0, accum_obj._disable_count)
            self.assertAllClose([accum_obj_value], [16.0])
Exemplo n.º 27
0
 def GlobalBatchSize(self):
     """Returns the total number of examples in the current batch."""
     # The number of examples is indicated by the segment_ids of the target.
     num_segments = tf.math.reduce_max(self._batch.tgt.segment_ids, axis=1)
     return tf.reduce_sum(tf.cast(num_segments, dtype=tf.int32))
Exemplo n.º 28
0
    def _BuildStackedRecurrentElman(self, seqlen, trailing_pad_len, batch,
                                    dims, layers):
        tf.random.set_seed(342462)
        np.random.seed(32540)

        seqlen += trailing_pad_len
        dtype = tf.float64

        def CreateTheta():
            return py_utils.NestedMap(
                w=tf.constant(np.random.uniform(0, 0.2, (2 * dims, dims)),
                              dtype=dtype),
                b=tf.constant(np.random.uniform(0, 0.2, (dims, )),
                              dtype=dtype))

        def CreateState0():
            return py_utils.NestedMap(h=tf.constant(np.random.uniform(
                0, 0.2, (batch, dims)),
                                                    dtype=dtype),
                                      padding=tf.constant([[0]] * batch,
                                                          dtype=dtype))

        devices = ['/cpu:0'] * layers
        cell_fns = [Elman] * layers
        cell_grads = [ElmanGrad] * layers
        cell_outs = [ElmanOut] * layers
        cell_out_grads = [ElmanOutGrad] * layers
        thetas = [CreateTheta() for _ in range(layers)]
        init_states = [CreateState0() for _ in range(layers)]
        padding = np.zeros((seqlen, batch, 1))
        padding[-trailing_pad_len:, :, :] = 1.
        padding[-trailing_pad_len - 3:-trailing_pad_len - 1, :, :] = 1.
        inputs = py_utils.NestedMap(x=tf.constant(np.random.uniform(
            0, 0.2, (seqlen, batch, dims)),
                                                  dtype=dtype),
                                    padding=tf.constant(padding, dtype=dtype))
        output, _ = recurrent.StackedRecurrent(devices=devices,
                                               cell_fns=cell_fns,
                                               cell_grads=cell_grads,
                                               cell_outs=cell_outs,
                                               cell_out_grads=cell_out_grads,
                                               thetas=thetas,
                                               init_states=init_states,
                                               inputs=inputs)
        o = output.x
        if 'padding' in inputs:
            o *= (1 - inputs.padding)
        loss = tf.reduce_sum(tf.square(o))

        xs = py_utils.Flatten(thetas + [py_utils.NestedMap(x=inputs.x)])
        dxs = tf.gradients(ys=loss, xs=xs)

        # Reference implementation using Recurrent().
        ref = inputs
        for i in range(layers):
            ref = ElmanOut(
                recurrent.Recurrent(cell_fn=cell_fns[i],
                                    cell_grad=cell_grads[i],
                                    theta=thetas[i],
                                    state0=init_states[i],
                                    inputs=ref)[0])
        return ref.x, output.x, loss, xs, dxs
    def testTargetSequenceSampler(self, use_recurrent):
        with self.session(use_gpu=False):
            np.random.seed(9384758)
            tf.random.set_seed(8274758)
            vocab_size = 12
            src_len = 5
            tgt_len = 7
            batch_size = 2

            def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs,
                                       num_hyps_per_beam):
                self.assertEqual(1, num_hyps_per_beam)
                logits = tf.zeros((batch_size, vocab_size), dtype=tf.float32)
                return (py_utils.NestedMap(log_probs=logits),
                        py_utils.NestedMap(step=tf.constant(0)))

            def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                          unused_step_ids, states,
                                          num_hyps_per_beam, unused_cur_step):
                self.assertEqual(1, num_hyps_per_beam)
                logits = tf.random.stateless_normal([batch_size, vocab_size],
                                                    seed=[8273747, 9])
                return (py_utils.NestedMap(log_probs=logits),
                        py_utils.NestedMap(step=states.step + 1))

            def PostBeamSearchStepCallback(unused_theta,
                                           unused_encoder_outputs,
                                           unused_new_step_ids, states):
                return states

            src_enc = tf.random.stateless_normal([src_len, batch_size, 8],
                                                 seed=[982774838, 9])
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)

            theta = py_utils.NestedMap()
            random_seed = tf.constant(123)
            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh',
                target_seq_len=tgt_len,
                use_recurrent=use_recurrent)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [[9, 0, 2, 2, 2, 2, 2], [0, 0, 11, 8, 1, 0, 7]]
            expected_lens = [3, 7]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)

            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh', target_seq_len=tgt_len, top_k=1)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [[0, 0, 0, 0, 0, 0, 0], [7, 7, 7, 7, 7, 7, 7]]
            expected_lens = [7, 7]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)

            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh', target_seq_len=tgt_len, top_k=5)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [[5, 0, 0, 0, 8, 0, 6], [7, 7, 10, 0, 7, 7, 0]]
            expected_lens = [7, 7]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)

            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh', target_seq_len=tgt_len, temperature=0.2)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [[0, 0, 0, 0, 0, 0, 9], [0, 0, 11, 7, 1, 0, 7]]
            expected_lens = [7, 7]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)

            p = target_sequence_sampler.TargetSequenceSampler.Params().Set(
                name='bsh',
                target_seq_len=tgt_len,
                use_recurrent=use_recurrent,
                nucleus_p=0.5)
            seq_sampler = p.Instantiate()
            decoder_output = seq_sampler.Sample(theta, encoder_outputs,
                                                random_seed,
                                                InitBeamSearchCallBack,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)
            ids, lens = self.evaluate([
                decoder_output.ids,
                tf.reduce_sum(1 - decoder_output.paddings, 1),
            ])
            print(np.array_repr(ids))
            print(np.array_repr(lens))
            expected_ids = [[9, 0, 0, 0, 9, 0, 9], [0, 0, 11, 10, 1, 0, 10]]
            expected_lens = [7, 7]
            self.assertAllEqual(expected_ids, ids)
            self.assertAllEqual(expected_lens, lens)
Exemplo n.º 30
0
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        # Create prefix of start tokens.
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ - - - - - - P P P P P P P* - - - ]   ^
        # [ - - P P P P P P P P P P P* - - - ]   | batch
        # [ - - - - - - - - - - - P P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1)
        pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0),
                          tf.int32)  # Exclude final pfx token.
        pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad
        pfx_last = pfx[:, -1]

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        # Remove padding.
        assert buf_size > pfx_max
        pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)],
                              constant_values=1)
        pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_indexes * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest