예제 #1
0
    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids
예제 #2
0
    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids
예제 #3
0
    def map_fn(n, d, p):
        # n: name prefix
        # d: data list
        # p: padding symbol
        data[concat_name(n, Constants.IDS_NAME)] = d
        n_samples = len(d)
        n_devices = len(input_fields)
        n_samples_per_gpu = n_samples // n_devices
        if n_samples % n_devices > 0:
            n_samples_per_gpu += 1

        def _feed_batchs(_start_idx, _inpf):

            if _start_idx * n_samples_per_gpu >= n_samples:
                return 0
            x, x_len = padding_batch_data(
                d[_start_idx * n_samples_per_gpu:(_start_idx + 1) *
                  n_samples_per_gpu], p)
            data["feed_dict"][_inpf[concat_name(n, Constants.IDS_NAME)]] = x
            data["feed_dict"][_inpf[concat_name(
                n, Constants.LENGTH_NAME)]] = x_len
            return len(x_len)

        parallels = repeat_n_times(n_devices, _feed_batchs,
                                   list(range(n_devices)), input_fields)
        data["feed_dict"]["parallels"] = parallels
예제 #4
0
def evaluate_with_attention(sess,
                            loss_op,
                            eval_data,
                            vocab_source,
                            vocab_target,
                            attention_op=None,
                            output_filename_prefix=None):
    """ Evaluates data by loss.

    Args:
        sess: `tf.Session`.
        loss_op: Tensorflow operation, computing the loss.
        eval_data: An iterable instance that each element
          is a packed feeding dictionary for `sess`.
        vocab_source: A `Vocab` instance for source side feature map.
        vocab_target: A `Vocab` instance for target side feature map.
        attention_op: Tensorflow operation for output attention.
        output_filename_prefix: A string.

    Returns: Total loss averaged by number of data samples.
    """
    losses = 0.
    weights = 0.
    num_of_samples = 0
    attentions = {}
    for data in eval_data:
        _n_samples = len(data["feature_ids"])
        parallels = data["feed_dict"].pop("parallels")
        avail = sum(numpy.array(parallels) > 0)
        if attention_op is None:
            loss = _evaluate(sess, data["feed_dict"], loss_op[:avail])
        else:
            loss, atts = _evaluate(sess, data["feed_dict"],
                                   [loss_op[:avail], attention_op[:avail]])
            ss_strs = [
                vocab_source.convert_to_wordlist(ss, bpe_decoding=False)
                for ss in data["feature_ids"]
            ]
            tt_strs = [
                vocab_target.convert_to_wordlist(tt,
                                                 bpe_decoding=False,
                                                 reverse_seq=False)
                for tt in data["label_ids"]
            ]
            _attentions = sum(
                repeat_n_times(avail, select_attention_sample_by_sample, atts),
                [])
            attentions.update(
                pack_batch_attention_dict(num_of_samples, ss_strs, tt_strs,
                                          _attentions))
        data["feed_dict"]["parallels"] = parallels
        losses += sum([_l[0] for _l in loss])
        weights += sum([_l[1] for _l in loss])
        num_of_samples += _n_samples
    loss = losses / weights
    if attention_op is not None:
        dump_attentions(output_filename_prefix, attentions)
    return loss
예제 #5
0
 def _add(prefix):
     nonpadding_tokens_num, total_tokens_num = repeat_n_times(
         len(input_fields), compute_non_padding_num, input_fields, prefix)
     nonpadding_tokens_num = tf.reduce_sum(nonpadding_tokens_num)
     total_tokens_num = tf.reduce_sum(total_tokens_num)
     tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME,
                          "input_stats/{}_nonpadding_tokens_num".format(prefix))
     tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME, nonpadding_tokens_num)
     tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME, "input_stats/{}_nonpadding_ratio".format(prefix))
     tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME,
                          tf.to_float(nonpadding_tokens_num)
                          / tf.to_float(total_tokens_num))
예제 #6
0
 def _add(prefix):
     nonpadding_tokens_num, total_tokens_num = repeat_n_times(
         len(input_fields), compute_non_padding_num, input_fields, prefix)
     nonpadding_tokens_num = tf.reduce_sum(nonpadding_tokens_num)
     total_tokens_num = tf.reduce_sum(total_tokens_num)
     tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME,
                          "input_stats/{}_nonpadding_tokens_num".format(prefix))
     tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME, nonpadding_tokens_num)
     tf.add_to_collection(Constants.DISPLAY_KEY_COLLECTION_NAME, "input_stats/{}_nonpadding_ratio".format(prefix))
     tf.add_to_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME,
                          tf.to_float(nonpadding_tokens_num)
                          / tf.to_float(total_tokens_num))
