Exemplo n.º 1
0
 def _Real():
   return tf.cond(
       tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)
 def StopFn(t, theta, state):
     del t, theta  # Unused: this stop function only uses the state ids.
     return tf.equal(state.ids, p.target_eos_id)
Exemplo n.º 3
0
    def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                  num_hyps_per_beam, *args, **kwargs):
      """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
      p = self.params
      time_step = states.time_step
      bs_results, out_states = self._PreBeamSearchStepCallback(
          theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args,
          **kwargs)
      labels = encoder_outputs.targets.labels
      weights = encoder_outputs.targets.weights

      def ApplyBias():
        """Bias and update log_probs and consistent."""

        def TileForBeamAndFlatten(tensor):
          tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
          tensor = tf.tile(
              tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
          tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
          return tf.reshape(tensor, [tgt_batch])

        # Consistent if step_ids == labels from previous step
        # TODO(navari): Consider updating consistent only if weights > 0. Then
        # re-evaluate the need for bias_only_if_consistent=True.
        # Note that prev_label is incorrrect for step 0 but is overridden later
        prev_label = TileForBeamAndFlatten(
            tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
        is_step0 = tf.equal(time_step, 0)
        local_consistence = tf.math.logical_or(
            is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
        consistent = tf.math.logical_and(states.consistent, local_consistence)

        # get label, weight slices corresponding to current time_step
        label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
        weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1))
        if p.bias_only_if_consistent:
          weight = weight * tf.cast(consistent, py_utils.FPropDtype(p))

        # convert from dense label to sparse label probs
        vocab_size = tf.shape(bs_results.log_probs)[1]
        uncertainty = tf.constant(1e-10, py_utils.FPropDtype(
            p))  # avoid 0 probs which may cause issues with log
        label_probs = tf.one_hot(
            label,
            vocab_size,
            on_value=1 - uncertainty,
            off_value=uncertainty /
            tf.cast(vocab_size - 1, py_utils.FPropDtype(p)),
            dtype=py_utils.FPropDtype(p))  # [tgt_batch, vocab_size]
        pred_probs = tf.exp(bs_results.log_probs)

        # interpolate predicted probs and label probs
        weight = tf.expand_dims(weight, 1)
        probs = py_utils.with_dependencies([
            py_utils.assert_less_equal(weight, 1.),
            py_utils.assert_greater_equal(weight, 0.)
        ], (1.0 - weight) * pred_probs + weight * label_probs)
        return tf.math.log(probs), consistent

      def NoApplyBias():
        """No-op. Return original log_probs and consistent."""
        return bs_results.log_probs, states.consistent

      log_probs, consistent = tf.cond(
          tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias)
      bs_results.log_probs = log_probs
      out_states.consistent = consistent

      return bs_results, out_states
Exemplo n.º 4
0
  def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
    """Loop body for farthest point sampler."""

    def _GetRandomRealPoint():
      """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
      random_values = tf.random.uniform((batch_size, num_points),
                                        minval=0,
                                        maxval=1,
                                        dtype=tf.float32,
                                        seed=random_seed)
      random_values = tf.where(
          tf.equal(padding, 0.0), random_values, padding * 10)
      return tf.argmin(random_values, axis=1, output_type=tf.int32)

    def _GetFurthestPoint():
      """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
      # Set padded points distance to negative so they aren't selected.
      padding_masked_distance_to_selected = tf.where(
          tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
              (batch_size, num_points), dtype=tf.float32))
      # But only do this when we still have valid points left.
      padding_masked_distance_to_selected = tf.where(
          tf.less(curr_idx, num_valid_points),
          padding_masked_distance_to_selected, distance_to_selected)
      return tf.argmax(
          padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)

    def _GetSeededPoint():
      """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
      return tf.ones((batch_size,), dtype=tf.int32) * curr_idx

    # Select indices for this loop iteration.
    def _Seeded():
      return tf.cond(
          tf.less(curr_idx, num_seeded_points), _GetSeededPoint,
          _GetFurthestPoint)

    def _Real():
      return tf.cond(
          tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint)

    new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real)
    sampled_idx = sampled_idx.write(curr_idx, new_selected)

    # Extract the distance to the latest point selected to update
    # distance_to_selected.
    new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected],
                                       axis=1)
    if precomputed_squared_distance is not None:
      new_distance = tf.gather_nd(precomputed_squared_distance,
                                  new_selected_gather_idx)
    else:
      new_points = tf.reshape(
          tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims])
      new_distance = tf.reshape(
          SquaredDistanceMatrix(points, new_points), [batch_size, num_points])

    is_newly_closest = tf.less(new_distance, distance_to_selected)
    distance_to_selected = tf.minimum(distance_to_selected, new_distance)

    # Track the index to the closest selected point.
    new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
    closest_idx = tf.cond(
        tf.equal(curr_idx, 0),
        # At the first loop iteration, the init points are the closest.
        lambda: new_selected_tiled,
        # Otherwise, update with the new points based on the distances.
        lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx))
    return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
Exemplo n.º 5
0
    def _EncodeToIds(self, word):
        # Below:
        #   * a token is a wordpiece ID.
        #   * the tokens array will be merged in-place.
        #   * the candidates array is an array of size len(tokens) - 1.
        #     It contains the token for the merged wordpiece, if it exists,
        #     -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]).
        # First, split into basic UTF-8 characters (letters).
        chars = tf.strings.unicode_split(word, 'UTF-8')
        tokens = self._StringToToken(chars)
        tokens = tf.where(
            tf.equal(tokens, NO_TOKEN),
            # Unseen character.
            tf.broadcast_to(self.unk_id, tf.shape(tokens)),
            tokens)
        # Create initial candidate list.
        candidates = tf.map_fn(self._MergeTokens, (tokens[:-1], tokens[1:]),
                               dtype=tokens.dtype)

        def _ShouldMerge(unused_tokens, candidates):
            """Merge until not possible, or we abort early according to merge_prob."""
            return tf.logical_and(
                tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)),
                tf.random.uniform([]) < self._merge_prob)

        def _MergeOneToken(tokens, i):
            return tf.expand_dims(self._MergeTokens(
                (tokens[i], tokens[i + 1])),
                                  axis=-1)

        def _MergeCandidates(tokens, candidates):
            """Merge in the reverse binary tree."""
            best_id = tf.argmin(candidates, output_type=tf.int32)
            # Perform the merge at position best_id.
            tokens = tf.concat([
                tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:]
            ],
                               axis=0)
            # Recompute the merge candidates.
            # Only the neighbors of best_id need to be recomputed.
            empty = tf.zeros([0], dtype=candidates.dtype)

            def _MergeLeft():
                return tf.concat([
                    candidates[:best_id - 1],
                    _MergeOneToken(tokens, best_id - 1)
                ],
                                 axis=0)

            left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty,
                                      _MergeLeft)

            def _MergeRight():
                return tf.concat([
                    _MergeOneToken(tokens, best_id), candidates[best_id + 2:]
                ],
                                 axis=0)

            right_candidates = tf.cond(
                tf.greater_equal(best_id,
                                 tf.size(tokens) - 1), lambda: empty,
                _MergeRight)

            candidates = tf.concat([left_candidates, right_candidates], axis=0)
            return tokens, candidates

        return tf.while_loop(_ShouldMerge,
                             _MergeCandidates, (tokens, candidates),
                             parallel_iterations=1,
                             back_prop=False)[0]
