def fast_gather(values,
                ids,
                ids_size,
                max_value=None,
                axis=0,
                batch_major_state=True):
    """Fast implementation of gather on TPUs.

  Args:
    values: Values to gather from.
    ids: ids (rows to gather)
    ids_size: id space size.
    max_value: Optional hint on maximum value for int32 that allows to speed up
      the gather operation.
    axis: axis to gather on. Defaults to 0 (rows).
    batch_major_state: Whether the values to gather from use batch major or not.
      Defaults to True.

  Returns:
    Gathered values.
  Raises:
    Value error if values is type int64.
  """
    values = tf.convert_to_tensor(values)
    ids = tf.convert_to_tensor(ids)
    with tf.name_scope("fast_gather"):
        return _Gatherer(ids, ids_size)(values,
                                        max_value=max_value,
                                        axis=axis,
                                        batch_major_state=batch_major_state)
Example #2
0
    def FProp(self, theta, inputs, paddings):
        """Apply global spatial pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time]. Defaults to None, which means there no paddings.

    Returns:
      outputs, out_paddings pair.
       - outputs: has shape [batch, 1, 1, channel].
       - out_paddings: None or has shape [batch, 1].
    """
        p = self.params
        assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type
        b, t, f = py_utils.GetShape(inputs, ndims=3)

        if paddings is not None:
            paddings = py_utils.HasShape(paddings, [b, t])

        if paddings is not None:
            mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis]
        else:
            mask = tf.ones([b, t, 1, 1], p.dtype)
        if p.pooling_type == 'AVG':
            global_sum = tf.reduce_sum(inputs * mask,
                                       axis=[1, 2],
                                       keepdims=True)
            f = tf.cast(tf.convert_to_tensor(f), p.dtype)
            count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True)
            out_feature = global_sum / tf.maximum(1.0, count)
        elif p.pooling_type == 'MAX':
            large_negative = (tf.ones_like(inputs) * p.dtype.max *
                              tf.constant(-0.7, dtype=p.dtype))
            padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative)
            out_feature = tf.reduce_max(padded_inputs,
                                        axis=[1, 2],
                                        keepdims=True)
        if paddings is None:
            out_paddings = None
        else:
            out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True)
            out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis]
        return out_feature, out_paddings
def _hash32(x, seed=1):
    """A simple int32 -> int32 hash function.

  Args:
    x: int32 Tensor, the value to hash.
    seed: int or int32 Tensor, the seed(s) values. Must be broadcastable to the
      shape of 'x'.
  """
    x = tf.convert_to_tensor(x)
    assert x.dtype == tf.int32
    # TODO(austinwaters): Change to an int64 valued hash once fast_gather
    # supports that type. For int64, the corresponding CityHash prime (k_mul) is
    # 0x9ddfea08eb382d69, and we should right shift by 32 bits in the last line.
    k_mul = 0xcc9e2d51  # Large prime borrowed from CityHash.
    m = (seed + x) * k_mul
    return tf.bitwise.bitwise_xor(m, (tf.bitwise.right_shift(m, 16)))
