def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): """Return predicted sequence.""" batch_size = tf.shape(encoder_outputs)[0] input_length = tf.shape(encoder_outputs)[1] max_decode_length = input_length + self.params["extra_decode_length"] encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, self.params["dtype"]) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length, training) # Create initial set of IDs that will be passed into symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. # pylint: disable=g-complex-comprehension cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, self.params["hidden_size"]], dtype=self.params["dtype"]), "v": tf.zeros([batch_size, 0, self.params["hidden_size"]], dtype=self.params["dtype"]) } for layer in range(self.params["num_hidden_layers"]) } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self.params["vocab_size"], beam_size=self.params["beam_size"], alpha=self.params["alpha"], max_decode_length=max_decode_length, eos_id=EOS_ID, dtype=self.params["dtype"]) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores}
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): """Return predicted sequence.""" encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"]) if self.params["padded_decode"]: batch_size = encoder_outputs.shape.as_list()[0] input_length = encoder_outputs.shape.as_list()[1] else: batch_size = tf.shape(encoder_outputs)[0] input_length = tf.shape(encoder_outputs)[1] max_decode_length = input_length + self.params["extra_decode_length"] encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, self.params["dtype"]) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length, training) # Create initial set of IDs that will be passed into symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. # pylint: disable=g-complex-comprehension init_decode_length = (max_decode_length if self.params["padded_decode"] else 0) num_heads = self.params["num_heads"] dim_per_head = self.params["hidden_size"] // num_heads ### old # cache = { # "layer_%d" % layer: { # "k": # tf.zeros([ # batch_size, init_decode_length, num_heads, dim_per_head # ], # dtype=self.params["dtype"]), # "v": # tf.zeros([ # batch_size, init_decode_length, num_heads, dim_per_head # ], # dtype=self.params["dtype"]) # } for layer in range(self.params["num_hidden_layers"]) # } cache = { "layer_0": { "k": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"], name='k0'), "v": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"], name='v0') }, "layer_1": { "k": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"], name='k1'), "v": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"], name='v1') } } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self.params["vocab_size"], beam_size=self.params["beam_size"], alpha=self.params["alpha"], max_decode_length=max_decode_length, eos_id=EOS_ID, padded_decode=self.params["padded_decode"], dtype=self.params["dtype"]) # Get the top sequence for each batch element # top_decoded_ids = decoded_ids[:, 0, 1:] # top_scores = scores[:, 0] ## get all beam top_decoded_ids = decoded_ids[:, :, 1:] top_scores = scores[:, :] return {"outputs": top_decoded_ids, "scores": top_scores}