Exemplo n.º 6
0
    def Sample(self, decoder_theta, encoder_outputs, random_seed,
               init_state_callback, pre_step_callback, post_step_callback):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(theta=decoder_theta,
                                             random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        bs_result, bs_state = init_state_callback(recurrent_theta.theta,
                                                  encoder_outputs,
                                                  num_hyps_per_beam=1)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=tf.fill([batch], tf.to_int32(p.target_sos_id)),
            bs_state=bs_state)
        inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    recurrent_theta.theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=1)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                # Sample ids from logits. [batch].
                state1.ids = tf.reshape(
                    tf.random.stateless_multinomial(
                        state1.logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        output_dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.logical_and(bs_result.is_last_chunk,
                                       tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    recurrent_theta.theta, recurrent_theta.encoder_outputs,
                    state1.ids, bs_state1)
            return state1, py_utils.NestedMap()

        accumulated_states, _ = recurrent.Recurrent(recurrent_theta,
                                                    recurrent_state0, inputs,
                                                    Step)
        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
Exemplo n.º 7
0
  def _AddNoise(self, batch):
    """Adding noise the src (see https://arxiv.org/pdf/1711.00043).

    This function implement 3 types of noise (hyparams defined in
    self.params.denoise):
    1) slightly shuffle the sentence following p.shuffle_tok_range
    2) randomly drop tokens with probability p.drop_tok_prob
    3) randomly mask tokens with probability p.blank_tok_prob
    The noises are added to the input with probability p.noise_sent_prob.

    Args:
      batch: a `.NestedMap` of the input batch.
    """

    def IsSpecialExample(task_ids, special_task_ids):
      """A utility function indicates whether inputs belong to specific tasks.

      Args:
        task_ids: Task ids for the input batch. Tensor of shape [batch].
        special_task_ids: A list of specified task ids.

      Returns:
        A tensor indicating whether each sample in the batch belong to the
        specified task. Return a tensor of size [batch].
      """
      batch_size = py_utils.GetShape(task_ids)[0]
      return tf.reduce_any(
          tf.equal(
              tf.expand_dims(task_ids, -1),
              tf.cast(
                  tf.broadcast_to(
                      special_task_ids,
                      [batch_size, len(special_task_ids)]), tf.int32)), -1)

    p = self.params.denoise
    batch_size = tf.shape(batch.src.ids)[0]
    source_max_len = tf.shape(batch.src.ids)[1]

    # Shuffle tokens according to p.shuffle_tok_range
    noise = tf.random.uniform([batch_size, source_max_len], 0,
                              p.shuffle_tok_range + 1)

    # Don't shuffle eos or padding
    shuffle_tok_range = tf.fill([batch_size, source_max_len],
                                float(p.shuffle_tok_range))
    shifted_paddings = tf.pad(
        batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1)
    noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range)
    indices = tf.broadcast_to(
        tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len])
    noisy_indices = tf.cast(indices, dtype=tf.float32) + noise
    permutations = tf.argsort(noisy_indices)
    stacked = tf.stack([batch.src.ids, permutations], axis=1)
    denoise_src_ids = tf.stack(
        tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0)

    # Select tokens to drop with probability=p.drop_tok_prob
    random_drop_tok = tf.random.uniform([batch_size, source_max_len])
    # Don't drop eos token
    is_keep_tok = tf.math.logical_or(
        tf.greater(random_drop_tok, p.drop_tok_prob),
        tf.equal(denoise_src_ids, self._src_tokenizer.eos_id))
    denoise_src_ids = tf.ragged.boolean_mask(denoise_src_ids,
                                             is_keep_tok).to_tensor(
                                                 default_value=0,
                                                 shape=tf.shape(batch.src.ids))
    denoise_src_paddings = tf.ragged.boolean_mask(
        batch.src.paddings, is_keep_tok).to_tensor(
            default_value=1, shape=tf.shape(batch.src.ids))

    # Select tokens to blank with probability=p.blank_tok_prob
    # Don't blank eos token
    random_blank_tok = tf.random.uniform([batch_size, source_max_len])
    shifted_paddings = tf.pad(
        denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1)
    is_blank_tok = tf.math.logical_and(
        tf.less(random_blank_tok, p.blank_tok_prob),
        tf.equal(shifted_paddings, 0))
    blank_id = tf.fill([batch_size, source_max_len], p.blank_id)
    denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids)

    # Select denoising task examples with probability=p.denoise_sent_prob
    random_uniform_sent = tf.random.uniform([batch_size])
    is_denoise_sent = tf.math.logical_and(
        tf.less(random_uniform_sent, p.noise_sent_prob),
        IsSpecialExample(
            self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids))
    batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids)
    batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings,
                                  batch.src.paddings)
    batch.src.ids_indicator = 1 - batch.src.paddings
    batch.src.weights = batch.src.ids_indicator
Exemplo n.º 8
0
 def bucket_fn(num):
   # Drops record if num[0] is odd.
   return tf.cond(
       tf.equal(tf.math.floormod(num[0], 2), 0), lambda: 1,
       lambda: -tf.cast(num[0], tf.int32))
