示例#1
0
    def body_infer(time, inputs, caches, outputs_tas, finished, log_probs,
                   lengths, infer_status_ta):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          infer_status_ta: structure of TensorArray.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """
        # step decoder
        outputs = []
        next_caches = []
        for dec, inp, cache in zip(decoders, inputs, caches):
            with tf.variable_scope(dec.name):
                out, next_cache = dec.step(inp, cache)
                outputs.append(out)
                next_caches.append(next_cache)
        next_outputs_tas = []
        for out_ta, out, rem in zip(outputs_tas, outputs,
                                    decoder_output_removers):
            ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    out_ta, rem.apply(out))
            next_outputs_tas.append(ta)
        logits = []
        for dec, modality, out in zip(decoders, target_modalities, outputs):
            logits.append(_compute_logits(dec, modality, out))
        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"],
                                                 beam_ids)

        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           predicted_ids=sample_ids,
                                           beam_ids=beam_ids,
                                           lengths=next_lengths)
        infer_status_ta = nest.map_structure(
            lambda ta, out: ta.write(time, out), infer_status_ta, infer_status)
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = nest.map_structure(
            lambda modality: _embed_words(modality, next_input_symbols, time +
                                          1), target_modalities)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, infer_status_ta
示例#2
0
def dynamic_ensemble_decode(decoders,
                            encoder_outputs,
                            bridges,
                            target_modalities,
                            helper,
                            parallel_iterations=32,
                            swap_memory=False):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        target_modalities: A list of `Modality` instances.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_removers = nest.map_structure(
        lambda dec: DecoderOutputRemover(dec.mode, dec.output_dtype._fields,
                                         dec.output_ignore_fields), decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_input_symbols_embed = nest.map_structure(
        lambda modality: _embed_words(modality, initial_input_symbols,
                                      initial_time), target_modalities)

    inputs_preprocessing_fns = []
    inputs_postprocessing_fns = []
    initial_inputs = []
    initial_decoder_states = []
    decoding_params = []
    for dec, enc_out, bri, inp in zip(decoders, encoder_outputs, bridges,
                                      initial_input_symbols_embed):
        with tf.variable_scope(dec.name):
            inputs_preprocessing_fn, inputs_postprocessing_fn = dec.inputs_prepost_processing_fn(
            )
            inputs = inputs_postprocessing_fn(None, inp)
            dec_states, dec_params = dec.prepare(enc_out, bri,
                                                 helper)  # prepare decoder
            dec_states = stack_beam_size(dec_states, helper.beam_size)
            dec_params = stack_beam_size(dec_params, helper.beam_size)
            # add to list
            inputs_preprocessing_fns.append(inputs_preprocessing_fn)
            inputs_postprocessing_fns.append(inputs_postprocessing_fn)
            initial_inputs.append(inputs)
            initial_decoder_states.append(dec_states)
            decoding_params.append(dec_params)

    initial_outputs_tas = nest.map_structure(
        lambda dec_out_rem, dec: nest.map_structure(
            _create_ta, dec_out_rem.apply(dec.output_dtype)),
        decoder_output_removers, decoders)

    def body_infer(time, inputs, decoder_states, outputs_tas, finished,
                   log_probs, lengths, infer_status_ta):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          decoder_states: A list of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          infer_status_ta: structure of TensorArray.

        Returns:
          `(time + 1, next_inputs, next_decoder_states, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """
        # step decoder
        outputs = []
        cur_inputs = []
        next_decoder_states = []
        for dec, inp, pre_fn, stat, dec_params in \
                zip(decoders, inputs, inputs_preprocessing_fns, decoder_states, decoding_params):
            with tf.variable_scope(dec.name):
                inp = pre_fn(time, inp)
                out, next_stat = dec.step(inp, stat, dec_params)
                cur_inputs.append(inp)
                outputs.append(out)
                next_decoder_states.append(next_stat)
        next_outputs_tas = []
        for out_ta, out, rem in zip(outputs_tas, outputs,
                                    decoder_output_removers):
            ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    out_ta, rem.apply(out))
            next_outputs_tas.append(ta)
        logits = []
        for dec, modality, out in zip(decoders, target_modalities, outputs):
            logits.append(_compute_logits(dec, modality, out))
        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)
        gathered_states = []
        for next_stat in next_decoder_states:
            gathered_states.append(gather_states(next_stat, beam_ids))
        cur_inputs = nest.map_structure(
            lambda inp: gather_states(inp, beam_ids), cur_inputs)
        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           predicted_ids=sample_ids,
                                           beam_ids=beam_ids,
                                           lengths=next_lengths)
        infer_status_ta = nest.map_structure(
            lambda ta, out: ta.write(time, out), infer_status_ta, infer_status)
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs_embed = nest.map_structure(
            lambda modality: _embed_words(modality, next_input_symbols, time +
                                          1), target_modalities)
        next_finished = tf.logical_or(next_finished, finished)
        next_inputs = []
        for dec, cur_inp, next_inp, post_fn in zip(decoders, cur_inputs,
                                                   next_inputs_embed,
                                                   inputs_postprocessing_fns):
            with tf.variable_scope(dec.name):
                next_inputs.append(post_fn(cur_inp, next_inp))
        return time + 1, next_inputs, gathered_states, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, infer_status_ta

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_infer_status_ta = nest.map_structure(_create_ta,
                                                 BeamSearchStateSpec.dtypes())
    loop_vars = [
        initial_time,
        initial_inputs,
        initial_decoder_states,
        initial_outputs_tas,
        initial_finished,
        # infer vars
        initial_log_probs,
        initial_lengths,
        initial_infer_status_ta
    ]

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_infer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    final_infer_status = nest.map_structure(lambda ta: ta.stack(), res[-1])
    return final_infer_status
