Example #1
0
def dynamic_decode(decoder,
                   encoder_output,
                   bridge,
                   helper,
                   target_modality,
                   parallel_iterations=32,
                   swap_memory=False,
                   **kwargs):
    """ Performs dynamic decoding with `decoder`.

    Call `prepare()` once and `step()` repeatedly on the `Decoder` object.

    Args:
        decoder: An instance of `Decoder`.
        encoder_output: An instance of `collections.namedtuple`
          from `Encoder.encode()`.
        bridge: An instance of `Bridge` that initializes the
          decoder states.
        helper: An instance of `Feedback` that samples next
          symbols from logits.
        target_modality: An instance of `Modality`, that deals
          with transformations from symbols to tensors or from
          tensors to symbols (the decoder top and bottom layer).
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: A tuple `(decoder_output, decoder_status)` for
      decoder.mode=INFER.
      `decoder_output` for decoder.mode=TRAIN/INFER.
    """
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_remover = DecoderOutputRemover(decoder.mode,
                                                  decoder.output_dtype._fields,
                                                  decoder.output_ignore_fields)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = _embed_words(target_modality, initial_input_symbols,
                                  initial_time)

    with tf.variable_scope(decoder.name):
        initial_cache = decoder.prepare(encoder_output, bridge,
                                        helper)  # prepare decoder
        if decoder.mode == ModeKeys.INFER:
            assert "beam_size" in kwargs
            beam_size = kwargs["beam_size"]
            initial_cache = stack_beam_size(initial_cache, beam_size)

    initial_outputs_ta = nest.map_structure(
        _create_ta, decoder_output_remover.apply(decoder.output_dtype))

    def body_traininfer(time, inputs, cache, outputs_ta, finished, *args):
        """Internal while_loop body.

        Args:
          time: scalar int32 Tensor.
          inputs: The inputs Tensor.
          cache: The decoder states.
          outputs_ta: structure of TensorArray.
          finished: A bool tensor (keeping track of what's finished).
          args: The log_probs, lengths, infer_status for mode==INFER.
        Returns:
          `(time + 1, next_inputs, next_cache, outputs_ta,
          next_finished, *args)`.
        """
        with tf.variable_scope(decoder.name):
            outputs, next_cache = decoder.step(inputs, cache)
        outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        outputs_ta,
                                        decoder_output_remover.apply(outputs))
        inner_loop_vars = [time + 1, None, None, outputs_ta, None]
        sample_ids = None
        if decoder.mode == ModeKeys.INFER:
            log_probs, lengths = args[0], args[1]
            bs_stat_ta = args[2]
            predicted_ids = args[3]
            logits = _compute_logits(decoder, target_modality, outputs)
            # sample next symbols
            sample_ids, beam_ids, next_log_probs, next_lengths \
                = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)
            predicted_ids = gather_states(
                tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)

            next_cache["decoding_states"] = gather_states(
                next_cache["decoding_states"], beam_ids)
            bs_stat = BeamSearchStateSpec(log_probs=next_log_probs,
                                          beam_ids=beam_ids)
            bs_stat_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), bs_stat_ta, bs_stat)
            next_predicted_ids = tf.concat(
                [predicted_ids,
                 tf.expand_dims(sample_ids, axis=1)], axis=1)
            next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
            next_predicted_ids.set_shape([None])
            inner_loop_vars.extend(
                [next_log_probs, next_lengths, bs_stat_ta, next_predicted_ids])

        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = _embed_words(target_modality, next_input_symbols,
                                   time + 1)

        next_finished = tf.logical_or(next_finished, finished)
        inner_loop_vars[1] = next_inputs
        inner_loop_vars[2] = next_cache
        inner_loop_vars[4] = next_finished
        return inner_loop_vars

    loop_vars = [
        initial_time, initial_inputs, initial_cache, initial_outputs_ta,
        initial_finished
    ]

    if decoder.mode == ModeKeys.INFER:  # add inference-specific parameters
        initial_log_probs = tf.zeros_like(initial_input_symbols,
                                          dtype=tf.float32)
        initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
        initial_bs_stat_ta = nest.map_structure(_create_ta,
                                                BeamSearchStateSpec.dtypes())
        # to process hypothesis
        initial_input_symbols.set_shape([None])
        loop_vars.extend([
            initial_log_probs, initial_lengths, initial_bs_stat_ta,
            initial_input_symbols
        ])

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_traininfer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    final_outputs_ta = res[3]
    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

    if decoder.mode == ModeKeys.INFER:
        timesteps = res[0] + 1
        log_probs, length, bs_stat, predicted_ids = res[-4:]
        final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
        return final_outputs, \
               {"beam_ids": final_bs_stat.beam_ids,
                "log_probs": final_bs_stat.log_probs,
                "decoding_length": length,
                "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]}

    return final_outputs
Example #2
0
def dynamic_ensemble_decode(decoders,
                            encoder_outputs,
                            bridges,
                            target_modalities,
                            helper,
                            parallel_iterations=32,
                            swap_memory=False):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        target_modalities: A list of `Modality` instances.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_removers = nest.map_structure(
        lambda dec: DecoderOutputRemover(dec.mode, dec.output_dtype._fields,
                                         dec.output_ignore_fields), decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_input_symbols_embed = nest.map_structure(
        lambda modality: _embed_words(modality, initial_input_symbols,
                                      initial_time), target_modalities)

    inputs_preprocessing_fns = []
    inputs_postprocessing_fns = []
    initial_inputs = []
    initial_decoder_states = []
    decoding_params = []
    for dec, enc_out, bri, inp in zip(decoders, encoder_outputs, bridges,
                                      initial_input_symbols_embed):
        with tf.variable_scope(dec.name):
            inputs_preprocessing_fn, inputs_postprocessing_fn = dec.inputs_prepost_processing_fn(
            )
            inputs = inputs_postprocessing_fn(None, inp)
            dec_states, dec_params = dec.prepare(enc_out, bri,
                                                 helper)  # prepare decoder
            dec_states = stack_beam_size(dec_states, helper.beam_size)
            dec_params = stack_beam_size(dec_params, helper.beam_size)
            # add to list
            inputs_preprocessing_fns.append(inputs_preprocessing_fn)
            inputs_postprocessing_fns.append(inputs_postprocessing_fn)
            initial_inputs.append(inputs)
            initial_decoder_states.append(dec_states)
            decoding_params.append(dec_params)

    initial_outputs_tas = nest.map_structure(
        lambda dec_out_rem, dec: nest.map_structure(
            _create_ta, dec_out_rem.apply(dec.output_dtype)),
        decoder_output_removers, decoders)

    def body_infer(time, inputs, decoder_states, outputs_tas, finished,
                   log_probs, lengths, infer_status_ta):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          decoder_states: A list of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          infer_status_ta: structure of TensorArray.

        Returns:
          `(time + 1, next_inputs, next_decoder_states, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """
        # step decoder
        outputs = []
        cur_inputs = []
        next_decoder_states = []
        for dec, inp, pre_fn, stat, dec_params in \
                zip(decoders, inputs, inputs_preprocessing_fns, decoder_states, decoding_params):
            with tf.variable_scope(dec.name):
                inp = pre_fn(time, inp)
                out, next_stat = dec.step(inp, stat, dec_params)
                cur_inputs.append(inp)
                outputs.append(out)
                next_decoder_states.append(next_stat)
        next_outputs_tas = []
        for out_ta, out, rem in zip(outputs_tas, outputs,
                                    decoder_output_removers):
            ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    out_ta, rem.apply(out))
            next_outputs_tas.append(ta)
        logits = []
        for dec, modality, out in zip(decoders, target_modalities, outputs):
            logits.append(_compute_logits(dec, modality, out))
        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)
        gathered_states = []
        for next_stat in next_decoder_states:
            gathered_states.append(gather_states(next_stat, beam_ids))
        cur_inputs = nest.map_structure(
            lambda inp: gather_states(inp, beam_ids), cur_inputs)
        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           predicted_ids=sample_ids,
                                           beam_ids=beam_ids,
                                           lengths=next_lengths)
        infer_status_ta = nest.map_structure(
            lambda ta, out: ta.write(time, out), infer_status_ta, infer_status)
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs_embed = nest.map_structure(
            lambda modality: _embed_words(modality, next_input_symbols, time +
                                          1), target_modalities)
        next_finished = tf.logical_or(next_finished, finished)
        next_inputs = []
        for dec, cur_inp, next_inp, post_fn in zip(decoders, cur_inputs,
                                                   next_inputs_embed,
                                                   inputs_postprocessing_fns):
            with tf.variable_scope(dec.name):
                next_inputs.append(post_fn(cur_inp, next_inp))
        return time + 1, next_inputs, gathered_states, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, infer_status_ta

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_infer_status_ta = nest.map_structure(_create_ta,
                                                 BeamSearchStateSpec.dtypes())
    loop_vars = [
        initial_time,
        initial_inputs,
        initial_decoder_states,
        initial_outputs_tas,
        initial_finished,
        # infer vars
        initial_log_probs,
        initial_lengths,
        initial_infer_status_ta
    ]

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_infer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    final_infer_status = nest.map_structure(lambda ta: ta.stack(), res[-1])
    return final_infer_status
Example #3
0
def dynamic_ensemble_decode(decoders,
                            encoder_outputs,
                            bridges,
                            helper,
                            target_to_embedding_fns,
                            outputs_to_logits_fns,
                            parallel_iterations=32,
                            swap_memory=False,
                            **kwargs):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        target_to_embedding_fns: A list of callables, converts target ids to
          embeddings.
        outputs_to_logits_fns: A list of callables, converts decoder outputs
          to logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    num_models = len(decoders)
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_removers = repeat_n_times(
        num_models, lambda dec: DecoderOutputRemover(
            dec.mode, dec.output_dtype._fields, dec.output_ignore_fields),
        decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                    initial_input_symbols, initial_time)

    assert "beam_size" in kwargs
    beam_size = kwargs["beam_size"]

    def _create_cache(_decoder, _encoder_output, _bridge):
        with tf.variable_scope(_decoder.name):
            _init_cache = _decoder.prepare(_encoder_output, _bridge, helper)
            _init_cache = stack_beam_size(_init_cache, beam_size)
        return _init_cache

    initial_caches = repeat_n_times(num_models, _create_cache, decoders,
                                    encoder_outputs, bridges)

    initial_outputs_tas = [
        nest.map_structure(
            _create_ta, _decoder_output_remover.apply(_decoder.output_dtype))
        for _decoder_output_remover, _decoder in zip(decoder_output_removers,
                                                     decoders)
    ]

    def body_infer(time, inputs, caches, outputs_tas, finished, log_probs,
                   lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(
                lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding, decoders, inputs, caches,
            decoder_output_removers, outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"],
                                                 beam_ids)

        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(
            tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat(
            [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_bs_stat_ta = nest.map_structure(_create_ta,
                                            BeamSearchStateSpec.dtypes())
    initial_input_symbols.set_shape([None])
    loop_vars = [
        initial_time,
        initial_inputs,
        initial_caches,
        initial_outputs_tas,
        initial_finished,
        # infer vars
        initial_log_probs,
        initial_lengths,
        initial_bs_stat_ta,
        initial_input_symbols
    ]

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_infer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    timesteps = res[0] + 1
    log_probs, length, bs_stat, predicted_ids = res[-4:]
    final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
    return {
        "beam_ids": final_bs_stat.beam_ids,
        "log_probs": final_bs_stat.log_probs,
        "decoding_length": length,
        "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]
    }
Example #4
0
def dynamic_ensemble_decode(
        decoders,
        encoder_outputs,
        bridges,
        helper,
        target_to_embedding_fns,
        outputs_to_logits_fns,
        parallel_iterations=32,
        swap_memory=False,
        **kwargs):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        target_to_embedding_fns: A list of callables, converts target ids to
          embeddings.
        outputs_to_logits_fns: A list of callables, converts decoder outputs
          to logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    num_models = len(decoders)
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(
            dtype=d, clear_after_read=False,
            size=0, dynamic_size=True)

    decoder_output_removers = repeat_n_times(
        num_models, lambda dec: DecoderOutputRemover(
            dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = repeat_n_times(
        num_models, target_to_embedding_fns,
        initial_input_symbols, initial_time)

    assert "beam_size" in kwargs
    beam_size = kwargs["beam_size"]

    def _create_cache(_decoder, _encoder_output, _bridge):
        with tf.variable_scope(_decoder.name):
            _init_cache = _decoder.prepare(_encoder_output, _bridge, helper)
            _init_cache = stack_beam_size(_init_cache, beam_size)
        return _init_cache

    initial_caches = repeat_n_times(
        num_models, _create_cache,
        decoders, encoder_outputs, bridges)

    initial_outputs_tas = [nest.map_structure(
        _create_ta, _decoder_output_remover.apply(_decoder.output_dtype))
                           for _decoder_output_remover, _decoder in zip(decoder_output_removers, decoders)]

    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_bs_stat_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes())
    initial_input_symbols.set_shape([None])
    loop_vars = [initial_time, initial_inputs, initial_caches,
                 initial_outputs_tas, initial_finished,
                 # infer vars
                 initial_log_probs, initial_lengths, initial_bs_stat_ta,
                 initial_input_symbols]

    res = tf.while_loop(
        lambda *args: tf.logical_not(tf.reduce_all(args[4])),
        body_infer,
        loop_vars=loop_vars,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    timesteps = res[0] + 1
    log_probs, length, bs_stat, predicted_ids = res[-4:]
    final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
    return {"beam_ids": final_bs_stat.beam_ids,
            "log_probs": final_bs_stat.log_probs,
            "decoding_length": length,
            "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]}