Exemplo n.º 9
0
    def __init__(self,
                 learning_rate,
                 momentum=0.0,
                 initial_accumulator_value=0.0,
                 start_preconditioning_steps=1000,
                 statistics_computation_frequency=1,
                 matrix_epsilon=1e-6,
                 synchronous_preconditioning=False,
                 second_moment_averaging=1.0,
                 fallback_to_diagonal_dim=4096,
                 max_any_dim=6656,
                 block_size=4096,
                 block_partition_threshold_size=1000000,
                 global_step=None,
                 exponent_multiplier=1.0,
                 name="DistributedShampoo"):
        """Construct a DistributedShampoo optimizer.

    Args:
      learning_rate: A `Tensor` or a floating point value.  The learning rate.
      momentum: A `Tensor` or a floating point value. Momentum is not applied to
        sparse updates.
      initial_accumulator_value: A floating point value.
      start_preconditioning_steps: A int32 value which indicates when to start
        preconditioning.
      statistics_computation_frequency: A int32 step value which indicates how
        often to compute statistics for preconditioning.
      matrix_epsilon: An epsilon regularizer to make the matrices positive
        definite.
      synchronous_preconditioning: Whether to run preconditioning synchronously.
      second_moment_averaging: 1.0 means sum of gradients squares, while less
        than 1.0 switches to RMSProp style exponential moving averages of the
        second moments.
      fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any
        of the dimension is larger than fallback_to_diagonal_dim.
      max_any_dim: If maximum value for any dimension is greater than this value
        we skip preconditioning and fall back to the diagonal.
      block_size: Dimension of the partitioned tensors.
      block_partition_threshold_size: Partitions diemnsions beyond this size.
      global_step: Global step for training.
      exponent_multiplier: A multiplier 'e` for the exponent for the inverse
        calculation. e * -1/(2*rank). Only applies when calculating inverses
        through svd.
      name: Optional name prefix for the operations created when applying
        gradients.
    """
        super(DistributedShampoo, self).__init__(False, name)
        self._learning_rate = learning_rate
        self._momentum = momentum
        self._initial_accumulator_value = initial_accumulator_value
        self._start_preconditioning_steps = start_preconditioning_steps
        self._matrix_epsilon = matrix_epsilon
        self._synchronous_preconditioning = synchronous_preconditioning
        self._second_moment_averaging = second_moment_averaging
        self._fallback_to_diagonal_dim = fallback_to_diagonal_dim
        self._max_any_dim = max_any_dim
        self._block_size = block_size
        # NOTE: On XLA - int64 is not handled properly.
        if global_step is not None:
            self._global_step = tf.cast(tf.identity(global_step), tf.int32)
        else:
            self._global_step = tf.cast(
                tf.identity(tf.train.get_or_create_global_step()), tf.int32)
        self._run_nondiagonal_update = tf.greater_equal(
            self._global_step, self._start_preconditioning_steps)
        start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32)
        global_step_f = tf.cast(self._global_step, tf.float32)
        self._run_nondiagonal_update_warmup = tf.minimum(
            1.0,
            tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0))
        # Computes statistics every K steps.
        self._statistics_computation_frequency = statistics_computation_frequency
        self._run_statistics_computation = tf.equal(
            tf.mod(self._global_step, self._statistics_computation_frequency),
            0)
        # All vars that are preconditioned.
        self._all_vars_for_preconditioning = []
        self._exponent_multiplier = exponent_multiplier
        self._partition_info = PartitionConfig(block_partition_threshold_size,
                                               block_size)
        self._partitioner_metadata = {}
Exemplo n.º 10
0
 def _ApplyMass(task_id):
   mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32)
   return tf.reduce_any(tf.equal(task_id, mass_task_ids))
Exemplo n.º 11
0
 def PostBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                new_step_ids, states):
   return py_utils.NestedMap(
       step=states.step + 1,
       src_step=states.src_step + tf.cast(
           tf.equal(new_step_ids, p.target_eoc_id), dtype=tf.int32))
Exemplo n.º 12
0
def GatherK(selected_pos, values, k, num_devices=1):
  """Gather up to k elements from given tensors at selected pos under SPMD.

  Example::

    # Input
    k = 3

    selected_pos = [
        [0, 0, 1, 1],
        [0, 1, 1, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],  # topk(k=3) largest indices are selected in this row.
    ]

    value_2d = [
        [1, 3, 5, 7],
        [9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39],
    ]

    # Output:
    output = [
        [0, 5, 7],
        [0, 11, 13],
        [0, 0, 0],
        [25, 27, 29],
        [35, 37, 39],
    ]

    # Output padding:
    output_padding = [
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0],
        [0, 0, 0],
    ]

  Args:
    selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time].
    values: a list of tensors, the rank of each is at least rank=2. [batch,
      time, ...].
    k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a
      compile-time constant.
    num_devices: number of TPU devices used in xla_sharding SPMD.

  Returns:
    A tuple (output, padding).

    - output: a list of tensors of shape [batch, k, ...].
    - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations.
  """
  global_batch, seq_len = py_utils.GetShape(selected_pos, 2)
  if num_devices:
    device_batch = global_batch // num_devices
  else:
    device_batch = global_batch

  for i in range(len(values)):
    # Assert the first 2 dim of values[i] is [global_batch, seq_len]
    values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2)
  # indices are 1-based for now, to distinguish between padding and selected
  # locations.
  indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32)
  # [1, seq_len]
  indices = tf.expand_dims(indices, axis=0)

  # if 0, the position is not selected.
  # [1, seq_len] * [global_batch, seq_len] => [global_batch, t]
  # -- topk --> [global_batch, k]
  topk_indices, _ = tf.math.top_k(
      indices * tf.cast(selected_pos, indices.dtype), k)

  # [global_batch, k], sorted in ascending order.
  indices = tf.reverse(topk_indices, [-1])
  # [global_batch, k], padded positions are '1's.
  padding = tf.cast(tf.equal(indices, 0), values[0].dtype)
  padding = Split(padding, 0, num_devices)

  # [global_batch, k], zero_based_indices
  mp_idx = tf.maximum(0, indices - 1)
  mp_idx = Split(mp_idx, 0, num_devices)

  # [device_batch, k]
  if num_devices > 1 and py_utils.use_tpu():
    mp_idx = xla_sharding.auto_to_manual_spmd_partition(
        mp_idx, xla_sharding.get_op_sharding(mp_idx.op))
  # [device_batch, k, 1]
  mp_idx = tf.expand_dims(mp_idx, -1)

  # [device_batch]
  batch_ids = tf.range(device_batch, dtype=tf.int32)
  # [device_batch, 1, 1]
  batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1])
  # [device_batch, k, 1]
  batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1])

  # [device_batch, k, 2]
  final_indices = tf.concat([batch_ids, mp_idx], axis=-1)

  output = []
  for v in values:
    # Begin manually partition gather.
    v = Split(v, 0, num_devices)
    v_shape = v.shape.as_list()
    if num_devices > 1 and py_utils.use_tpu():
      op_sharding = xla_sharding.get_op_sharding(v.op)
      v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding)
    # Returns [global_batch, k, ...]
    v_out = tf.gather_nd(v, final_indices)

    if num_devices > 1 and py_utils.use_tpu():
      v_shape[1] = k
      v_out = xla_sharding.manual_to_auto_spmd_partition(
          v_out, op_sharding, full_shape=tf.TensorShape(v_shape))
    output.append(v_out)

  return output, padding