예제 #7
0
파일: decode.py 프로젝트: KIngpon/NJUNMT-tf
def evaluate_with_attention(
        sess,
        loss_op,
        eval_data,
        vocab_source,
        vocab_target,
        attention_op=None,
        output_filename_prefix=None):
    """ Evaluates data by loss.

    Args:
        sess: `tf.Session`.
        loss_op: Tensorflow operation, computing the loss.
        eval_data: An iterable instance that each element
          is a packed feeding dictionary for `sess`.
        vocab_source: A `Vocab` instance for source side feature map.
        vocab_target: A `Vocab` instance for target side feature map.
        attention_op: Tensorflow operation for output attention.
        output_filename_prefix: A string.

    Returns: Total loss averaged by number of data samples.
    """
    losses = 0.
    weights = 0.
    num_of_samples = 0
    attentions = {}
    for data in eval_data:
        _n_samples = len(data["feature_ids"])
        parallels = data["feed_dict"].pop("parallels")
        avail = sum(numpy.array(parallels) > 0)
        if attention_op is None:
            loss = _evaluate(sess, data["feed_dict"], loss_op[:avail])
        else:
            loss, atts = _evaluate(sess, data["feed_dict"],
                                   [loss_op[:avail], attention_op[:avail]])
            ss_strs = [vocab_source.convert_to_wordlist(ss, bpe_decoding=False)
                       for ss in data["feature_ids"]]
            tt_strs = [vocab_target.convert_to_wordlist(
                tt, bpe_decoding=False, reverse_seq=False)
                       for tt in data["label_ids"]]
            _attentions = sum(repeat_n_times(avail, select_attention_sample_by_sample,
                                             atts), [])
            attentions.update(pack_batch_attention_dict(
                num_of_samples, ss_strs, tt_strs, _attentions))
        data["feed_dict"]["parallels"] = parallels
        losses += sum([_l[0] for _l in loss])
        weights += sum([_l[1] for _l in loss])
        num_of_samples += _n_samples
    loss = losses / weights
    if attention_op is not None:
        dump_attentions(output_filename_prefix, attentions)
    return loss
예제 #8
0
    def build(self, base_models, vocab_target, input_fields):
        """ Builds the ensemble model.

        Args:
            base_models: A list of `BaseSeq2Seq` instances.
            vocab_target: An instance of `Vocab`.
            input_fields: A dict of placeholders.

        Returns: A dictionary of inference status.
        """
        encoder_outputs = []
        # prepare for decoding of each model
        for index, model in enumerate(base_models):
            with tf.variable_scope(
                            Constants.ENSEMBLE_VARNAME_PREFIX + str(index)):
                with tf.variable_scope(model.name):
                    encoder_output = model._encode(input_fields=input_fields)
                    vs_name = tf.get_variable_scope().name
                    model._decoder.name = os.path.join(vs_name, model._decoder.name)
                    model._target_modality.name = os.path.join(vs_name, model._target_modality.name)
                encoder_outputs.append(encoder_output)

        helper = BeamFeedback(
            vocab=vocab_target,
            batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0],
            maximum_labels_length=self._maximum_labels_length,
            beam_size=self._beam_size,
            alpha=self._length_penalty,
            ensemble_weight=self.get_ensemble_weights(len(base_models)))

        decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \
            repeat_n_times(
                len(base_models),
                lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn),
                base_models)

        decoding_result = dynamic_ensemble_decode(
            decoders=decoders,
            encoder_outputs=encoder_outputs,
            bridges=bridges,
            helper=helper,
            target_to_embedding_fns=target_to_emb_fns,
            outputs_to_logits_fns=outputs_to_logits_fns,
            beam_size=self._beam_size)
        predict_out = process_beam_predictions(
            decoding_result=decoding_result,
            beam_size=self._beam_size,
            alpha=self._length_penalty)
        predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME]
        return predict_out
