def test_sequence_beam_search(self, padded_decode): # batch_size*beam_size, max_decode_length, vocab_size probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2], [0.1, 0.8, 0.1]], [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.2, 0.1, 0.7]]]) # batch_size, max_decode_length, num_heads, embed_size per head x = tf.zeros([1, 3, 2, 32], dtype=tf.float32) cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)} def _get_test_symbols_to_logits_fn(): """Test function that returns logits for next token.""" def symbols_to_logits_fn(_, i, cache): logits = tf.cast(probabilities[:, i, :], tf.float32) return logits, cache return symbols_to_logits_fn predictions, _ = beam_search.sequence_beam_search( symbols_to_logits_fn=_get_test_symbols_to_logits_fn(), initial_ids=tf.zeros([1], dtype=tf.int32), initial_cache=cache, vocab_size=3, beam_size=2, alpha=0.6, max_decode_length=3, eos_id=9, padded_decode=padded_decode, dtype=tf.float32) self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
def predict_decode(self, start_token_ids, cache): symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title) # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=start_token_ids, initial_cache=cache, vocab_size=self.params.vocab_size, beam_size=self.params.beam_size, alpha=self.params.alpha, max_decode_length=self.params.len_title, padded_decode=self.params.get("padded_decode", False), eos_id=self.params.end_token_id) return decoded_ids, scores
def call(self, inputs): """Calculate target logits or inferred target sequences. Args: inputs: a dictionary of tensors. Feature `inputs` (optional): int tensor with shape `[batch_size, input_length]`. Feature `embedded_inputs` (optional): float tensor with shape `[batch_size, input_length, embedding_width]`. Feature `targets` (optional): None or int tensor with shape `[batch_size, target_length]`. Feature `input_masks` (optional): When providing the `embedded_inputs`, the dictionary must provide a boolean mask marking the filled time steps. The shape of the tensor is `[batch_size, input_length]`. Either `inputs` or `embedded_inputs` and `input_masks` must be present in the input dictionary. In the second case the projection of the integer tokens to the transformer embedding space is skipped and `input_masks` is expected to be present. Returns: If targets is defined, then return logits for each word in the target sequence, which is a float tensor with shape `(batch_size, target_length, vocab_size)`. If target is `None`, then generate output sequence one token at a time and returns a dictionary { outputs: `(batch_size, decoded_length)` scores: `(batch_size, 1)`} Even when `float16` is used, the output tensor(s) are always `float32`. Raises: NotImplementedError: If try to use padded decode method on CPU/GPUs. """ # Prepare inputs to the layer stack by adding positional encodings and # applying dropout. targets = inputs.get("targets", None) (embedded_inputs, boolean_mask, input_shape, source_dtype) = self._parse_inputs(inputs) embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype) embedded_inputs *= tf.expand_dims(embedding_mask, -1) # Attention_mask generation. attention_mask = tf.cast(tf.reshape( boolean_mask, [input_shape[0], 1, input_shape[1]]), dtype=source_dtype) broadcast_ones = tf.ones(shape=[input_shape[0], input_shape[1], 1], dtype=source_dtype) attention_mask = broadcast_ones * attention_mask pos_encoding = self.position_embedding(embedded_inputs) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype) encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = self.encoder_dropout(encoder_inputs) encoder_outputs = self.encoder_layer(encoder_inputs, attention_mask=attention_mask) if targets is None: if self._padded_decode: max_decode_length = self._decode_max_length else: max_decode_length = self._decode_max_length or ( tf.shape(encoder_outputs)[1] + self._extra_decode_length) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length) batch_size = tf.shape(encoder_outputs)[0] # Create initial set of IDs that will be passed to symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. init_decode_length = (max_decode_length if self._padded_decode else 0) num_heads = self.decoder_layer.num_attention_heads dim_per_head = self._embedding_width // num_heads # Cache dtype needs to match beam_search dtype. # pylint: disable=g-complex-comprehension cache = { str(layer): { "key": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self.compute_dtype), "value": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self.compute_dtype) } for layer in range(self.decoder_layer.num_layers) } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype) attention_mask = tf.cast(tf.reshape( boolean_mask, [input_shape[0], 1, input_shape[1]]), dtype=self.compute_dtype) cache["encoder_outputs"] = encoder_outputs cache["encoder_decoder_attention_mask"] = attention_mask # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self._vocab_size, beam_size=self._beam_size, alpha=self._alpha, max_decode_length=max_decode_length, eos_id=self._eos_id, padded_decode=self._padded_decode, dtype=self.compute_dtype) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores} # Shift targets to the right, and remove the last element targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1] decoder_inputs = self.embedding_lookup(targets) length = tf.shape(decoder_inputs)[1] pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype) decoder_inputs += pos_encoding decoder_inputs = self.decoder_dropout(decoder_inputs) decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0) self_attention_mask = tf.reshape(self_attention_mask, [1, length, length]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) attention_mask = tf.cast(tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype) attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) outputs = self.decoder_layer(decoder_inputs, encoder_outputs, self_attention_mask=self_attention_mask, cross_attention_mask=attention_mask) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) # Model outputs should be float32 to avoid numeric issues. # https://www.tensorflow.org/guide/mixed_precision#building_the_model logits = tf.cast(logits, tf.float32) return logits
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): """Return predicted sequence.""" encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"]) if self.params["padded_decode"]: batch_size = encoder_outputs.shape.as_list()[0] input_length = encoder_outputs.shape.as_list()[1] else: batch_size = tf.shape(encoder_outputs)[0] input_length = tf.shape(encoder_outputs)[1] max_decode_length = input_length + self.params["extra_decode_length"] encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, self.params["dtype"]) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length, training) # Create initial set of IDs that will be passed into symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. # pylint: disable=g-complex-comprehension init_decode_length = (max_decode_length if self.params["padded_decode"] else 0) num_heads = self.params["num_heads"] dim_per_head = self.params["hidden_size"] // num_heads cache = { "layer_%d" % layer: { "k": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"]), "v": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=self.params["dtype"]) } for layer in range(self.params["num_hidden_layers"]) } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self.params["vocab_size"], beam_size=self.params["beam_size"], alpha=self.params["alpha"], max_decode_length=max_decode_length, eos_id=EOS_ID, padded_decode=self.params["padded_decode"], dtype=self.params["dtype"]) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores}
def call(self, inputs): """Calculate target logits or inferred target sequences. Args: inputs: input tensor list of size 1 or 2. First item, inputs: int tensor with shape [batch_size, input_length]. Second item (optional), targets: None or int tensor with shape [batch_size, target_length]. Returns: If targets is defined, then return logits for each word in the target sequence. float tensor with shape [batch_size, target_length, vocab_size] If target is none, then generate output sequence one token at a time. returns a dictionary { outputs: [batch_size, decoded length] scores: [batch_size, float]} Even when float16 is used, the output tensor(s) are always float32. Raises: NotImplementedError: If try to use padded decode method on CPU/GPUs. """ if len(inputs) == 2: sources, targets = inputs[0], inputs[1] else: # Decoding path. sources, targets = inputs[0], None attention_bias = model_utils.get_padding_bias(sources) attention_bias = tf.cast(attention_bias, self._dtype) # Prepare inputs to the layer stack by adding positional encodings and # applying dropout. embedded_inputs = self.embedding_lookup(sources) embedding_mask = tf.cast(tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype) embedded_inputs *= tf.expand_dims(embedding_mask, -1) embedded_inputs = tf.cast(embedded_inputs, self._dtype) # Attention_mask generation. input_shape = tf_utils.get_shape_list(sources, expected_rank=2) attention_mask = tf.cast(tf.reshape(tf.not_equal( sources, 0), [input_shape[0], 1, input_shape[1]]), dtype=sources.dtype) broadcast_ones = tf.ones(shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype) attention_mask = broadcast_ones * attention_mask pos_encoding = self.position_embedding(inputs=embedded_inputs) pos_encoding = tf.cast(pos_encoding, self._dtype) encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = self.encoder_dropout(encoder_inputs) encoder_outputs = self.encoder_layer(encoder_inputs, attention_mask=attention_mask) if targets is None: encoder_decoder_attention_bias = attention_bias encoder_outputs = tf.cast(encoder_outputs, self._dtype) if self._padded_decode: batch_size = encoder_outputs.shape.as_list()[0] max_decode_length = self._decode_max_length else: batch_size = tf.shape(encoder_outputs)[0] max_decode_length = self._decode_max_length or ( tf.shape(encoder_outputs)[1] + self._extra_decode_length) encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, self._dtype) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length) # Create initial set of IDs that will be passed to symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. # pylint: disable=g-complex-comprehension init_decode_length = (max_decode_length if self._padded_decode else 0) num_heads = self.decoder_layer.num_attention_heads dim_per_head = self._embedding_width // num_heads cache = { str(layer): { "key": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self._dtype), "value": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self._dtype) } for layer in range(self.decoder_layer.num_layers) } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self._vocab_size, beam_size=self._beam_size, alpha=self._alpha, max_decode_length=max_decode_length, eos_id=EOS_ID, padded_decode=self._padded_decode, dtype=self._dtype) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores} decoder_inputs = self.embedding_lookup(targets) embedding_mask = tf.cast(tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype) decoder_inputs *= tf.expand_dims(embedding_mask, -1) decoder_inputs = tf.cast(decoder_inputs, self._dtype) # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] length = tf.shape(decoder_inputs)[1] pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = tf.cast(pos_encoding, self._dtype) decoder_inputs += pos_encoding decoder_inputs = self.decoder_dropout(decoder_inputs) decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] self_attention_mask = tf.linalg.band_part( tf.ones([length, length], dtype=tf.float32), -1, 0) self_attention_mask = tf.reshape(self_attention_mask, [1, length, length]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) attention_mask = tf.cast(tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype) attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) outputs = self.decoder_layer(decoder_inputs, encoder_outputs, memory_mask=self_attention_mask, target_mask=attention_mask) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) logits = tf.cast(logits, tf.float32) return logits