Exemplo n.º 13
0
def GetSentenceEmbeddings(inputs, segment_id):
  """Returns the average sentence embedding to gate by.

  Example::

    inputs: <tf.Variable 'Variable:0' shape=(10, 3) dtype=float64, numpy=
             array([[0.41258181, 0.61071571, 0.63777673],
                    [0.65571443, 0.54297766, 0.10288261],
                    [0.8577837 , 0.81915847, 0.61996602],
                    [0.46897136, 0.92662692, 0.32942232],
                    [0.60162383, 0.3385829 , 0.3408632 ],
                    [0.40774807, 0.86139635, 0.00927162],
                    [0.56126334, 0.51748817, 0.07791397],
                    [0.06595223, 0.95529216, 0.34458149],
                    [0.1238971 , 0.49897169, 0.25216722],
                    [0.11221774, 0.50284604, 0.84106974]])>
    segment_id: <tf.Variable 'Variable:0' shape=(10,) dtype=int64,
                 numpy=array([1, 1, 2, 0, 0, 3, 3, 3, 3, 0])>

  Args:
    inputs: G`SM Tensor.
    segment_id: G`S Tensor.

  Returns:
    sentence_embeddings: GSM Tensor that is an average of the input embeddings
    per segment.
  """
  reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])

  # We set num_segments to a large value so that shape is known at compile time.
  max_segments = py_utils.GetShape(reshaped_inputs)[0]
  # We change the padding to be max_segments - 1 instead of 0 because
  # tf.math.unsorted_segment_mean because it only accepts values between 1 and
  # max_segments.
  modified_segment_id = tf.cast(
      segment_id + max_segments * tf.cast(
          tf.equal(segment_id, 0), dtype=tf.dtypes.as_dtype(segment_id.dtype)) -
      1,
      dtype=tf.int32)
  reshaped_segment_id = tf.reshape(modified_segment_id, [-1])

  # Takes the mean of all segments, w/ 0s for the padding.
  params = tf.concat([
      tf.math.unsorted_segment_mean(reshaped_inputs, reshaped_segment_id,
                                    max_segments)[:-1],
      tf.zeros([1, reshaped_inputs.shape[-1]], dtype=reshaped_inputs.dtype)
  ],
                     axis=0)
  raw_sentence_embeddings = tf.gather(params, modified_segment_id)

  # sentence_embedding: <tf.Tensor: shape=(10, 3), dtype=float64, numpy=
  #                     array([[0.92657252, 0.40264503, 0.55494457],
  #                            [0.92657252, 0.40264503, 0.55494457],
  #                            [0.08002721, 0.02360659, 0.63688627],
  #                            [0.        , 0.        , 0.        ],
  #                            [0.        , 0.        , 0.        ],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.        , 0.        , 0.        ]])>
  sentence_embeddings = tf.reshape(raw_sentence_embeddings, inputs.shape)

  return sentence_embeddings
Exemplo n.º 14
0
    def Sample(self,
               decoder_theta,
               encoder_outputs,
               random_seed,
               init_state_callback,
               pre_step_callback,
               post_step_callback,
               init_step_ids=None):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.
      init_step_ids: [batch], optional init step ids, default to SOS.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        assert p.top_k >= 0
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  num_hyps_per_beam=1)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=init_step_ids if init_step_ids is not None else tf.fill(
                [batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)
        inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=1)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            return state1, py_utils.NestedMap()

        def StopFn(t, theta, state):
            del t, theta  # Unused: this stop function only uses the state ids.
            return tf.equal(state.ids, p.target_eos_id)

        if p.use_stop_fn:
            stop_fn = StopFn
        else:
            stop_fn = None

        accumulated_states, _ = recurrent.Recurrent(
            recurrent_theta,
            recurrent_state0,
            inputs,
            Step,
            stop_fn=stop_fn,
            allow_implicit_capture=True)
        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
Exemplo n.º 15
0
  def _Extract(self, features):
    p = self.params
    # Label values match the proto enum car.open_dataset.Label.Type. The value
    # range is [1..4] for non-background labels.
    labels = tf.cast(_Dense(features['labels']), tf.int32)
    labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects])
    label_ids = tf.reshape(_Dense(features['label_ids'], ''), [-1])
    label_ids = py_utils.PadOrTrimTo(label_ids, [p.max_num_objects], '')
    bboxes_3d = tf.reshape(_Dense(features['bboxes_3d']), [-1, 7])
    bboxes_3d_mask = tf.ones([tf.shape(bboxes_3d)[0]])
    bboxes_3d_num_points = tf.cast(
        _Dense(features['bboxes_3d_num_points']), tf.int32)
    bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7])
    bboxes_3d_mask = py_utils.PadOrTrimTo(bboxes_3d_mask, [p.max_num_objects])
    bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points,
                                                [p.max_num_objects])
    label_metadata = tf.reshape(_Dense(features['label_metadata']), [-1, 4])
    label_metadata = py_utils.PadOrTrimTo(label_metadata,
                                          [p.max_num_objects, 4])

    detection_difficulties = py_utils.PadOrTrimTo(
        tf.cast(_Dense(features['detection_difficulties']), tf.int32),
        [p.max_num_objects])
    single_frame_detection_difficulties = py_utils.PadOrTrimTo(
        tf.cast(
            _Dense(features['single_frame_detection_difficulties']), tf.int32),
        [p.max_num_objects])
    tracking_difficulties = py_utils.PadOrTrimTo(
        tf.cast(_Dense(features['tracking_difficulties']), tf.int32),
        [p.max_num_objects])
    unfiltered_bboxes_3d_mask = bboxes_3d_mask

    if p.filter_labels:
      valid_labels = tf.constant([p.filter_labels])
      bbox_mask = tf.reduce_any(
          tf.equal(tf.expand_dims(labels, 1), valid_labels), axis=1)
      bboxes_3d_mask *= tf.cast(bbox_mask, tf.float32)

    outputs = {
        'labels':
            labels,
        'label_ids':
            label_ids,
        'detection_difficulties':
            detection_difficulties,
        'single_frame_detection_difficulties':
            single_frame_detection_difficulties,
        'tracking_difficulties':
            tracking_difficulties,
        'bboxes_3d':
            bboxes_3d,
        'bboxes_3d_mask':
            bboxes_3d_mask,
        'bboxes_3d_num_points':
            bboxes_3d_num_points,
        'unfiltered_bboxes_3d_mask':
            unfiltered_bboxes_3d_mask,
        'speed':
            label_metadata[:, :2],
        'acceleration':
            label_metadata[:, 2:],
    }

    return py_utils.NestedMap(outputs)