예제 #9
0
    def build(self, input_fields):
        """ Builds the ensemble model.

        Args:
            input_fields: A dict of placeholders.

        Returns: A dictionary of inference status.
        """
        encoder_outputs = []
        # prepare for decoding of each model
        for index, model in enumerate(self._base_models):
            encoder_output = model._encode(input_fields=input_fields)
            encoder_outputs.append(encoder_output)

        helper = BeamFeedback(
            vocab=self._vocab_target,
            batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0],
            maximum_labels_length=self._maximum_labels_length,
            beam_size=self._beam_size,
            alpha=self._length_penalty,
            ensemble_weight=self.get_ensemble_weights(len(self._base_models)))

        decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \
            repeat_n_times(
                len(self._base_models),
                lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn),
                self._base_models)

        decoding_result = dynamic_ensemble_decode(
            decoders=decoders,
            encoder_outputs=encoder_outputs,
            bridges=bridges,
            helper=helper,
            target_to_embedding_fns=target_to_emb_fns,
            outputs_to_logits_fns=outputs_to_logits_fns,
            beam_size=self._beam_size)
        predict_out = process_beam_predictions(
            decoding_result=decoding_result,
            beam_size=self._beam_size,
            alpha=self._length_penalty)
        predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME]
        return predict_out
예제 #10
0
    def build(self, input_fields):
        """ Builds the ensemble model.

        Args:
            input_fields: A dict of placeholders.

        Returns: A dictionary of inference status.
        """
        encoder_outputs = []
        # prepare for decoding of each model
        for index, model in enumerate(self._base_models):
            encoder_output = model._encode(input_fields=input_fields)
            encoder_outputs.append(encoder_output)

        helper = BeamFeedback(
            vocab=self._vocab_target,
            batch_size=tf.shape(input_fields[Constants.FEATURE_IDS_NAME])[0],
            maximum_labels_length=self._maximum_labels_length,
            beam_size=self._beam_size,
            alpha=self._length_penalty,
            ensemble_weight=self.get_ensemble_weights(len(self._base_models)))

        decoders, bridges, target_to_emb_fns, outputs_to_logits_fns = \
            repeat_n_times(
                len(self._base_models),
                lambda m: (m._decoder, m._encoder_decoder_bridge, m._target_to_embedding_fn, m._outputs_to_logits_fn),
                self._base_models)

        decoding_result = dynamic_ensemble_decode(
            decoders=decoders,
            encoder_outputs=encoder_outputs,
            bridges=bridges,
            helper=helper,
            target_to_embedding_fns=target_to_emb_fns,
            outputs_to_logits_fns=outputs_to_logits_fns,
            beam_size=self._beam_size)
        predict_out = process_beam_predictions(decoding_result=decoding_result,
                                               beam_size=self._beam_size,
                                               alpha=self._length_penalty)
        predict_out["source"] = input_fields[Constants.FEATURE_IDS_NAME]
        return predict_out