示例#3
0
def dynamic_ensemble_decode(decoders,
                            encoder_outputs,
                            bridges,
                            target_modalities,
                            helper,
                            parallel_iterations=32,
                            swap_memory=False,
                            **kwargs):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        target_modalities: A list of `Modality` instances.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_removers = nest.map_structure(
        lambda dec: DecoderOutputRemover(dec.mode, dec.output_dtype._fields,
                                         dec.output_ignore_fields), decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = nest.map_structure(
        lambda modality: _embed_words(modality, initial_input_symbols,
                                      initial_time), target_modalities)
    assert "beam_size" in kwargs
    beam_size = kwargs["beam_size"]
    initial_caches = []
    for dec, enc_out, bri in zip(decoders, encoder_outputs, bridges):
        with tf.variable_scope(dec.name):
            init_cache = dec.prepare(enc_out, bri, helper)  # prepare decoder
            init_cache = stack_beam_size(init_cache, beam_size)
            initial_caches.append(init_cache)

    initial_outputs_tas = nest.map_structure(
        lambda dec_out_rem, dec: nest.map_structure(
            _create_ta, dec_out_rem.apply(dec.output_dtype)),
        decoder_output_removers, decoders)

    def body_infer(time, inputs, caches, outputs_tas, finished, log_probs,
                   lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """
        # step decoder
        outputs = []
        next_caches = []
        for dec, inp, cache in zip(decoders, inputs, caches):
            with tf.variable_scope(dec.name):
                out, next_cache = dec.step(inp, cache)
                outputs.append(out)
                next_caches.append(next_cache)
        next_outputs_tas = []
        for out_ta, out, rem in zip(outputs_tas, outputs,
                                    decoder_output_removers):
            ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    out_ta, rem.apply(out))
            next_outputs_tas.append(ta)
        logits = []
        for dec, modality, out in zip(decoders, target_modalities, outputs):
            logits.append(_compute_logits(dec, modality, out))
        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"],
                                                 beam_ids)

        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(
            tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat(
            [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = nest.map_structure(
            lambda modality: _embed_words(modality, next_input_symbols, time +
                                          1), target_modalities)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_bs_stat_ta = nest.map_structure(_create_ta,
                                            BeamSearchStateSpec.dtypes())
    initial_input_symbols.set_shape([None])
    loop_vars = [
        initial_time,
        initial_inputs,
        initial_caches,
        initial_outputs_tas,
        initial_finished,
        # infer vars
        initial_log_probs,
        initial_lengths,
        initial_bs_stat_ta,
        initial_input_symbols
    ]

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_infer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    timesteps = res[0] + 1
    log_probs, length, bs_stat, predicted_ids = res[-4:]
    final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
    return {
        "beam_ids": final_bs_stat.beam_ids,
        "log_probs": final_bs_stat.log_probs,
        "decoding_length": length,
        "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]
    }