コード例 #1
0
  def __init__(self, dtype, shape, send_device, recv_device, name=None):
    """Construct a channel.

    Args:
      dtype: The dtype of tensors sent through the channel.
      shape: The shape of tensors sent through the channel. Must be a fully
        defined shape for TPUs.
      send_device: A fully-specified tensorflow device.
      recv_device: A fully-specified tensorflow device.
      name: A name for the channel (optional).
    """
    current_graph = tf.get_default_graph()
    assert current_graph, "A channel is scoped within a tf.Graph"
    self._dtype = dtype
    self._send_device = send_device
    self._recv_device = recv_device
    self._name = current_graph.unique_name(name if name else "channel")

    assert shape is not None
    shape = tf.TensorShape(shape)

    self._shape = shape
    self._send_tpu_core = _TpuCore(send_device)
    self._recv_tpu_core = _TpuCore(recv_device)
    self._send_called = False
    self._recv_op = None
    assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
        "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
    if self._send_tpu_core >= 0:
      assert self._shape.is_fully_defined(), (
          "TPU channel must have fully defined shape. Name: %s, shape: %s" %
          (self._name, self._shape))
      assert self._send_tpu_core != self._recv_tpu_core, (
          "TPU send/recv must be cross-core: %s and %s" %
          (send_device, recv_device))
