def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
Beispiel #2
0
    def _ProcessBeamSearchDecodeOut(self, input_batch, encoder_outputs,
                                    decoder_outs):
        self.r1_shape = decoder_outs[0]
        self.r2_shape = decoder_outs[1]
        self.r3_shape = decoder_outs[2]
        tf.logging.info('r1_shape: %s', self.r1_shape)
        tf.logging.info('r2_shape: %s', self.r2_shape)
        tf.logging.info('r3_shape: %s', self.r3_shape)

        hyps = decoder_outs[3]
        prev_hyps = decoder_outs[4]
        done_hyps = decoder_outs[5]
        scores = decoder_outs[6]
        atten_probs = decoder_outs[7]
        eos_scores = decoder_outs[8]
        eos_atten_probs = decoder_outs[9]
        source_seq_lengths = decoder_outs[10]

        tlen = tf.cast(
            tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0),
            tf.int32)
        ret_dict = {
            'target_ids': input_batch.tgt.ids[:, 1:],
            'eval_weight': input_batch.eval_weight,
            'tlen': tlen,
            'hyps': hyps,
            'prev_hyps': prev_hyps,
            'done_hyps': done_hyps,
            'scores': scores,
            'atten_probs': atten_probs,
            'eos_scores': eos_scores,
            'eos_atten_probs': eos_atten_probs,
            'source_seq_lengths': source_seq_lengths,
        }
        return ret_dict
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
def SequenceLength(padding):
  """Computes the length of a sequence based on binary padding.

  Args:
    padding: A tensor of binary paddings shaped [batch, seqlen].

  Returns:
    seq_lens, A tensor of shape [batch] containing the non-padded length of each
      element of plot_tensor along the batch dimension.
  """
  seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32)
  # Get rid of any extra dimensions.
  batch_size = tf.shape(padding)[0]
  seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens')
  return seq_lens
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                    tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
Beispiel #7
0
  def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
    """Takes a tensor of strings and returns id/padding tensors.

    This generates `token_ids`, `target_ids`, and `paddings` in the format that
    is expected for tokenizers. This performs padding to a fixed length and
    appends the end-of-sentence token as appropriate.

    Args:
      strs: a string Tensor.
      max_length: a python integer. The second dimension of the returned arrays.
        All sequences are padded or truncated to that length.
      append_eos: a python bool. See `BaseTokenizer` for explanation.
      languages: A vector of strings with the same length as `strs`.

    Returns:
      A tuple of 3 tensors:

      - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences
        always end with EOS unless the sequence exceeds the maximum length.
        Always padded with EOS.
      - target_ids: a tensor of sequences of WPM ids not starting with SOS
        but ending with EOS. Always padded with EOS.
      - paddings: a tensor of floats indicating, at each position, whether
        the corresponding position is padded.
    """
    p = self.params
    if append_eos is None:
      append_eos = p.append_eos

    batch_size = py_utils.GetShape(strs)[0]
    token_ids_ta = tf.TensorArray(tf.int32, batch_size)
    target_ids_ta = tf.TensorArray(tf.int32, batch_size)
    paddings_ta = tf.TensorArray(tf.float32, batch_size)

    def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta):
      """Tokenizes a single sentence."""
      ids, _ = self._wpm_encoder.Encode(strs[i])

      if append_eos:
        ids = tf.concat([ids, [self.eos_id]], axis=0)

      # This truncates after the eos is added, so some sentences might
      # not have </s> at the end.
      token_ids_ta = token_ids_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.concat([[self.sos_id], ids], axis=0), [max_length],
              self.eos_id))
      target_ids_ta = target_ids_ta.write(
          i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id))
      paddings_ta = paddings_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.))

      return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta

    _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop(
        lambda i, *_: i < batch_size,
        _TokenizeOneSentence,
        loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta,
                   paddings_ta),
        parallel_iterations=30,
        back_prop=False)

    token_ids = token_ids_ta.stack()
    target_ids = target_ids_ta.stack()
    paddings = paddings_ta.stack()

    if not p.pad_to_max_length:
      maxlen = tf.cast(
          tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))),
          tf.int32)
      token_ids = token_ids[:, :maxlen]
      target_ids = target_ids[:, :maxlen]
      paddings = paddings[:, :maxlen]

    return token_ids, target_ids, paddings
