Ejemplo n.º 1
0
    def predict(self, encoder_outputs, encoder_decoder_attention_bias):
        """Return predicted sequence."""
        batch_size = tf.shape(encoder_outputs)[0]
        input_length = tf.shape(encoder_outputs)[1]
        max_decode_length = input_length + self.params["extra_decode_length"]

        symbols_to_logits_fn = self._get_symbols_to_logits_fn(
            max_decode_length)

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        # Create cache storing decoder attention values for each layer.
        cache = {
            "layer_%d" % layer: {
                "k":
                tf.zeros([batch_size, 0, self.params["hidden_size"]],
                         dtype=encoder_outputs.dtype),
                "v":
                tf.zeros([batch_size, 0, self.params["hidden_size"]],
                         dtype=encoder_outputs.dtype),
            }
            for layer in range(self.params["num_hidden_layers"])
        }

        # Add encoder output and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        # Use beam search to find the top beam_size sequences and scores.
        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["tgt_vocab_size"],
            beam_size=self.params["beam_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            eos_id=self.params["EOS_ID"])

        # Get the top sequence for each batch element
        top_decoded_ids = decoded_ids[:, 0, 1:]
        top_scores = scores[:, 0]

        # this isn't particularly efficient
        logits = self.decode_pass(top_decoded_ids, encoder_outputs,
                                  encoder_decoder_attention_bias)
        return {  #"logits": tf.ones(shape=[tf.shape(top_decoded_ids)[0],
            #                         tf.shape(top_decoded_ids)[1],
            #                         self.params["tgt_vocab_size"]]),
            "logits": logits,
            "samples": [top_decoded_ids],
            "final_state": None,
            "final_sequence_lengths": None
        }
    def predict(self, encoder_outputs, encoder_outputs_b,
                inputs_attention_bias):
        """Return predicted sequence."""
        batch_size = tf.shape(encoder_outputs)[0]
        input_length = tf.shape(encoder_outputs)[1]

        max_decode_length = input_length + self.params["extra_decode_length"]

        symbols_to_logits_fn = self._get_symbols_to_logits_fn()

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size],
                               dtype=tf.int32) + self.params["GO_SYMBOL"]

        cache = {}
        # Add encoder outputs and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache["encoder_outputs_b"] = encoder_outputs_b
        if inputs_attention_bias is not None:
            cache["inputs_attention_bias"] = inputs_attention_bias

        # Use beam search to find the top beam_size sequences and scores.
        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["tgt_vocab_size"],
            beam_size=self.params["beam_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            eos_id=self.params["EOS_ID"])

        # Get the top sequence for each batch element
        top_decoded_ids = decoded_ids[:, 0, :]
        top_scores = scores[:, 0]

        # this isn't particularly efficient
        logits = self.decode_pass(top_decoded_ids, encoder_outputs,
                                  encoder_outputs_b, inputs_attention_bias)

        return {
            "logits": logits,
            "outputs": [top_decoded_ids],
            "final_state": None,
            "final_sequence_lengths": None
        }