def call(self, inputs: list, training: bool) -> Any: if len(inputs) == 3: inputs, memories, targets = inputs[0], inputs[1], inputs[2] else: # Decoding path. inputs, memories, targets = inputs[0], inputs[1], None # Variance scaling is used here because it seems to work in many problems. # Other reasonable initializers may also work just as well. with tf.name_scope("MemoryTransformer"): # Calculate attention bias for encoder self-attention and decoder # multi-headed attention layers. encoder_attention_bias = model_utils.get_padding_bias(inputs) # Run the inputs through the encoder layer to map the symbol # representations to continuous representations. encoder_outputs = self.encode(inputs, encoder_attention_bias, training) reminder_outputs = self.remind(memories, encoder_outputs, encoder_attention_bias, training) reminder_attention_bias = tf.zeros([1, 1, 1, 1], dtype=self.params["dtype"]) # Generate output sequence if targets is None, or return logits if target # sequence is known. if targets is None: return self.predict(reminder_outputs, reminder_attention_bias, training) else: logits = self.decode(targets, reminder_outputs, reminder_attention_bias, training) return logits, reminder_outputs
def test_get_padding_bias(self): x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]]) bias = model_utils.get_padding_bias(x) bias_shape = tf.shape(bias) flattened_bias = tf.reshape(bias, [3, 5]) self.assertAllEqual( [[0, NEG_INF, NEG_INF, NEG_INF, 0], [0, 0, NEG_INF, NEG_INF, NEG_INF], [NEG_INF, 0, 0, NEG_INF, 0]], flattened_bias) self.assertAllEqual([3, 1, 1, 5], bias_shape)
def call(self, inputs, training): """Calculate target logits or inferred target sequences. Args: inputs: input tensor list of size 1 or 2. First item, inputs: int tensor with shape [batch_size, input_length]. Second item (optional), targets: None or int tensor with shape [batch_size, target_length]. training: boolean, whether in training mode or not. Returns: If targets is defined, then return logits for each word in the target sequence. float tensor with shape [batch_size, target_length, vocab_size] If target is none, then generate output sequence one token at a time. returns a dictionary { outputs: [batch_size, decoded length] scores: [batch_size, float]} Even when float16 is used, the output tensor(s) are always float32. Raises: NotImplementedError: If try to use padded decode method on CPU/GPUs. """ inputs = inputs if isinstance(inputs, list) else [inputs] if len(inputs) == 2: inputs, targets = inputs[0], inputs[1] else: # Decoding path. inputs, targets = inputs[0], None if self.params["padded_decode"]: if not self.params["num_replicas"]: raise NotImplementedError( "Padded decoding on CPU/GPUs is not supported.") decode_batch_size = int(self.params["decode_batch_size"] / self.params["num_replicas"]) inputs.set_shape( [decode_batch_size, self.params["decode_max_length"]]) # Variance scaling is used here because it seems to work in many problems. # Other reasonable initializers may also work just as well. with tf.name_scope("Transformer"): # Calculate attention bias for encoder self-attention and decoder # multi-headed attention layers. attention_bias = model_utils.get_padding_bias(inputs) # Run the inputs through the encoder layer to map the symbol # representations to continuous representations. encoder_outputs = self.encode(inputs, attention_bias, training) # Generate output sequence if targets is None, or return logits if target # sequence is known. if targets is None: return self.predict(encoder_outputs, attention_bias, training) else: logits = self.decode(targets, encoder_outputs, attention_bias, training) return logits
def __call__(self, inputs, targets=None): """Calculate target logits or inferred target sequences. Args: inputs: int tensor with shape [batch_size, input_length]. targets: None or int tensor with shape [batch_size, target_length]. Returns: If targets is defined, then return logits for each word in the target sequence. float tensor with shape [batch_size, target_length, vocab_size] If target is none, then generate output sequence one token at a time. returns a dictionary { output: [batch_size, decoded length] score: [batch_size, float]} """ # Variance scaling is used here because it seems to work in many problems. # Other reasonable initializers may also work just as well. initializer = tf.variance_scaling_initializer( self.params["initializer_gain"], mode="fan_avg", distribution="uniform") with tf.variable_scope("Transformer", initializer=initializer): # Calculate attention bias for encoder self-attention and decoder # multi-headed attention layers. attention_bias = model_utils.get_padding_bias(inputs) # Run the inputs through the encoder layer to map the symbol # representations to continuous representations. encoder_outputs = self.encode(inputs, attention_bias) # Generate output sequence if targets is None, or return logits if target # sequence is known. if targets is None: return self.predict(encoder_outputs, attention_bias) else: logits = self.decode(targets, encoder_outputs, attention_bias) return logits
def get_attention_bias(input_tensor, bias_type, padding_value=0, max_length=None): """A helper function to get various attention bias tensors.""" if bias_type not in ("single_cross", "multi_cross", "decoder_self"): raise ValueError("Invalid attention bias type: %s" % bias_type) if bias_type == "single_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_padding_bias(input_tensor, padding_value=padding_value) elif bias_type == "multi_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2] padding = transformer_utils.get_padding(input_tensor, padding_value=padding_value) bias = padding * -1e9 else: if max_length is not None: length = max_length else: length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_decoder_self_attention_bias(length) return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias))
def call(self, inputs): """Calculate target logits or inferred target sequences. Args: inputs: input tensor list of size 1 or 2. First item, inputs: int tensor with shape [batch_size, input_length]. Second item (optional), targets: None or int tensor with shape [batch_size, target_length]. Returns: If targets is defined, then return logits for each word in the target sequence. float tensor with shape [batch_size, target_length, vocab_size] If target is none, then generate output sequence one token at a time. returns a dictionary { outputs: [batch_size, decoded length] scores: [batch_size, float]} Even when float16 is used, the output tensor(s) are always float32. Raises: NotImplementedError: If try to use padded decode method on CPU/GPUs. """ if len(inputs) == 2: sources, targets = inputs[0], inputs[1] else: # Decoding path. sources, targets = inputs[0], None attention_bias = model_utils.get_padding_bias(sources) attention_bias = tf.cast(attention_bias, self._dtype) # Prepare inputs to the layer stack by adding positional encodings and # applying dropout. embedded_inputs = self.embedding_lookup(sources) embedding_mask = tf.cast(tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype) embedded_inputs *= tf.expand_dims(embedding_mask, -1) embedded_inputs = tf.cast(embedded_inputs, self._dtype) # Attention_mask generation. input_shape = tf_utils.get_shape_list(sources, expected_rank=2) attention_mask = tf.cast(tf.reshape(tf.not_equal( sources, 0), [input_shape[0], 1, input_shape[1]]), dtype=sources.dtype) broadcast_ones = tf.ones(shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype) attention_mask = broadcast_ones * attention_mask pos_encoding = self.position_embedding(inputs=embedded_inputs) pos_encoding = tf.cast(pos_encoding, self._dtype) encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = self.encoder_dropout(encoder_inputs) encoder_outputs = self.encoder_layer(encoder_inputs, attention_mask=attention_mask) if targets is None: encoder_decoder_attention_bias = attention_bias encoder_outputs = tf.cast(encoder_outputs, self._dtype) if self._padded_decode: batch_size = encoder_outputs.shape.as_list()[0] max_decode_length = self._decode_max_length else: batch_size = tf.shape(encoder_outputs)[0] max_decode_length = self._decode_max_length or ( tf.shape(encoder_outputs)[1] + self._extra_decode_length) encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, self._dtype) symbols_to_logits_fn = self._get_symbols_to_logits_fn( max_decode_length) # Create initial set of IDs that will be passed to symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. # pylint: disable=g-complex-comprehension init_decode_length = (max_decode_length if self._padded_decode else 0) num_heads = self.decoder_layer.num_attention_heads dim_per_head = self._embedding_width // num_heads cache = { str(layer): { "key": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self._dtype), "value": tf.zeros([ batch_size, init_decode_length, num_heads, dim_per_head ], dtype=self._dtype) } for layer in range(self.decoder_layer.num_layers) } # pylint: enable=g-complex-comprehension # Add encoder output and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self._vocab_size, beam_size=self._beam_size, alpha=self._alpha, max_decode_length=max_decode_length, eos_id=EOS_ID, padded_decode=self._padded_decode, dtype=self._dtype) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores} decoder_inputs = self.embedding_lookup(targets) embedding_mask = tf.cast(tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype) decoder_inputs *= tf.expand_dims(embedding_mask, -1) decoder_inputs = tf.cast(decoder_inputs, self._dtype) # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] length = tf.shape(decoder_inputs)[1] pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = tf.cast(pos_encoding, self._dtype) decoder_inputs += pos_encoding decoder_inputs = self.decoder_dropout(decoder_inputs) decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] self_attention_mask = tf.linalg.band_part( tf.ones([length, length], dtype=tf.float32), -1, 0) self_attention_mask = tf.reshape(self_attention_mask, [1, length, length]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) attention_mask = tf.cast(tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype) attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) outputs = self.decoder_layer(decoder_inputs, encoder_outputs, memory_mask=self_attention_mask, target_mask=attention_mask) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) logits = tf.cast(logits, tf.float32) return logits