Beispiel #8
0
    def _BeamSearchDecodeIds(self,
                             theta,
                             encoder_outputs,
                             num_hyps_per_beam,
                             init_beam_search_state=None,
                             pre_beam_search_step_callback=None,
                             post_beam_search_step_callback=None,
                             max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap computed by encoder.
      num_hyps_per_beam: Number of hyps per beam.

      init_beam_search_state: The InitBeamSearchState callback. Please refer to
          the class header comments for more details.
      pre_beam_search_step_callback: The PreBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      post_beam_search_step_callback: The PostBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
          self.params.target_seq_len.

    Returns:
      hyps: A tensor of shape [time, b * k] with ids of the token selected.
      prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps
        which was selected.
      done_hyps: A boolean tensor of shape [time, b * k] where value
        indicates if hyps was terminated.
      scores: A tensor of shape [time, b * k] with scores of the token
        selected.
      atten_probs: A tensor of shape [time, b * k, seq_len] which contain the
        attention probabilities over the source words against word in the
        previous hyps.
      eos_scores: A tensor of shape [time, b * k] with scores of the eos token
        selected.
      eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain
        the attention probabilities over the source words against word in the
        previous hyps.
      source_seq_lengths:  A tensor of shape [time] containing the source
        seq_lengths.
      flat_final_other_states: A array of tensors that are part of other states.
    """
        p = self.params
        source_paddings = encoder_outputs.padding

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        # We cache the NestedMap as member variable so that we can use it to
        # pack the final outputs. Tpu rewrite methods forces us to strictly pass
        # in Tensors, and output Tensors
        self._other_states = other_states

        step_ids = tf.fill([num_hyps, 1],
                           tf.constant(p.target_sos_id, dtype=tf.int32))
        min_score = -1e36
        fprop_dtype = py_utils.FPropDtype(p)
        best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) +
                       min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype)
        histories = tf.zeros(shape=[num_hyps], dtype=tf.int32)
        in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        # States for beam search that are inputs into Beam search step.
        accum_bs_states = [best_scores, cumulative_scores, histories]
        # States that are not accumulators.
        non_accum_bs_states = [
            in_scores,
            in_hyps,
            in_prev_hyps,
            in_done_hyps,
            in_atten_probs,
            in_eos_scores,
            in_eos_atten_probs,
        ]
        core_bs_states = tuple(accum_bs_states + non_accum_bs_states)

        flat_other_states = other_states.Flatten()

        # If there is an optimized implementation for short sequence, LoopBodyShort
        # will run first for short_seq_limit steps (after which the
        # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the
        # default implementation) is used to continue the rest of the steps. For
        # decoders which do not have the short sequence specific implementation,
        # only the LoopBodyLong (the default implementation) will run.

        if p.short_seq_limit > 0:

            def LoopContinueShort(cur_step, all_done, unused_step_ids,
                                  unused_core_bs_states,
                                  unused_other_states_list):
                """Use short_seq optimization when cur_step is smaller than limit."""
                return tf.math.logical_and(cur_step < p.short_seq_limit,
                                           tf.math.logical_not(all_done))

            def LoopBodyShort(cur_step, unused_all_done, step_ids,
                              core_bs_states, other_states_list):
                """Loop body of short_seq optimization.

        Instead of doing computation for the entire padded sequence, while loop
        with early exit is used within each _BeamSearchStep to do computation
        for only the actual sequence (seq_length <= cur_step).
        use_short_seq_opt is used as the flag to pass this information down to
        the decoder implementation.

        Args:
          cur_step: A scalar int tensor, the current time step, 0-based.
          unused_all_done: A tf.bool, indicating whether the decoding finishes.
          step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the
            current search step.
          core_bs_states: A tuple of core beam search states.
          other_states_list: A flattened NestedMap of other beam search states.

        Returns:
          The updated input tuple, with the same shape.
        """
                (cur_step, all_done, new_step_ids, new_bs_states,
                 new_other_states) = self._BeamSearchStep(
                     theta,
                     encoder_outputs,
                     cur_step,
                     step_ids,
                     core_bs_states,
                     other_states.Pack(other_states_list),
                     num_hyps_per_beam,
                     pre_beam_search_step_callback,
                     post_beam_search_step_callback,
                     use_short_seq_opt=True)
                return (cur_step, all_done, new_step_ids, new_bs_states,
                        new_other_states.Flatten())

            (cur_step, all_done, step_ids, core_bs_states,
             flat_other_states) = tf.while_loop(
                 LoopContinueShort,
                 LoopBodyShort,
                 loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                            flat_other_states),
                 parallel_iterations=10,
                 back_prop=False,
                 swap_memory=False,
                 shape_invariants=(
                     tf.TensorShape(cur_step.get_shape()),
                     tf.TensorShape(all_done.get_shape()),
                     tf.TensorShape(step_ids.get_shape()),
                     tuple(
                         list(_GetShapes(accum_bs_states)) +
                         list(_GetShapes(non_accum_bs_states,
                                         none_shapes=True))),
                     _GetShapes(flat_other_states, none_shapes=True)),
                 maximum_iterations=max_steps)

        def LoopContinueLong(cur_step, all_done, unused_step_ids,
                             unused_core_bs_states, unused_other_states_list):
            """Continue default implementation until decoding finishes."""
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states,
                         other_states_list):
            """Loop body of default long_seq implementation."""
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta,
                 encoder_outputs,
                 cur_step,
                 step_ids,
                 core_bs_states,
                 other_states.Pack(other_states_list),
                 num_hyps_per_beam,
                 pre_beam_search_step_callback,
                 post_beam_search_step_callback,
                 use_short_seq_opt=False)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinueLong,
            LoopBodyLong,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(
                tf.TensorShape(cur_step.get_shape()),
                tf.TensorShape(all_done.get_shape()),
                tf.TensorShape(step_ids.get_shape()),
                tuple(
                    list(_GetShapes(accum_bs_states)) +
                    list(_GetShapes(non_accum_bs_states, none_shapes=True))),
                _GetShapes(flat_other_states, none_shapes=False)),
            maximum_iterations=max_steps)

        if isinstance(source_paddings, py_utils.NestedMap):
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                              1)),
                                         dtype=tf.int32)
        else:
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                                         dtype=tf.int32)

        # Concatenate all outputs on axis=0.
        scores = final_bs_states[3].stack()
        hyps = final_bs_states[4].stack()
        prev_hyps = final_bs_states[5].stack()
        done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool)
        atten_probs = final_bs_states[7].stack()
        eos_scores = final_bs_states[8].stack()
        eos_atten_probs = final_bs_states[9].stack()
        rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores,
                eos_atten_probs, source_seq_lengths)

        # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after
        # b/111131551 is resolved.
        # Canonical shapes for tensors of various. ranks
        r_shapes = [
            py_utils.GetShape(source_seq_lengths),
            py_utils.GetShape(hyps),
            py_utils.GetShape(atten_probs)
        ]
        # Reshape all tensors to [-1] to avoid cost of copy due to padding.
        rets_r1 = [tf.reshape(r, [-1]) for r in rets]

        return tuple(r_shapes) + tuple(rets_r1) + tuple(
            flat_final_other_states)