Exemplo n.º 16
0
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)
Exemplo n.º 17
0
  def _InferenceSubgraph_Default(self):
    """Default inference subgraph.

    Returns:
      (fetches, feeds):

      - fetches: A dictionary of fetches, containing:

        - log_pplx_per_token: A matrix of shape [batch, time]. [i, j]
          is i-th input text's j-th token's log prob.
        - paddings: A matrix of shape [batch, time]. The padding mask.
        - log_pplx_per_sample: A vector of shape [batch]. [i]
          is i-th input text's log prob.
        - num_oovs_per_sample: A vector of shape [batch] counting the total
          number of out-of-vocabulary tokens in each input.
        - tokens_from_labels: A vector of shape [batch] returning the predicted
          tokens as a sequence after mapping them back to strings from ids using
          the vocabulary.
        - ids: A matrix of shape [batch, time]. [i, j]
          is i-th input text's j-th token's id.

      - feeds: A dictionary of feeds, containing:

        - text: A placeholder for a vector of strings.
    """
    text = tf.placeholder(tf.string, shape=[None])
    # [batch, time]
    ids, labels, paddings = self.input_generator.StringsToIds(text)
    lengths = tf.reduce_sum(tf.cast(1 - paddings, tf.int32), axis=1)
    tokens_from_labels = self.input_generator.IdsToStrings(labels, lengths)
    oovs = tf.equal(labels, self.input_generator.tokenizer.unk_id)
    num_oovs_per_sample = tf.cast(
        tf.round(
            tf.reduce_sum(tf.cast(oovs, tf.float32) * (1 - paddings), axis=1)),
        tf.int32)
    # [time, batch]
    ids, paddings, labels, weights = self._TrimIfPossibleThenTranspose(
        ids, paddings, labels, 1.0 - paddings)
    batch_size = tf.shape(ids)[1]
    xent_output, _ = self.lm.FPropDefaultTheta(
        inputs=ids,
        paddings=paddings,
        state0=self.lm.zero_state(self.theta.lm, batch_size),
        labels=py_utils.NestedMap(class_ids=labels, class_weights=weights))

    per_example_xent = py_utils.HasShape(xent_output.per_example_xent,
                                         tf.shape(ids))
    log_pplx_per_sample = tf.reduce_sum(
        per_example_xent * (1 - paddings), axis=0)
    fetches = {
        'log_pplx_per_token':  # [batch, time]
            tf.transpose(per_example_xent),
        'paddings':  # [batch, time]
            tf.transpose(paddings),
        'lengths':  # [batch]
            lengths,
        'log_pplx_per_sample':  # [batch]
            log_pplx_per_sample,
        'num_oovs_per_sample':  # [batch], int32
            num_oovs_per_sample,
        'tokens_from_labels':  # [batch], string
            tokens_from_labels,
        'ids':  # [batch, time], int32
            ids
    }
    feeds = {'text': text}
    return fetches, feeds
