Example #1
0
    def predict(self, encoder_outputs, encoder_decoder_attention_bias,
                training):
        """Return predicted sequence."""
        batch_size = tf.shape(encoder_outputs)[0]
        input_length = tf.shape(encoder_outputs)[1]
        max_decode_length = input_length + self.params["extra_decode_length"]
        encoder_decoder_attention_bias = tf.cast(
            encoder_decoder_attention_bias, self.params["dtype"])

        symbols_to_logits_fn = self._get_symbols_to_logits_fn(
            max_decode_length, training)

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        # Create cache storing decoder attention values for each layer.
        # pylint: disable=g-complex-comprehension
        cache = {
            "layer_%d" % layer: {
                "k":
                tf.zeros([batch_size, 0, self.params["hidden_size"]],
                         dtype=self.params["dtype"]),
                "v":
                tf.zeros([batch_size, 0, self.params["hidden_size"]],
                         dtype=self.params["dtype"])
            }
            for layer in range(self.params["num_hidden_layers"])
        }
        # pylint: enable=g-complex-comprehension

        # Add encoder output and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        # Use beam search to find the top beam_size sequences and scores.
        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["vocab_size"],
            beam_size=self.params["beam_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            eos_id=EOS_ID,
            dtype=self.params["dtype"])

        # Get the top sequence for each batch element
        top_decoded_ids = decoded_ids[:, 0, 1:]
        top_scores = scores[:, 0]

        return {"outputs": top_decoded_ids, "scores": top_scores}
Example #2
0
    def predict(self, encoder_outputs, encoder_decoder_attention_bias,
                training):
        """Return predicted sequence."""
        encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
        if self.params["padded_decode"]:
            batch_size = encoder_outputs.shape.as_list()[0]
            input_length = encoder_outputs.shape.as_list()[1]
        else:
            batch_size = tf.shape(encoder_outputs)[0]
            input_length = tf.shape(encoder_outputs)[1]
        max_decode_length = input_length + self.params["extra_decode_length"]
        encoder_decoder_attention_bias = tf.cast(
            encoder_decoder_attention_bias, self.params["dtype"])

        symbols_to_logits_fn = self._get_symbols_to_logits_fn(
            max_decode_length, training)

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        # Create cache storing decoder attention values for each layer.
        # pylint: disable=g-complex-comprehension
        init_decode_length = (max_decode_length
                              if self.params["padded_decode"] else 0)
        num_heads = self.params["num_heads"]
        dim_per_head = self.params["hidden_size"] // num_heads
        ### old
        # cache = {
        #     "layer_%d" % layer: {
        #         "k":
        #             tf.zeros([
        #                 batch_size, init_decode_length, num_heads, dim_per_head
        #             ],
        #                      dtype=self.params["dtype"]),
        #         "v":
        #             tf.zeros([
        #                 batch_size, init_decode_length, num_heads, dim_per_head
        #             ],
        #                      dtype=self.params["dtype"])
        #     } for layer in range(self.params["num_hidden_layers"])
        # }

        cache = {
            "layer_0": {
                "k":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"],
                    name='k0'),
                "v":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"],
                    name='v0')
            },
            "layer_1": {
                "k":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"],
                    name='k1'),
                "v":
                tf.zeros(
                    [batch_size, init_decode_length, num_heads, dim_per_head],
                    dtype=self.params["dtype"],
                    name='v1')
            }
        }

        # pylint: enable=g-complex-comprehension

        # Add encoder output and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        # Use beam search to find the top beam_size sequences and scores.
        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["vocab_size"],
            beam_size=self.params["beam_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            eos_id=EOS_ID,
            padded_decode=self.params["padded_decode"],
            dtype=self.params["dtype"])

        # Get the top sequence for each batch element
        # top_decoded_ids = decoded_ids[:, 0, 1:]
        # top_scores = scores[:, 0]
        ## get all beam
        top_decoded_ids = decoded_ids[:, :, 1:]
        top_scores = scores[:, :]

        return {"outputs": top_decoded_ids, "scores": top_scores}