def reorder_tensor(reorder_mode,
                   values,
                   num_shards,
                   shard_size,
                   max_value=None,
                   axis=0):
    """Reorder tensor based on the mode passed in.

  This method reorders rows or cols (based on `axis`) of the tensor passed in
  from one sharding mode to another sharding mode. This method uses matmul for
  reordering to be efficient on TPUs.

  Args:
    reorder_mode: Either mod_to_div or div_to_mod
    values: Tensor to reorder
    num_shards: Number of shards.
    shard_size: Size of each shard.
    max_value: If dtype=tf.int32, and we know maximum of the values, we can
      efficiently implement it as matmuls.
    axis: axis to gather on. Defaults to 0 (rows).

  Returns:
    A tensor of same shape as values but rows (or first axis) reordered.
  """
    values = tf.convert_to_tensor(values)
    with tf.name_scope("reorder_tensor_" + reorder_mode):
        num_ids = num_shards * shard_size
        # Elements to gather.
        seq_ids = tf.range(num_ids)
        if reorder_mode == "mod_to_div":
            local_ids = seq_ids // shard_size
            shard_ids = seq_ids % shard_size
            ids = local_ids + shard_ids * num_shards
        elif reorder_mode == "div_to_mod":
            shard_ids = seq_ids % num_shards
            local_ids = seq_ids // num_shards
            ids = local_ids + shard_ids * shard_size
        else:
            raise NotImplementedError(
                "Reorder mode: {} not implemented.".format(reorder_mode))
        return fast_gather(values, ids, num_ids, max_value, axis=axis)
 def PostTrainingStepUpdate(self, global_step):
   """Updates moving_mean, moving_variance after each training step."""
   p = self.params
   # Get sufficient stats that accumulates over microbatches.
   counts = self.accumulators.counts.GetValue()
   mean_ss = self.accumulators.mean_ss.GetValue()
   variance_ss = self.accumulators.variance_ss.GetValue()
   # Compute batch mean and batch variance from sufficient stats
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
   # Update moving_mean, moving_variance from  batch mean and batch variance.
   with tf.name_scope(p.name) as scope:
     with tf.ops.colocate_with(self.vars.moving_mean):
       mean_update = tf.assign_sub(
           self.vars.moving_mean,
           tf.where(
               tf.greater(counts, 0.5),
               (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay,
               tf.zeros_like(self.vars.moving_mean)),
           name='moving_mean_update')
     with tf.ops.colocate_with(self.vars.moving_variance):
       var_update = tf.assign_sub(
           self.vars.moving_variance,
           tf.where(
               tf.greater(counts, 0.5),
               (self.vars.moving_variance - tf.cast(variance, p.dtype)) *
               decay, tf.zeros_like(self.vars.moving_variance)),
           name='moving_variance_update')
     py_utils.CheckNumerics(
         self.vars.moving_mean,
         'moving mean of {} failed numeric check'.format(scope))
     py_utils.CheckNumerics(
         self.vars.moving_variance,
         'moving variance of {} failed numeric check'.format(scope))
   self.accumulators.counts.Reset()
   self.accumulators.mean_ss.Reset()
   self.accumulators.variance_ss.Reset()
   return tf.group(mean_update, var_update)
 def replicate_var(name):
   return tf.convert_to_tensor(
       [self._private_vars[name]] * batch_dim, dtype=tf.float32)
def beam_search_step(in_scores,
                     in_atten_probs,
                     in_best_scores,
                     in_cumulative_scores,
                     in_histories,
                     cur_step,
                     eos_id,
                     num_beams,
                     beam_size,
                     num_hyps_per_beam,
                     valid_eos_max_logit_delta=5.0,
                     local_eos_threshold=-100.0,
                     merge_paths=False,
                     is_last_chunk=None,
                     eoc_id=0):
    """A single step of beam search.

  Let "b" be the number of beams, "k" be the number hyps in each beam. This
  function supports values with dtypes tf.float32 or tf.bfloat16.

  The following data structures are allocated before the first decoding step and
  are passed along from cur step to the next step:

  Args:
    in_scores: A tensor of shape [b * k, vocab_size], where [i, ...] is the
      token score of the j-th hyps of the n-th beam. j = (i / k), and n = i % k
    in_atten_probs: A tensor of shape [b*k, s_len], where in_atten_probs[i, ...]
      is the attention probabilities over the source words of the j-th hyps of
      n-th beam (where j, and n are derived as above).
    in_best_scores: A vector of size [b], best scores of terminated hyps so far
      in each of the beams.
    in_cumulative_scores: A vector of size [b * k]. The cumulative score of each
      active hyp before the current step.
    in_histories: An int32 vector of size [b * k] containing hashes of the
      histories of each active hyp. If 'merge_paths' is enabled, the histories
      are used to identify hypotheses that are identical modulo epsilons (e.g.
      "a <eps> b" and "a b <eps>") and merge them. See 'update_histories'
      docstring for details.
    cur_step: Current step id.
    eos_id: Token id of the special end of sequence token.
    num_beams: Number of beams.
    beam_size: Search terminates if the delta between the scores of the active
      hyps.
    num_hyps_per_beam: Number of hyps in a beam.
    valid_eos_max_logit_delta: We allow </s> to terminate a hyp only if its
      logit is no more than 'valid_eos_max_logit_delta' away from the logit of
      the best candidate.
    local_eos_threshold: We allow </s> to terminate a hyp if the local score for
      </s> is greater than local_eos_threshold.
    merge_paths: If true, hyps which are identical when epsilons are removed
      will be combined into a single hyp.  The probability for that combined hyp
      will be the sum of the probabilities of the component hyps.  This can only
      be applied for epsilon-emitting models (RNN-T and NT).
    is_last_chunk: A tensor of shape [b * k, 1]. Used by neural transducer,
      determines whether the current hypothesis reaches the last chunk and
      should treat the next end-of-chunk symbol as end-of-sentence.
    eoc_id: int, the id of the end of chunk (a.k.a epsilon) token used by neural
      transducer models. Only relevant if 'merge_paths' is True or
      'is_last_chunk' is provided.

  Returns:
    out_best_scores: A tensor of shape [b] of updated best scores for each of
      the beams.
    out_cumulative_scores: A tensor of shape [b * k]. The cumulative score of
      the new hyps after the current decoding step.
    out_scores: A tensor of shape [b * k] with scores of the token selected.
    out_eos_scores: A tensor of shape [b * k] with token scores for the EOS, in
      case the hyp was terminated, otherwise 0.0.
    out_hyps: A tensor of shape [b * k] with ids of the token selected.
    out_prev_hyps: A tensor of shape [b * k] with index to the previous hyps
      which was selected.
    out_done_hyps: A boolean tensor of shape [b * k] where value indicates
      if hyps was terminated.
    out_atten_probs: A tensor of shape [b * k, seq_len] which contain the
      attention probabilities over the source words against word in the previous
      hyps.
    out_eos_atten_probs: A tensor of shape [b * k, seq_len] which contains the
      attention probabilities over the source against word in the current hyp
      which was terminated.
    out_all_done: A scalar, whether decoding should terminate for all beams.
    out_histories: A tensor of shape [b * k] containing new history hashes for
      the active hypotheses. See 'update_histories' docstring for details.
  Raises:
    ValueError: if inputs are invalid.
  """
    num_hyps_per_beam = int(num_hyps_per_beam)

    if num_hyps_per_beam <= 0:
        raise ValueError("num_hyps_per_beam = {} and must be > 0.".format(
            num_hyps_per_beam))

    in_scores = tf.convert_to_tensor(in_scores)
    in_scores.shape.assert_has_rank(2)
    num_classes = in_scores.get_shape()[1]

    in_atten_probs = tf.convert_to_tensor(in_atten_probs)
    in_atten_probs.shape.assert_has_rank(2)

    in_best_scores = tf.convert_to_tensor(in_best_scores)
    in_best_scores.shape.assert_has_rank(1)

    in_cumulative_scores = tf.convert_to_tensor(in_cumulative_scores)
    in_cumulative_scores.shape.assert_has_rank(1)

    in_histories = tf.convert_to_tensor(in_histories)
    in_histories.shape.assert_has_rank(1)

    with tf.name_scope("beam_search_step"):
        # For k = num_hyps_per_beam
        # First step of beam search is to find the top tokens based on its score.
        # Normally we select k+1, where the extra +1 is to make sure we have k
        # non-eos tokens to select if EOS token is in the top-k. If path merging is
        # on, we actually need to select k+2; this ensures there are k+1 tokens left
        # after the merge, at least k of which are not EOS.
        # TODO(b/118644069): Avoid casts when there is a XLA op available that takes
        # in bfloat16.
        num_candidates_per_input_hyp = (num_hyps_per_beam + 2 if merge_paths
                                        else num_hyps_per_beam + 1)
        # [b * k, num_candidates_per_input_hyp]
        local_score_values, local_indices = xla_ops.top_k_with_unique(
            tf.cast(in_scores, tf.float32), k=num_candidates_per_input_hyp)
        local_score_values = tf.cast(local_score_values, in_scores.dtype)

        # Compute the global score which is sum of the local score, and the
        # cumulative scores for each of the hyps.
        # [b * k, num_candidates_per_input_hyp]
        global_score_values = local_score_values + tf.expand_dims(
            in_cumulative_scores, 1)

        values_dtype = local_score_values.dtype
        is_first_step = tf.cast(tf.equal(cur_step, 0), values_dtype)

        # Preprocessing to reorder the tensor from `mod` sharding to `div` so that
        # we can use matrix/vector operations to complete the beam search.
        # [b * k, num_candidates_per_input_hyp]
        global_score_values = reorder_tensor("mod_to_div", global_score_values,
                                             num_beams, num_hyps_per_beam)
        local_score_values = reorder_tensor("mod_to_div", local_score_values,
                                            num_beams, num_hyps_per_beam)
        local_indices = reorder_tensor("mod_to_div",
                                       local_indices,
                                       num_beams,
                                       num_hyps_per_beam,
                                       max_value=num_classes - 1)
        # [b * k, 1]
        histories = reorder_tensor("mod_to_div",
                                   tf.expand_dims(in_histories, 1), num_beams,
                                   num_hyps_per_beam)
        if is_last_chunk is None:
            is_last_chunk = tf.zeros([num_beams * num_hyps_per_beam, 1],
                                     tf.bool)
        else:
            is_last_chunk = tf.cast(
                reorder_tensor(
                    "mod_to_div",
                    tf.reshape(is_last_chunk,
                               [num_beams * num_hyps_per_beam, 1]), num_beams,
                    num_hyps_per_beam), tf.bool)

        # For the first step mask everything but the first row.
        # [num_hyps_per_beam]
        per_example_mask = tf.concat([
            tf.constant([1.0], dtype=values_dtype),
            tf.zeros([num_hyps_per_beam - 1], dtype=values_dtype)
        ], 0)
        # [num_hyps_per_beam, num_beams] => [b*k, 1]
        mask = tf.reshape(
            tf.tile(per_example_mask, tf.expand_dims(num_beams, 0)),
            [-1, 1]) * is_first_step + (1.0 - is_first_step)
        local_score_values *= mask
        global_score_values *= mask

        # We add a large negative value for the unmasked values.
        per_example_additive_mask = tf.concat([
            tf.constant([0.0], dtype=values_dtype),
            tf.constant(BEST_SCORES_INIT,
                        shape=[num_hyps_per_beam - 1],
                        dtype=values_dtype)
        ], 0)
        additive_mask = tf.reshape(
            tf.tile(per_example_additive_mask, tf.expand_dims(num_beams, 0)),
            [-1, 1]) * is_first_step
        local_score_values += additive_mask
        global_score_values += additive_mask

        if merge_paths:
            with tf.name_scope("merge_paths"):
                # Compute new history hashes for each hypothesis + new token.
                # [b * k, num_candidates_per_input_hyp]
                histories = update_histories(histories,
                                             local_indices,
                                             mask,
                                             epsilon_id=eoc_id)
                global_score_values, histories = merge_hyps(
                    global_score_values, histories, mask, num_beams,
                    num_hyps_per_beam)

        # As we keep num_candidates_per_input_hyp, we have a total of
        # num_candidates_per_input_hyp * k hyps active per example.
        num_candidate_hyps = num_candidates_per_input_hyp * num_hyps_per_beam
        batch_shape = [-1, num_candidate_hyps]

        # Reshape score values so that each row corresponds to a particular example.
        # [num_beams, num_candidate_hyps]
        global_score_values_batch = tf.reshape(global_score_values,
                                               batch_shape)

        # First for each beam: Find the top 2 * num_hyps_per_beam candidates.
        # The factor of 2 is to be able to process non EOS token ids in the case
        # where top scoring token for each hyps is EOS token.
        # [k * b, 2 * num_hyps_per_beam]
        _, candidates_indices_in_top_k = xla_ops.top_k_with_unique(
            tf.cast(global_score_values_batch, tf.float32),
            k=2 * num_hyps_per_beam)
        # Find the previous hyps of the candidate. We divide here by (k+1) to
        # identify which hyps this token came from.
        hyps_id = candidates_indices_in_top_k // num_candidates_per_input_hyp

        # Add in offset so that we can get the candidate index in the [b * k] space.
        offset = tf.expand_dims(tf.range(num_beams) * num_candidate_hyps, 1)
        flat_candidates_indices_in_top_k = tf.reshape(
            candidates_indices_in_top_k + offset, [-1])

        flat_local_indices = tf.reshape(local_indices, [1, -1])
        flat_token_scores = tf.reshape(local_score_values, [-1, 1])
        flat_global_scores = tf.reshape(global_score_values, [-1, 1])

        # Gather the token scores for each of 2*k candidates. We use tf.one_hot()
        # followed by a tf.matmul() to speedup gather on TPUs.
        total_num_candidates = num_beams * num_candidate_hyps
        token_scores_for_beam = tf.reshape(
            fast_gather(flat_token_scores, flat_candidates_indices_in_top_k,
                        total_num_candidates),
            [num_beams, 2 * num_hyps_per_beam])
        token_scores_for_beam_shape = tf.shape(token_scores_for_beam)

        global_scores_for_beam = tf.reshape(
            fast_gather(flat_global_scores, flat_candidates_indices_in_top_k,
                        total_num_candidates), token_scores_for_beam_shape)

        # Local indices value's are between [0, vocab_size-1], hence we use the
        # slower version of gather.
        token_ids_for_beam = tf.reshape(
            fast_gather(flat_local_indices,
                        flat_candidates_indices_in_top_k,
                        total_num_candidates,
                        max_value=num_classes - 1,
                        axis=1), token_scores_for_beam_shape)

        # We have access to 2*num_hyps_per_beam hyps per beam.
        # We shrink back to num_hyps_per_beam that does not include EOS, and move
        # EOS that occurs in top-num_hyps_per_beam to the EOS done matrix.

        # To determine the threshold at which eos is allowed to terminate a hyp,
        # we need to know the maximum global score for that hyp with any additional
        # token. If path merging is *not* enabled, the global_score_values are
        # by construction in sorted order, so we can just look at its 0th column. If
        # path merging is enabled, the global scores of deleted (merged) hyps break
        # the sorted order, which means we have to do a full reduce_max.
        if merge_paths:
            max_global_score_per_input_hyp = tf.reduce_max(global_score_values,
                                                           axis=1,
                                                           keepdims=True)
        else:
            max_global_score_per_input_hyp = global_score_values[:, 0:1]
        # [num_beams * num_hyps_per_beam, 1]
        global_eos_threshold = (max_global_score_per_input_hyp -
                                valid_eos_max_logit_delta)
        local_eos_threshold_tensor = local_eos_threshold * tf.ones_like(
            global_eos_threshold)

        # Find EOS in top num_hyps_per_beam token ids. We also treat EOC as EOS if
        # the model has indicated this is the last chunk.
        local_index_is_eos = tf.equal(local_indices, eos_id)
        local_index_is_last_chunk_eoc = tf.math.logical_and(
            tf.equal(local_indices, eoc_id), is_last_chunk)
        eos_mask = tf.math.logical_and(
            tf.math.logical_and(
                tf.math.logical_and(
                    tf.greater(
                        local_score_values,
                        tf.tile(local_eos_threshold_tensor,
                                [1, num_candidates_per_input_hyp])),
                    tf.greater(
                        global_score_values,
                        tf.tile(global_eos_threshold,
                                [1, num_candidates_per_input_hyp]))),
                tf.math.logical_or(local_index_is_eos,
                                   local_index_is_last_chunk_eoc)),
            tf.cast(mask, tf.bool))
        end_hyps_bool_mask = tf.reshape(tf.reduce_any(eos_mask, 1), [-1, 1])

        end_hyps_bool_mask = reorder_tensor("div_to_mod", end_hyps_bool_mask,
                                            num_beams, num_hyps_per_beam)

        eos_atten_probs = in_atten_probs * tf.cast(end_hyps_bool_mask,
                                                   in_atten_probs.dtype)
        eos_atten_probs = tf.reshape(eos_atten_probs,
                                     [num_beams * num_hyps_per_beam, -1])
        # A boolean tensor of shape [b * k] where value indicates if hyps was
        # terminated.
        out_done_hyps = tf.reshape(end_hyps_bool_mask, [-1])

        # Scores for EOS token.
        eos_float_mask = tf.cast(eos_mask, values_dtype)
        eos_local_scores = eos_float_mask * local_score_values
        eos_additive_float_mask = (1.0 - eos_float_mask) * BEST_SCORES_INIT
        eos_local_scores += eos_additive_float_mask
        out_eos_scores = tf.reshape(tf.reduce_max(eos_local_scores, 1),
                                    [-1, 1])
        out_eos_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_eos_scores, num_beams,
                           num_hyps_per_beam), [-1])
        # A tensor of shape [b] of updated best scores for each of the beams.
        eos_global_scores = eos_float_mask * global_score_values
        eos_global_scores += eos_additive_float_mask
        best_scores = tf.reduce_max(
            tf.reshape(eos_global_scores, [num_beams, -1]), 1)

        # Following operations are to finds the top num_hyps_per_beam that are
        # active.

        # Active ones are the ones that do not correspond to EOS termination.
        # We keep num_hyps_per_beam * 2 in case every hyps is terminated by EOS id.
        # Top K with eos removed.
        non_eos_mask = tf.not_equal(token_ids_for_beam, eos_id)
        num_candidate_hyps = num_hyps_per_beam * 2 * num_beams
        index = tf.where(
            non_eos_mask,
            tf.reshape(tf.range(num_candidate_hyps, dtype=tf.int32),
                       token_scores_for_beam_shape),
            num_candidate_hyps *
            tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape))

        # Unrolled TopK.
        sorted_indices = []
        # Finds the first num_hyps_per_beam unmasked indexes and stores them in
        # concated_index (shape: [num_beams, num_candidate_hyps])
        # This is done by iteratively record the min index in each row, and reset
        # it to the max, so that next iteration reduce_min returns the 2nd minimum
        # index.
        for _ in range(num_hyps_per_beam):
            min_index = tf.reshape(tf.reduce_min(index, [1]), [num_beams, 1])
            sorted_indices.append(min_index)
            # Replace position with num_candidate_hyps value.
            index = tf.where(
                tf.equal(index, min_index),
                num_candidate_hyps *
                tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape),
                index)

        # Post processing ops to output expected tensors.
        concated_sorted_indices = tf.concat(sorted_indices, 1)
        flat_sorted_indices = tf.reshape(concated_sorted_indices, [-1])

        # A tensor of shape [b * k] with scores of the token selected.
        out_scores = tf.reshape(
            fast_gather(tf.reshape(token_scores_for_beam, [-1, 1]),
                        flat_sorted_indices, num_candidate_hyps), [-1, 1])
        out_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_scores, num_beams,
                           num_hyps_per_beam), [-1])

        # Gather the updated histories of selected hypotheses if path merging is
        # enabled. Otherwise, the histories are unused, so just output in_histories.
        if merge_paths:
            flat_histories = tf.reshape(histories, [-1, 1])
            # [num_beams, 2 * num_hyps_per_beam]
            histories_for_beam = tf.reshape(
                fast_gather(flat_histories, flat_candidates_indices_in_top_k,
                            total_num_candidates), token_scores_for_beam_shape)
            out_histories = tf.reshape(
                fast_gather(tf.reshape(histories_for_beam, [-1, 1]),
                            flat_sorted_indices, num_candidate_hyps), [-1, 1])
            out_histories = tf.reshape(
                reorder_tensor("div_to_mod", out_histories, num_beams,
                               num_hyps_per_beam), [-1])
        else:
            out_histories = in_histories

        prev_hyps_ids = tf.reshape(
            tf.reshape(
                fast_gather(tf.reshape(hyps_id, [1, -1]),
                            flat_sorted_indices,
                            num_candidate_hyps,
                            max_value=num_hyps_per_beam,
                            axis=1), [num_beams, -1]) * num_beams +
            tf.expand_dims(tf.range(num_beams), 1), [-1, 1])

        prev_hyps_ids = reorder_tensor("div_to_mod",
                                       prev_hyps_ids,
                                       num_beams,
                                       num_hyps_per_beam,
                                       max_value=num_hyps_per_beam)
        # A tensor of shape [b * k] with index to the previous hyps which was
        # selected.
        out_prev_hyps = tf.reshape(prev_hyps_ids, [-1])

        # A tensor of shape [b * k, seq_len] which contain the attention
        # probabilities over the source words against word in the previous hyps.
        out_atten_probs = tf.reshape(
            fast_gather(in_atten_probs, out_prev_hyps,
                        num_beams * num_hyps_per_beam),
            [num_beams * num_hyps_per_beam, -1])

        sorted_top_k_ids = fast_gather(tf.reshape(token_ids_for_beam, [1, -1]),
                                       flat_sorted_indices,
                                       num_candidate_hyps,
                                       max_value=num_classes - 1,
                                       axis=1)
        sorted_top_k_ids = reorder_tensor("div_to_mod",
                                          sorted_top_k_ids,
                                          num_beams,
                                          num_hyps_per_beam,
                                          max_value=num_classes - 1,
                                          axis=1)

        # A tensor of shape [b * k] with ids of the token selected.
        out_hyps = tf.reshape(sorted_top_k_ids, [-1])

        # A tensor of shape [b * k]. The cumulative score of the selected hyps after
        # the current decoding step.
        out_cumulative_scores = tf.reshape(
            fast_gather(tf.reshape(global_scores_for_beam, [-1, 1]),
                        flat_sorted_indices, num_candidate_hyps), [-1, 1])

        out_cumulative_scores = tf.reshape(
            reorder_tensor("div_to_mod", out_cumulative_scores, num_beams,
                           num_hyps_per_beam), [-1])
        out_best_scores = tf.maximum(best_scores, in_best_scores)

        # A scalar, whether decoding should terminate for all beams.
        out_all_done = tf.reshape(
            tf.math.logical_not(
                tf.reduce_any(
                    tf.greater(
                        out_cumulative_scores,
                        tf.reshape(
                            tf.tile(
                                tf.reshape(out_best_scores - beam_size,
                                           [-1, 1]), [1, num_hyps_per_beam]),
                            [-1])))), [])

        return (out_best_scores, out_cumulative_scores, out_scores,
                out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps,
                out_atten_probs, eos_atten_probs, out_all_done, out_histories)