예제 #11
0
def dynamic_ensemble_decode(decoders,
                            encoder_outputs,
                            bridges,
                            helper,
                            target_to_embedding_fns,
                            outputs_to_logits_fns,
                            parallel_iterations=32,
                            swap_memory=False,
                            **kwargs):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        target_to_embedding_fns: A list of callables, converts target ids to
          embeddings.
        outputs_to_logits_fns: A list of callables, converts decoder outputs
          to logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    num_models = len(decoders)
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(dtype=d,
                              clear_after_read=False,
                              size=0,
                              dynamic_size=True)

    decoder_output_removers = repeat_n_times(
        num_models, lambda dec: DecoderOutputRemover(
            dec.mode, dec.output_dtype._fields, dec.output_ignore_fields),
        decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                    initial_input_symbols, initial_time)

    assert "beam_size" in kwargs
    beam_size = kwargs["beam_size"]

    def _create_cache(_decoder, _encoder_output, _bridge):
        with tf.variable_scope(_decoder.name):
            _init_cache = _decoder.prepare(_encoder_output, _bridge, helper)
            _init_cache = stack_beam_size(_init_cache, beam_size)
        return _init_cache

    initial_caches = repeat_n_times(num_models, _create_cache, decoders,
                                    encoder_outputs, bridges)

    initial_outputs_tas = [
        nest.map_structure(
            _create_ta, _decoder_output_remover.apply(_decoder.output_dtype))
        for _decoder_output_remover, _decoder in zip(decoder_output_removers,
                                                     decoders)
    ]

    def body_infer(time, inputs, caches, outputs_tas, finished, log_probs,
                   lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(
                lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding, decoders, inputs, caches,
            decoder_output_removers, outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"],
                                                 beam_ids)

        infer_status = BeamSearchStateSpec(log_probs=next_log_probs,
                                           beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(
            tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat(
            [predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(
            time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_bs_stat_ta = nest.map_structure(_create_ta,
                                            BeamSearchStateSpec.dtypes())
    initial_input_symbols.set_shape([None])
    loop_vars = [
        initial_time,
        initial_inputs,
        initial_caches,
        initial_outputs_tas,
        initial_finished,
        # infer vars
        initial_log_probs,
        initial_lengths,
        initial_bs_stat_ta,
        initial_input_symbols
    ]

    res = tf.while_loop(lambda *args: tf.logical_not(tf.reduce_all(args[4])),
                        body_infer,
                        loop_vars=loop_vars,
                        parallel_iterations=parallel_iterations,
                        swap_memory=swap_memory)

    timesteps = res[0] + 1
    log_probs, length, bs_stat, predicted_ids = res[-4:]
    final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
    return {
        "beam_ids": final_bs_stat.beam_ids,
        "log_probs": final_bs_stat.log_probs,
        "decoding_length": length,
        "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]
    }
예제 #12
0
def dynamic_ensemble_decode(
        decoders,
        encoder_outputs,
        bridges,
        helper,
        target_to_embedding_fns,
        outputs_to_logits_fns,
        parallel_iterations=32,
        swap_memory=False,
        **kwargs):
    """ Performs dynamic decoding with `decoders`.

    Calls prepare() once and step() repeatedly on `Decoder` object.

    Args:
        decoders: A list of `Decoder` instances.
        encoder_outputs: A list of `collections.namedtuple`s from each
          corresponding `Encoder.encode()`.
        bridges: A list of `Bridge` instances or Nones.
        helper: An instance of `Feedback` that samples next symbols
          from logits.
        target_to_embedding_fns: A list of callables, converts target ids to
          embeddings.
        outputs_to_logits_fns: A list of callables, converts decoder outputs
          to logits.
        parallel_iterations: Argument passed to `tf.while_loop`.
        swap_memory: Argument passed to `tf.while_loop`.
        kwargs:

    Returns: The results of inference, an instance of `collections.namedtuple`
      whose element types are defined by `BeamSearchStateSpec`, indicating
      the status of beam search.
    """
    num_models = len(decoders)
    var_scope = tf.get_variable_scope()
    # Properly cache variable values inside the while_loop
    if var_scope.caching_device is None:
        var_scope.set_caching_device(lambda op: op.device)

    def _create_ta(d):
        return tf.TensorArray(
            dtype=d, clear_after_read=False,
            size=0, dynamic_size=True)

    decoder_output_removers = repeat_n_times(
        num_models, lambda dec: DecoderOutputRemover(
            dec.mode, dec.output_dtype._fields, dec.output_ignore_fields), decoders)

    # initialize first inputs (start of sentence) with shape [_batch*_beam,]
    initial_finished, initial_input_symbols = helper.init_symbols()
    initial_time = tf.constant(0, dtype=tf.int32)
    initial_inputs = repeat_n_times(
        num_models, target_to_embedding_fns,
        initial_input_symbols, initial_time)

    assert "beam_size" in kwargs
    beam_size = kwargs["beam_size"]

    def _create_cache(_decoder, _encoder_output, _bridge):
        with tf.variable_scope(_decoder.name):
            _init_cache = _decoder.prepare(_encoder_output, _bridge, helper)
            _init_cache = stack_beam_size(_init_cache, beam_size)
        return _init_cache

    initial_caches = repeat_n_times(
        num_models, _create_cache,
        decoders, encoder_outputs, bridges)

    initial_outputs_tas = [nest.map_structure(
        _create_ta, _decoder_output_remover.apply(_decoder.output_dtype))
                           for _decoder_output_remover, _decoder in zip(decoder_output_removers, decoders)]

    def body_infer(time, inputs, caches, outputs_tas, finished,
                   log_probs, lengths, bs_stat_ta, predicted_ids):
        """Internal while_loop body.

        Args:
          time: Scalar int32 Tensor.
          inputs: A list of inputs Tensors.
          caches: A dict of decoder states.
          outputs_tas: A list of TensorArrays.
          finished: A bool tensor (keeping track of what's finished).
          log_probs: The log probability Tensor.
          lengths: The decoding length Tensor.
          bs_stat_ta: structure of TensorArray.
          predicted_ids: A Tensor.

        Returns:
          `(time + 1, next_inputs, next_caches, next_outputs_tas,
          next_finished, next_log_probs, next_lengths, next_infer_status_ta)`.
        """

        # step decoder
        def _decoding(_decoder, _input, _cache, _decoder_output_remover,
                      _outputs_ta, _outputs_to_logits_fn):
            with tf.variable_scope(_decoder.name):
                _output, _next_cache = _decoder.step(_input, _cache)
                _decoder_top_features = _decoder.merge_top_features(_output)
            _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms),
                                     _outputs_ta, _decoder_output_remover.apply(_output))
            _logit = _outputs_to_logits_fn(_decoder_top_features)
            return _output, _next_cache, _ta, _logit

        outputs, next_caches, next_outputs_tas, logits = repeat_n_times(
            num_models, _decoding,
            decoders, inputs, caches, decoder_output_removers,
            outputs_tas, outputs_to_logits_fns)

        # sample next symbols
        sample_ids, beam_ids, next_log_probs, next_lengths \
            = helper.sample_symbols(logits, log_probs, finished, lengths, time=time)

        for c in next_caches:
            c["decoding_states"] = gather_states(c["decoding_states"], beam_ids)

        infer_status = BeamSearchStateSpec(
            log_probs=next_log_probs,
            beam_ids=beam_ids)
        bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                        bs_stat_ta, infer_status)
        predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids)
        next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1)
        next_predicted_ids = tf.reshape(next_predicted_ids, [-1])
        next_predicted_ids.set_shape([None])
        next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids)
        next_inputs = repeat_n_times(num_models, target_to_embedding_fns,
                                     next_input_symbols, time + 1)
        next_finished = tf.logical_or(next_finished, finished)

        return time + 1, next_inputs, next_caches, next_outputs_tas, \
               next_finished, next_log_probs, next_lengths, bs_stat_ta, \
               next_predicted_ids

    initial_log_probs = tf.zeros_like(initial_input_symbols, dtype=tf.float32)
    initial_lengths = tf.zeros_like(initial_input_symbols, dtype=tf.int32)
    initial_bs_stat_ta = nest.map_structure(_create_ta, BeamSearchStateSpec.dtypes())
    initial_input_symbols.set_shape([None])
    loop_vars = [initial_time, initial_inputs, initial_caches,
                 initial_outputs_tas, initial_finished,
                 # infer vars
                 initial_log_probs, initial_lengths, initial_bs_stat_ta,
                 initial_input_symbols]

    res = tf.while_loop(
        lambda *args: tf.logical_not(tf.reduce_all(args[4])),
        body_infer,
        loop_vars=loop_vars,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    timesteps = res[0] + 1
    log_probs, length, bs_stat, predicted_ids = res[-4:]
    final_bs_stat = nest.map_structure(lambda ta: ta.stack(), bs_stat)
    return {"beam_ids": final_bs_stat.beam_ids,
            "log_probs": final_bs_stat.log_probs,
            "decoding_length": length,
            "hypothesis": tf.reshape(predicted_ids, [-1, timesteps])[:, 1:]}
예제 #13
0
def _infer(sess,
           feed_dict,
           prediction_op,
           batch_size,
           top_k=1,
           output_attention=False):
    """ Infers a batch of samples with beam search.

    Args:
        sess: `tf.Session`
        feed_dict: A dictionary of feeding data.
        prediction_op: Tensorflow operation for inference.
        batch_size: The batch size.
        top_k: An integer, number of predicted sequences will be
          returned.
        output_attention: Whether to output attention.

    Returns: A tuple `(predicted_sequences, attention_scores)`.
      The `predicted_sequences` is a list of hypothesis with
      approx [`top_k` * `batch_sze`, sequence_length].
      The `attention_scores` is None if there is no attention
      related information in `prediction_op`.
    """
    parallels = feed_dict.pop("parallels")
    avail = sum(numpy.array(parallels) > 0)
    extract_keys = ["sorted_hypothesis"]
    if output_attention:
        assert top_k == 1
        extract_keys.extend(["sorted_argidx", "attentions", "beam_ids"])
    brief_pred_op = dict(
        zip(
            extract_keys,
            repeat_n_times(avail,
                           lambda dd: tuple([dd[k] for k in extract_keys]),
                           prediction_op[:avail])))
    predict_out = sess.run(brief_pred_op, feed_dict=feed_dict)
    feed_dict["parallels"] = parallels
    total_samples = sum(
        repeat_n_times(avail, lambda p: p.shape[0],
                       predict_out["sorted_hypothesis"]))
    beam_size = total_samples // batch_size

    def _post_process_hypo(pred, **kwargs):
        _num_samples = pred.shape[0]
        _batch_size = _num_samples // beam_size
        batch_beam_pos = numpy.tile(
            numpy.arange(_batch_size) * beam_size, [beam_size, 1]).transpose()
        batch_beam_pos = numpy.reshape(batch_beam_pos[:, :top_k], -1)
        if output_attention:
            atts = postprocess_attention(
                beam_ids=kwargs["beam_ids"],
                attention_dict=kwargs["attentions"],
                gather_idx=kwargs["sorted_argidx"][batch_beam_pos])
            return pred[batch_beam_pos, :].tolist(), atts
        return pred[batch_beam_pos, :].tolist(), []

    hypothesises, attentions = repeat_n_times(
        avail,
        _post_process_hypo,
        predict_out["sorted_hypothesis"],
        beam_ids=predict_out.get("beam_ids", None),
        attentions=predict_out.get("attentions", None),
        sorted_argidx=predict_out.get("sorted_argidx", None))
    hypothesis = sum(hypothesises, [])
    attention = sum(attentions, [])
    return hypothesis, attention
예제 #14
0
파일: decode.py 프로젝트: KIngpon/NJUNMT-tf
def _infer(
        sess,
        feed_dict,
        prediction_op,
        batch_size,
        top_k=1,
        output_attention=False):
    """ Infers a batch of samples with beam search.

    Args:
        sess: `tf.Session`
        feed_dict: A dictionary of feeding data.
        prediction_op: Tensorflow operation for inference.
        batch_size: The batch size.
        top_k: An integer, number of predicted sequences will be
          returned.
        output_attention: Whether to output attention.

    Returns: A tuple `(predicted_sequences, attention_scores)`.
      The `predicted_sequences` is a list of hypothesis with
      approx [`top_k` * `batch_sze`, sequence_length].
      The `attention_scores` is None if there is no attention
      related information in `prediction_op`.
    """
    parallels = feed_dict.pop("parallels")
    avail = sum(numpy.array(parallels) > 0)
    extract_keys = ["sorted_hypothesis"]
    if output_attention:
        assert top_k == 1
        extract_keys.extend(["sorted_argidx", "attentions", "beam_ids"])
    brief_pred_op = dict(zip(
        extract_keys,
        repeat_n_times(
            avail,
            lambda dd: tuple([dd[k] for k in extract_keys]),
            prediction_op[:avail])))
    predict_out = sess.run(brief_pred_op, feed_dict=feed_dict)
    feed_dict["parallels"] = parallels
    total_samples = sum(
        repeat_n_times(avail,
                       lambda p: p.shape[0],
                       predict_out["sorted_hypothesis"]))
    beam_size = total_samples // batch_size

    def _post_process_hypo(pred, **kwargs):
        _num_samples = pred.shape[0]
        _batch_size = _num_samples // beam_size
        batch_beam_pos = numpy.tile(numpy.arange(_batch_size) * beam_size, [beam_size, 1]).transpose()
        batch_beam_topk_add = numpy.tile(numpy.arange(top_k), [batch_beam_pos.shape[0], 1])
        batch_beam_pos = numpy.reshape(
            batch_beam_pos[:, :top_k] + batch_beam_topk_add, -1)
        if output_attention:
            atts = postprocess_attention(
                beam_ids=kwargs["beam_ids"],
                attention_dict=kwargs["attentions"],
                gather_idx=kwargs["sorted_argidx"][batch_beam_pos])
            return pred[batch_beam_pos, :].tolist(), atts
        return pred[batch_beam_pos, :].tolist(), []

    hypothesises, attentions = repeat_n_times(
        avail,
        _post_process_hypo,
        predict_out["sorted_hypothesis"],
        beam_ids=predict_out.get("beam_ids", None),
        attentions=predict_out.get("attentions", None),
        sorted_argidx=predict_out.get("sorted_argidx", None))
    hypothesis = sum(hypothesises, [])
    attention = sum(attentions, [])
    return hypothesis, attention