Exemplo n.º 18
0
    def _CreateCanvasAndTargets(self, batch):
        # pyformat: disable
        """Create the canvas and targets.

    Args:
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    """
        # pyformat: enable
        p = self.params

        if not p.is_eval:
            # Sample our src and tgt canvas.
            src_descriptor = self._SampleCanvasAndTargets(
                batch.src.ids, batch.src.paddings)
            tgt_descriptor = self._SampleCanvasAndTargets(
                batch.tgt.ids, batch.tgt.paddings)

            # Offset the src ids (to unshare embeddings between src/tgt). Note, we
            # only offset the canvas ids, but we do not offset the vocab ids. This
            # will result in unshared embeddings, but shared softmax. This is due to
            # GPU/TPU memory limitations, empirically it is known that unsharing
            # everything results in better performance.
            vocab_size = p.decoder.softmax.num_classes
            src_descriptor.canvas = tf.where(
                tf.equal(src_descriptor.canvas_paddings, 0),
                src_descriptor.canvas + vocab_size, src_descriptor.canvas)

            # Offset the tgt indices (need shift according to src length).
            batch_size = py_utils.GetShape(batch.src.ids)[0]
            # `target_batch` is a [num_targets, batch_size] tensor where each row
            # identifies which batch the target belongs to. Note the observation that,
            # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
            target_batch = tf.cast(
                tf.equal(
                    tf.expand_dims(tf.range(batch_size), 0),
                    tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)),
                tf.int32)
            src_lens = tf.cast(
                tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
            # `tgt_offset` is shape [num_targets] where each entry corresponds to the
            # offset needed for that target (due to the source length).
            tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
            # We shift the tgt slot without touching the batch or vocab.
            tgt_descriptor.target_indices += tf.concat([
                tf.zeros_like(tgt_offset), tgt_offset,
                tf.zeros_like(tgt_offset)
            ], 1)

            # The canvas is simply the sequence-level concat of the src and tgt.
            canvas, canvas_paddings = insertion.SequenceConcat(
                src_descriptor.canvas, src_descriptor.canvas_paddings,
                tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
            target_indices = tf.concat(
                [src_descriptor.target_indices, tgt_descriptor.target_indices],
                0)
            target_weights = tf.concat(
                [src_descriptor.target_weights, tgt_descriptor.target_weights],
                0)

            return py_utils.NestedMap(canvas=canvas,
                                      canvas_paddings=canvas_paddings,
                                      target_indices=target_indices,
                                      target_weights=target_weights)
Exemplo n.º 19
0
  def FProp(self, theta, input_batch):
    """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      if (not py_utils.use_tpu() and
          tf.flags.FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])
      # [time, batch, dim]
      orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      input_embs += position_embs
      if p.task_emb:
        input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)

      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)
      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)
    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)
Exemplo n.º 20
0
def _ComputePaddings(ids, eos_id):
    is_eos = tf.to_int32(tf.equal(ids, eos_id))
    # eos_in_prefix[i, j] = any(ids[i, k] == eos_id for k in range(j))
    eos_in_prefix = tf.cumsum(is_eos, axis=-1, exclusive=True)
    return tf.where(tf.equal(eos_in_prefix, 0), tf.zeros_like(ids),
                    tf.ones_like(ids))
Exemplo n.º 21
0
    def _Extract(self, features):
        p = self.params

        source_id = py_utils.HasShape(features['image/source_id'], [])
        xmin = _Dense(features['object/image/bbox/xmin'])
        xmax = _Dense(features['object/image/bbox/xmax'])
        ymin = _Dense(features['object/image/bbox/ymin'])
        ymax = _Dense(features['object/image/bbox/ymax'])

        # 2d bounding box in image coordinates.
        bboxes = tf.stack([ymin, xmin, ymax, xmax], axis=1)
        bboxes_count = tf.shape(bboxes)[0]
        bboxes = py_utils.PadOrTrimTo(bboxes, [p.max_num_objects, 4])

        bboxes_padding = 1.0 - py_utils.PadOrTrimTo(tf.ones([bboxes_count]),
                                                    [p.max_num_objects])

        dim_xyz = tf.reshape(_Dense(features['object/velo/bbox/dim_xyz']),
                             [-1, 3])
        loc_xyz = tf.reshape(_Dense(features['object/velo/bbox/xyz']), [-1, 3])
        phi = tf.reshape(_Dense(features['object/velo/bbox/phi']), [-1, 1])
        # bboxes_3d is in [x, y, z, dx, dy, dz, phi].
        bboxes_3d = tf.concat([loc_xyz, dim_xyz, phi], axis=1)

        cx, cy, _, dx, dy, _, _ = tf.unstack(bboxes_3d, num=7, axis=-1)
        bboxes_td = tf.stack([
            cy - dy / 2,
            cx - dx / 2,
            cy + dy / 2,
            cx + dx / 2,
        ],
                             axis=-1)  # pyformat: disable
        bboxes_td = py_utils.PadOrTrimTo(bboxes_td, [p.max_num_objects, 4])

        has_3d_info = tf.to_float(_Dense(features['object/has_3d_info']))
        bboxes_3d_mask = py_utils.PadOrTrimTo(has_3d_info, [p.max_num_objects])
        bboxes_td_mask = bboxes_3d_mask

        # Fill in difficulties from bounding box height, truncation and occlusion.
        bb_height = ymax - ymin
        box_image_height = py_utils.PadOrTrimTo(bb_height, [p.max_num_objects])
        box_image_height *= bboxes_3d_mask

        # 0 to 3 indicating occlusion level. 0 means fully visible, 1 means partly,
        occlusion = tf.reshape(_Dense(features['object/occlusion']), [-1])
        occlusion = tf.to_float(occlusion)
        occlusion = py_utils.PadOrTrimTo(occlusion, [p.max_num_objects])
        occlusion *= bboxes_3d_mask

        # Truncation: 0 -> not truncated, 1.0 -> truncated
        truncation = tf.reshape(_Dense(features['object/truncation']), [-1])
        truncation = py_utils.PadOrTrimTo(truncation, [p.max_num_objects])
        truncation *= bboxes_3d_mask

        difficulties = ComputeKITTIDifficulties(box_image_height, occlusion,
                                                truncation)
        difficulties = py_utils.PadOrTrimTo(difficulties, [p.max_num_objects])

        # Make a batch axis to call BBoxCorners, and take the first result back.
        bbox3d_corners = geometry.BBoxCorners(bboxes_3d[tf.newaxis, ...])[0]

        # Project the 3D bbox to the image plane.
        velo_to_image_plane = features['transform/velo_to_image_plane']
        bboxes3d_proj_to_image_plane = geometry.PointsToImagePlane(
            tf.reshape(bbox3d_corners, [-1, 3]), velo_to_image_plane)

        # Output is [num_objects, 8 corners per object, (x, y)].
        bboxes3d_proj_to_image_plane = tf.reshape(bboxes3d_proj_to_image_plane,
                                                  [-1, 8, 2])
        bboxes3d_proj_to_image_plane = py_utils.PadOrTrimTo(
            bboxes3d_proj_to_image_plane, [p.max_num_objects, 8, 2])

        texts = features['object/label'].values
        labels = ops.static_map_string_int(x=texts,
                                           keys=self.KITTI_CLASS_NAMES)

        labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects])
        texts = py_utils.PadOrTrimTo(texts, [p.max_num_objects])

        # Filter labels by setting bboxes_padding, bboxes_3d_mask, and
        # bboxes_td_mask appropriately.
        if p.filter_labels is not None:
            valid_labels = tf.constant([p.filter_labels])
            bbox_mask = tf.reduce_any(tf.equal(tf.expand_dims(labels, 1),
                                               valid_labels),
                                      axis=1)
            bbox_mask = tf.to_float(bbox_mask)
            bboxes_padding = 1 - bbox_mask * (1 - bboxes_padding)
            filtered_bboxes_3d_mask = bboxes_3d_mask * bbox_mask
            bboxes_td_mask *= bbox_mask
        else:
            filtered_bboxes_3d_mask = bboxes_3d_mask

        # Placeholder for counting the number of laser points that reside within
        # each 3-d bounding box. This must be filled in outside of this function
        # based on the loaded 3-d laser points.
        bboxes_3d_num_points = tf.zeros([p.max_num_objects], dtype=tf.int32)
        bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points,
                                                    [p.max_num_objects])

        # Pad bboxes_3d.
        bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7])

        return py_utils.NestedMap(
            source_id=source_id,
            bboxes_count=bboxes_count,
            bboxes=bboxes,
            bboxes_padding=bboxes_padding,
            bboxes_3d=bboxes_3d,
            bboxes_3d_mask=filtered_bboxes_3d_mask,
            unfiltered_bboxes_3d_mask=bboxes_3d_mask,
            bboxes3d_proj_to_image_plane=bboxes3d_proj_to_image_plane,
            bboxes_td=bboxes_td,
            bboxes_td_mask=bboxes_td_mask,
            bboxes_3d_num_points=bboxes_3d_num_points,
            labels=labels,
            texts=texts,
            box_image_height=box_image_height,
            occlusion=occlusion,
            truncation=truncation,
            difficulties=difficulties)
Exemplo n.º 22
0
    def AssignAnchors(self,
                      anchor_bboxes,
                      gt_bboxes,
                      gt_bboxes_labels,
                      gt_bboxes_mask,
                      foreground_assignment_threshold=0.5,
                      background_assignment_threshold=0.35,
                      background_class_id=0,
                      force_match=True,
                      similarity_fn=None):
        """Assigns anchors to bboxes using a similarity function (SSD-based).

    Each anchor box is assigned to the top matching ground truth box.
    Ground truth boxes can be assigned to multiple anchor boxes.

    Assignments can result in 3 outcomes:

      - Positive assignment (if score >= foreground_assignment_threshold):
        assigned_gt_labels will reflect the assigned box label and
        assigned_cls_mask will be set to 1.0
      - Background assignment (if score <= background_assignment_threshold):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 1.0
      - Ignore assignment (otherwise):
        assigned_gt_labels will be background_class_id and assigned_cls_mask
        will be set to 0.0

    The detection loss function would usually:

      - Use assigned_cls_mask for weighting the classification loss. The mask
        is set such that the loss applies to foreground and background
        assignments only - ignored anchors will be set to 0.
      - Use assigned_reg_mask for weighting the regression loss. The mask is set
        such that the loss applies to foreground assignments only.

    The thresholds (foreground_assignment_threshold and
    background_assignment_threshold) should be tuned per dataset.

    TODO(jngiam): Consider having a separate threshold for regression boxes; a
    separate threshold is used in PointRCNN.

    Args:
      anchor_bboxes: tf.float32. [A, 7], where [..., :] corresponds to box
        parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes: tf.float32. [G, 7], where [..., :] corresponds to ground truth
        box parameters (x, y, z, dx, dy, dz, r).
      gt_bboxes_labels: tensor with shape [G]. Ground truth labels for each
        bounding box.
      gt_bboxes_mask: tensor with shape [G]. Mask for ground truth boxes, 1 iff
        the gt_bbox is a real bbox.
      foreground_assignment_threshold: Similarity score threshold for assigning
        foreground bounding boxes; scores need to be >=
        foreground_assignment_threshold to be assigned to foreground.
      background_assignment_threshold: Similarity score threshold for assigning
        background bounding boxes; scores need to be <=
        background_assignment_threshold to be assigned to background.
      background_class_id: class id to be assigned to anchors_gt_class if no
        anchor boxes match.
      force_match: Boolean specifying if force matching is enabled. If
        force matching is enabled, then matched anchors which are also the
        highest scoring with a ground-truth box are considered foreground
        matches as long as their similarity score > 0.
      similarity_fn: Function that computes the a similarity score (e.g., IOU)
        between pairs of bounding boxes. This function should take in two
        tensors corresponding to anchor and ground-truth bboxes, and return a
        matrix [A, G] with the similarity score between each pair of bboxes. The
        score must be non-negative, with greater scores representing more
        similar. The fore/background_assignment_thresholds will be applied to
        this score to determine if the an anchor is foreground, background or
        ignored. If set to None, the function will default to IOU2DRotatedBoxes.

    Returns:
      NestedMap with the following keys

      - assigned_gt_idx: shape [A] index corresponding to the index of the
        assigned ground truth box. Anchors not assigned to a ground truth box
        will have the index set to -1.
      - assigned_gt_bbox: shape [A, 7] bbox parameters assigned to each anchor.
      - assigned_gt_similarity_score: shape [A] (iou) score between the anchor
        and the gt bbox.
      - assigned_gt_labels: shape [A] label assigned to bbox.
      - assigned_cls_mask: shape [A] mask for classification loss per anchor.
        This should be 1.0 if the anchor has a foreground or background
        assignment; otherwise, it will be assigned to 0.0.
      - assigned_reg_mask: shape [A] mask for regression loss per anchor.
        This should be 1.0 if the anchor has a foreground assignment;
        otherwise, it will be assigned to 0.0.
        Note: background anchors do not have regression targets.
    """
        if similarity_fn is None:
            similarity_fn = self.IOU2DRotatedBoxes

        # Shape validation.
        anchor_bboxes = py_utils.HasShape(anchor_bboxes, [-1, 7])
        num_anchor_bboxes, _ = py_utils.GetShape(anchor_bboxes, 2)
        gt_bboxes = py_utils.HasShape(gt_bboxes, [-1, 7])
        num_gt_bboxes, _ = py_utils.GetShape(gt_bboxes, 2)

        # Compute similarity score and reduce max by anchors and by ground-truth.
        similarity_score = similarity_fn(anchor_bboxes, gt_bboxes)
        similarity_score = py_utils.HasShape(
            similarity_score, [num_anchor_bboxes, num_gt_bboxes])

        # Reduce over ground-truth boxes, so we have the max score per anchor.
        anchor_max_score = tf.reduce_max(similarity_score, axis=1)
        anchor_max_idx = tf.argmax(similarity_score, axis=1)

        if force_match:
            # Reduce over anchors, so we have the max score per ground truth box.
            gt_max_score = tf.reduce_max(similarity_score,
                                         axis=0,
                                         keep_dims=True)

            # Force matches occur when the top matching gt bbox for an anchor is the
            # top matching anchor for the gt bbox. When force matching, we match
            # these boxes as long as their similarity score exceeds 0.
            force_matches = (
                tf.equal(similarity_score, gt_max_score)
                & tf.equal(similarity_score, anchor_max_score[..., tf.newaxis])
                & tf.greater(similarity_score, 0.)
                & tf.cast(gt_bboxes_mask[tf.newaxis, ...], tf.bool))
            force_match_indicator = tf.reduce_any(force_matches, axis=1)
            force_match_idx = tf.argmax(tf.cast(force_matches, tf.int32),
                                        axis=1)

            # In assigning foreground/background anchors later, force_match_indicator
            # is used to determine which anchors are force foreground, and the index
            # assigned will be taken from anchor_max_idx.

            # Force matchers must also be the max scoring gt bbox per anchor.
            # We overwrite anchor_max_idx to ensure that the right match is done.
            anchor_max_idx = tf.where(force_match_indicator, force_match_idx,
                                      anchor_max_idx)

        # Ensure that max score boxes are not padded boxes by setting score to 0
        # for boxes that are padded.
        gathered_mask = tf.batch_gather(gt_bboxes_mask, anchor_max_idx)
        anchor_max_score = tf.where(tf.equal(gathered_mask, 1),
                                    anchor_max_score,
                                    tf.zeros_like(anchor_max_score))

        # Boolean tensors corresponding to whether an anchor is background or
        # foreground based on thresholding.
        background_anchors = tf.less_equal(anchor_max_score,
                                           background_assignment_threshold)
        foreground_anchors = tf.greater_equal(anchor_max_score,
                                              foreground_assignment_threshold)
        if force_match:
            # Background anchors are below threshold and not force matches.
            background_anchors &= ~force_match_indicator
            # Foreground anchors are above thresholds or force matches.
            foreground_anchors |= force_match_indicator

        # Add dummy background bbox to gt_boxes to facilitate batch gather.
        dummy_bbox = tf.constant([[0, 0, 0, 1, 1, 1, 0]], dtype=tf.float32)

        # Since we are concatenating the dummy bbox, the index corresponds to the
        # number of boxes.
        dummy_bbox_idx = py_utils.GetShape(gt_bboxes, 1)[0]

        gt_bboxes = tf.concat([gt_bboxes, dummy_bbox], axis=0)
        gt_bboxes_labels = tf.concat([gt_bboxes_labels, [background_class_id]],
                                     axis=0)

        # Gather indices so that all foreground boxes are gathered from gt_bboxes,
        # while all background and ignore boxes gather the dummy_bbox.
        anchor_gather_idx = tf.where(
            foreground_anchors, anchor_max_idx,
            tf.constant(dummy_bbox_idx,
                        shape=py_utils.GetShape(anchor_max_idx),
                        dtype=anchor_max_idx.dtype))

        # Gather the bboxes and weights.
        assigned_gt_bbox = tf.batch_gather(gt_bboxes, anchor_gather_idx)
        assigned_gt_labels = tf.batch_gather(gt_bboxes_labels,
                                             anchor_gather_idx)

        # Set masks for classification and regression losses.
        assigned_cls_mask = tf.cast(background_anchors | foreground_anchors,
                                    tf.float32)
        assigned_reg_mask = tf.cast(foreground_anchors, tf.float32)

        # Set assigned_gt_idx such that dummy boxes have idx = -1.
        assigned_gt_idx = tf.where(tf.equal(anchor_gather_idx, dummy_bbox_idx),
                                   tf.ones_like(anchor_gather_idx) * -1,
                                   anchor_gather_idx)
        assigned_gt_idx = tf.cast(assigned_gt_idx, tf.int32)

        return py_utils.NestedMap(
            assigned_gt_idx=assigned_gt_idx,
            assigned_gt_bbox=assigned_gt_bbox,
            assigned_gt_similarity_score=anchor_max_score,
            assigned_gt_labels=assigned_gt_labels,
            assigned_cls_mask=assigned_cls_mask,
            assigned_reg_mask=assigned_reg_mask)
Exemplo n.º 23
0
        def Callback(theta, encoder_outputs, step_ids, states,
                     num_hyps_per_beam, cur_step, *args, **kwargs):
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                cur_step, *args, **kwargs)

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            if biased:
                labels = encoder_outputs.targets.labels
                weights = encoder_outputs.targets.weights

                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent

                def NoApplyBias():
                    """No-op. Return original log_probs and consistent."""
                    return bs_results.log_probs, states.consistent

                log_probs, consistent = tf.cond(
                    tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias,
                    ApplyBias)
                bs_results.log_probs = log_probs
                out_states.consistent = consistent

            if stochastic:
                log_probs = bs_results.log_probs

                def PerturbedLogProbs():
                    # STEP 1: Perform top-k filtering. This is done as a performance
                    # optimization of avoiding sorting the entire `log_probs`, which is
                    # prohibitively slow.
                    top_k = tf.math.top_k(log_probs, k, sorted=True)
                    # shape: [tgt_batch, k]
                    top_k_log_probs = top_k.values
                    # shape: [tgt_batch, k]
                    top_k_ids = top_k.indices

                    # STEP 2: Perform top-p filtering.
                    # shape: [tgt_batch]
                    top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold
                    top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.)
                    top_p_threshold = TileForBeamAndFlatten(top_p_threshold)
                    # shape: [tgt_batch, k]
                    filtered_top_k_log_probs = _KeepTopP(
                        top_k_log_probs, top_p_threshold)

                    # STEP 3: Perturb cumulative log-probs.
                    # shape: [tgt_batch, 1]
                    last_cumulative_log_probs = states.cumulative_log_probs
                    # shape: [tgt_batch, 1]
                    last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs
                    # Compute cumulative log-probs of the current step.
                    # shape: [tgt_batch, k]
                    cumulative_log_probs = (last_cumulative_log_probs +
                                            filtered_top_k_log_probs)
                    # Perturb cumulative log-probs by Gumbel noises under the condition
                    # that the max of the new perturbed log-probs is equal to
                    # perturbed_cumulative_log_probs of the previous step.
                    # shape: [tgt_batch, k]
                    new_perturbed_cumulative_log_probs = _SampleGumbelWithMax(
                        cumulative_log_probs,
                        last_perturbed_cumulative_log_probs,
                        encoder_outputs.stochastic_beam_search.seed, time_step,
                        encoder_outputs.stochastic_beam_search.src_ids,
                        encoder_outputs.stochastic_beam_search.src_paddings)

                    # STEP 4: Compute updated log_probs. This step is necessary because
                    # the output of PreBeamSearchStepCallback must be "per-step"
                    # log-probs, whereas so far "cumulative" log-probs have been computed.
                    # shape: [tgt_batch, k]
                    updated_top_k_log_probs = (
                        new_perturbed_cumulative_log_probs -
                        last_perturbed_cumulative_log_probs)
                    # Convert to the shape [tgt_batch, vocab_size].
                    updated_log_probs = tf.fill(
                        tf.shape(log_probs),
                        tf.constant(LARGE_NEGATIVE_NUMBER,
                                    dtype=log_probs.dtype))
                    updated_log_probs = _BatchScatter(updated_log_probs,
                                                      top_k_ids,
                                                      updated_top_k_log_probs)

                    return (updated_log_probs,
                            py_utils.NestedMap(
                                new_perturbed_cumulative_log_probs=
                                new_perturbed_cumulative_log_probs,
                                top_k_log_probs=top_k_log_probs,
                                top_k_ids=top_k_ids,
                            ))

                (bs_results.log_probs, out_states.tmp_states) = tf.cond(
                    encoder_outputs.stochastic_beam_search.enable,
                    PerturbedLogProbs,
                    # No-op.
                    lambda: (bs_results.log_probs, states.tmp_states))
                # These states are not updated here but will be updated in
                # PostBeamSearchStepCallback since doing so requires the knowledge of
                # the next step IDs.
                out_states.cumulative_log_probs = states.cumulative_log_probs
                out_states.perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs

            return bs_results, out_states