コード例 #2
0
  def FProp(self, theta, *args):
    """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
    p = self.params
    for arg in args:
      if arg is not None:
        arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

    theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)
    inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
    # Infer out_shapes from FPropMeta.
    out_shapes = self._InferOutShapes(args)

    def _CellFn(unused_theta, unused_state0, inputs):
      """Recurrent cell function wrapper of body.FProp."""
      # Sets shapes for both theta and inputs to self.body.FProp.
      for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                          list(args) + theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(out_shapes)
      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Initiate state0 with inferred output shapes.
      state0 = py_utils.NestedMap(
          outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes])
      # Runs body.FProp p.repeat times using Recurrent.
      acc_states, _ = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=inputs,
          cell_fn=_CellFn)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = tuple(acc_states.outputs)
      for out_idx in range(len(output_tensors)):
        output_tensors[out_idx].set_shape(
            tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
コード例 #3
0
def _GetShapes(tensors, none_shapes=False):
    """Util for getting nested structure of shapes from structure of tensors.

  Args:
    tensors: Structure of Tensors to get shapes for.
    none_shapes: Returns None shapes if true.

  Returns:
    The same structure as tensors but of corresponding `TensorShape` objects.
  """
    shapes = []
    for t in tf.nest.flatten(tensors):
        shape = t.get_shape() if isinstance(t, tf.Tensor) else None
        if none_shapes:
            if shape:
                shapes.append(tf.TensorShape([None] * len(shape)))
            else:
                shapes.append(tf.TensorShape(None))
        else:
            shapes.append(tf.TensorShape(shape))

    return type(tensors)(tf.nest.pack_sequence_as(tensors, shapes))
コード例 #4
0
    def _CellFn(unused_theta, unused_state0, inputs):
      """Recurrent cell function wrapper of body.FProp."""
      # Sets shapes for both theta and inputs to self.body.FProp.
      for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                          list(args) + theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(out_shapes)
      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
      return state1, py_utils.NestedMap()
コード例 #5
0
 def _AssignVar(self, var_op):
     size = var_op.get_attr('dtype').size
     shape = tf.TensorShape(var_op.get_attr('shape'))
     assert self._var_space_pq, ('No ps devices to use.')
     allocated, device = heapq.heappop(self._var_space_pq)
     if shape.num_elements() is None:
         assert var_op.name.endswith(
             'wb/var'), 'Unexpected name pattern: %s' % var_op.name
         # CuDNN RNN vars shape aren't known statically, decide to make a constant
         # estimate to avoid introducing more complexities.
         allocated += 10 * 1024**2 * size
     else:
         allocated += shape.num_elements() * size
     heapq.heappush(self._var_space_pq, (allocated, device))
     tf.logging.info('Place variable %s on %s %d', var_op.name, device,
                     allocated)
     return device
コード例 #6
0
    def _CellFn(unused_theta, state0, theta_i):
      """Recurrent cell function wrapper of body.FProp."""
      # Retrieves fprop arguments from state and sets shapes.
      frop_inputs = _StateToArgs(state0)

      # Sets shapes for theta_i as well.
      for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      frop_outputs = self.body.FProp(theta_i, *frop_inputs)
      frop_outputs = _ToTuple(frop_outputs)
      assert len(frop_outputs) == len(frop_inputs)

      # Passes fprop outputs to the next layer through state.
      state1 = _ArgsToState(frop_outputs)
      return state1, py_utils.NestedMap()
コード例 #7
0
    def _get_input_shapes(self, *args):
        p = self.params
        if p.nested_map_fprop:
            assert len(args) == 1
            assert isinstance(args[0], py_utils.NestedMap)
            input_tensors = py_utils.Flatten(args[0])
        else:
            input_tensors = _ToTuple(args)
        # Get batch size from the first tensor which is not None.
        mini_batch_size = None
        for input_tensor in input_tensors:
            if input_tensor is not None:
                mini_batch_size = input_tensor.get_shape().as_list()[
                    p.batch_dim]
        assert mini_batch_size is not None
        micro_batch_size = p.micro_batch_size
        if not micro_batch_size:
            if p.num_micro_batches > mini_batch_size:
                p.num_micro_batches = mini_batch_size
            micro_batch_size = mini_batch_size // p.num_micro_batches
        if mini_batch_size is not None:
            if micro_batch_size * p.num_micro_batches != mini_batch_size:
                raise ValueError(
                    'micro_batch_size * num_micro_batches != batch_size.')

        input_shapes = ()
        for input_tensor in input_tensors:
            if input_tensor is not None:
                input_shape = input_tensor.get_shape().as_list()
                input_shape[p.batch_dim] = micro_batch_size
                input_shapes += (tf.TensorShape(input_shape), )
            else:
                input_shapes += (None, )

        if p.nested_map_fprop:
            input_shapes = py_utils.Pack(args[0], input_shapes)
        return input_shapes
コード例 #8
0
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.math.logical_and(
                cur_step < max_steps,
                tf.math.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
コード例 #9
0
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                    tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
コード例 #10
0
    def _BeamSearchDecodeIds(self,
                             theta,
                             encoder_outputs,
                             num_hyps_per_beam,
                             init_beam_search_state=None,
                             pre_beam_search_step_callback=None,
                             post_beam_search_step_callback=None,
                             max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap computed by encoder.
      num_hyps_per_beam: Number of hyps per beam.

      init_beam_search_state: The InitBeamSearchState callback. Please refer to
          the class header comments for more details.
      pre_beam_search_step_callback: The PreBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      post_beam_search_step_callback: The PostBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
          self.params.target_seq_len.

    Returns:
      hyps: A tensor of shape [time, b * k] with ids of the token selected.
      prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps
        which was selected.
      done_hyps: A boolean tensor of shape [time, b * k] where value
        indicates if hyps was terminated.
      scores: A tensor of shape [time, b * k] with scores of the token
        selected.
      atten_probs: A tensor of shape [time, b * k, seq_len] which contain the
        attention probabilities over the source words against word in the
        previous hyps.
      eos_scores: A tensor of shape [time, b * k] with scores of the eos token
        selected.
      eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain
        the attention probabilities over the source words against word in the
        previous hyps.
      source_seq_lengths:  A tensor of shape [time] containing the source
        seq_lengths.
      flat_final_other_states: A array of tensors that are part of other states.
    """
        p = self.params
        source_paddings = encoder_outputs.padding

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        # We cache the NestedMap as member variable so that we can use it to
        # pack the final outputs. Tpu rewrite methods forces us to strictly pass
        # in Tensors, and output Tensors
        self._other_states = other_states

        step_ids = tf.fill([num_hyps, 1],
                           tf.constant(p.target_sos_id, dtype=tf.int32))
        min_score = -1e36
        fprop_dtype = py_utils.FPropDtype(p)
        best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) +
                       min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype)
        histories = tf.zeros(shape=[num_hyps], dtype=tf.int32)
        in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        # States for beam search that are inputs into Beam search step.
        accum_bs_states = [best_scores, cumulative_scores, histories]
        # States that are not accumulators.
        non_accum_bs_states = [
            in_scores,
            in_hyps,
            in_prev_hyps,
            in_done_hyps,
            in_atten_probs,
            in_eos_scores,
            in_eos_atten_probs,
        ]
        core_bs_states = tuple(accum_bs_states + non_accum_bs_states)

        flat_other_states = other_states.Flatten()

        # If there is an optimized implementation for short sequence, LoopBodyShort
        # will run first for short_seq_limit steps (after which the
        # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the
        # default implementation) is used to continue the rest of the steps. For
        # decoders which do not have the short sequence specific implementation,
        # only the LoopBodyLong (the default implementation) will run.

        if p.short_seq_limit > 0:

            def LoopContinueShort(cur_step, all_done, unused_step_ids,
                                  unused_core_bs_states,
                                  unused_other_states_list):
                """Use short_seq optimization when cur_step is smaller than limit."""
                return tf.math.logical_and(cur_step < p.short_seq_limit,
                                           tf.math.logical_not(all_done))

            def LoopBodyShort(cur_step, unused_all_done, step_ids,
                              core_bs_states, other_states_list):
                """Loop body of short_seq optimization.

        Instead of doing computation for the entire padded sequence, while loop
        with early exit is used within each _BeamSearchStep to do computation
        for only the actual sequence (seq_length <= cur_step).
        use_short_seq_opt is used as the flag to pass this information down to
        the decoder implementation.

        Args:
          cur_step: A scalar int tensor, the current time step, 0-based.
          unused_all_done: A tf.bool, indicating whether the decoding finishes.
          step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the
            current search step.
          core_bs_states: A tuple of core beam search states.
          other_states_list: A flattened NestedMap of other beam search states.

        Returns:
          The updated input tuple, with the same shape.
        """
                (cur_step, all_done, new_step_ids, new_bs_states,
                 new_other_states) = self._BeamSearchStep(
                     theta,
                     encoder_outputs,
                     cur_step,
                     step_ids,
                     core_bs_states,
                     other_states.Pack(other_states_list),
                     num_hyps_per_beam,
                     pre_beam_search_step_callback,
                     post_beam_search_step_callback,
                     use_short_seq_opt=True)
                return (cur_step, all_done, new_step_ids, new_bs_states,
                        new_other_states.Flatten())

            (cur_step, all_done, step_ids, core_bs_states,
             flat_other_states) = tf.while_loop(
                 LoopContinueShort,
                 LoopBodyShort,
                 loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                            flat_other_states),
                 parallel_iterations=10,
                 back_prop=False,
                 swap_memory=False,
                 shape_invariants=(
                     tf.TensorShape(cur_step.get_shape()),
                     tf.TensorShape(all_done.get_shape()),
                     tf.TensorShape(step_ids.get_shape()),
                     tuple(
                         list(_GetShapes(accum_bs_states)) +
                         list(_GetShapes(non_accum_bs_states,
                                         none_shapes=True))),
                     _GetShapes(flat_other_states, none_shapes=True)),
                 maximum_iterations=max_steps)

        def LoopContinueLong(cur_step, all_done, unused_step_ids,
                             unused_core_bs_states, unused_other_states_list):
            """Continue default implementation until decoding finishes."""
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states,
                         other_states_list):
            """Loop body of default long_seq implementation."""
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta,
                 encoder_outputs,
                 cur_step,
                 step_ids,
                 core_bs_states,
                 other_states.Pack(other_states_list),
                 num_hyps_per_beam,
                 pre_beam_search_step_callback,
                 post_beam_search_step_callback,
                 use_short_seq_opt=False)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinueLong,
            LoopBodyLong,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(
                tf.TensorShape(cur_step.get_shape()),
                tf.TensorShape(all_done.get_shape()),
                tf.TensorShape(step_ids.get_shape()),
                tuple(
                    list(_GetShapes(accum_bs_states)) +
                    list(_GetShapes(non_accum_bs_states, none_shapes=True))),
                _GetShapes(flat_other_states, none_shapes=False)),
            maximum_iterations=max_steps)

        if isinstance(source_paddings, py_utils.NestedMap):
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                              1)),
                                         dtype=tf.int32)
        else:
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                                         dtype=tf.int32)

        # Concatenate all outputs on axis=0.
        scores = final_bs_states[3].stack()
        hyps = final_bs_states[4].stack()
        prev_hyps = final_bs_states[5].stack()
        done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool)
        atten_probs = final_bs_states[7].stack()
        eos_scores = final_bs_states[8].stack()
        eos_atten_probs = final_bs_states[9].stack()
        rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores,
                eos_atten_probs, source_seq_lengths)

        # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after
        # b/111131551 is resolved.
        # Canonical shapes for tensors of various. ranks
        r_shapes = [
            py_utils.GetShape(source_seq_lengths),
            py_utils.GetShape(hyps),
            py_utils.GetShape(atten_probs)
        ]
        # Reshape all tensors to [-1] to avoid cost of copy due to padding.
        rets_r1 = [tf.reshape(r, [-1]) for r in rets]

        return tuple(r_shapes) + tuple(rets_r1) + tuple(
            flat_final_other_states)