def _get_symbols_to_logits_fn(self, max_decode_length, training): """Returns a decoding function that calculates logits of the next tokens.""" pos_layer = position_embedding.RelativePositionEmbedding( hidden_size=self.params["hidden_size"], length=max_decode_length + 1) timing_signal = pos_layer(None) timing_signal = tf.cast(timing_signal, self.params["dtype"]) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length, dtype=self.params["dtype"]) # TODO(b/139770046): Refactor code with better naming of i. def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1]. i: Loop index. cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) if self.params["padded_decode"]: timing_signal_shape = timing_signal.shape.as_list() decoder_input += tf.slice(timing_signal, [i, 0], [1, timing_signal_shape[1]]) bias_shape = decoder_self_attention_bias.shape.as_list() self_attention_bias = tf.slice( decoder_self_attention_bias, [0, 0, i, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) else: decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_outputs = self.decoder_stack( decoder_input, cache.get("encoder_outputs"), self_attention_bias, cache.get("encoder_decoder_attention_bias"), training=training, cache=cache, decode_loop_step=i if self.params["padded_decode"] else None) logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn
def decode(self, targets: tf.Tensor, reminder_outputs: tf.Tensor, attention_bias: tf.Tensor, training: bool) -> tf.Tensor: with tf.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"]) attention_bias = tf.cast(attention_bias, self.params["dtype"]) with tf.name_scope("add_pos_encoding"): length = tf.shape(decoder_inputs)[1] pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = tf.cast(pos_encoding, self.params["dtype"]) decoder_inputs += pos_encoding if training: decoder_inputs = tf.nn.dropout( decoder_inputs, rate=self.params["layer_postprocess_dropout"]) # Run values decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( length, dtype=self.params["dtype"]) outputs = self.decoder_stack(decoder_inputs, reminder_outputs, decoder_self_attention_bias, attention_bias, training=training) logits = self.embedding_softmax_layer(outputs, mode="linear") logits = tf.cast(logits, tf.float32) return logits
def test_get_decoder_self_attention_bias(self): length = 5 bias = model_utils.get_decoder_self_attention_bias(length) self.assertAllEqual( [[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF], [0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, 0, NEG_INF, NEG_INF], [0, 0, 0, 0, NEG_INF], [0, 0, 0, 0, 0]]]], bias)
def decode(self, targets, encoder_outputs, attention_bias, training): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] attention_bias: float tensor with shape [batch_size, 1, 1, input_length] training: boolean, whether in training mode or not. Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ with tf.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"]) attention_bias = tf.cast(attention_bias, self.params["dtype"]) with tf.name_scope("shift_targets"): # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] with tf.name_scope("add_pos_encoding"): length = tf.shape(decoder_inputs)[1] pos_layer = position_embedding.RelativePositionEmbedding( hidden_size=self.params["hidden_size"]) pos_encoding = pos_layer(decoder_inputs) pos_encoding = tf.cast(pos_encoding, self.params["dtype"]) decoder_inputs += pos_encoding if training: decoder_inputs = tf.nn.dropout( decoder_inputs, rate=self.params["layer_postprocess_dropout"]) # Run values decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( length, dtype=self.params["dtype"]) outputs = self.decoder_stack( decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias, training=training) logits = self.embedding_softmax_layer(outputs, mode="linear") logits = tf.cast(logits, tf.float32) return logits
def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = model_utils.get_position_encoding( max_decode_length + 1, self.params["hidden_size"]) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length) def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_outputs = self.decoder_stack( decoder_input, cache.get("encoder_outputs"), self_attention_bias, cache.get("encoder_decoder_attention_bias"), cache) logits = self.embedding_softmax_layer.linear(decoder_outputs) logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn
def decode(self, targets, encoder_outputs, attention_bias): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] attention_bias: float tensor with shape [batch_size, 1, 1, input_length] Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ with tf.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) with tf.name_scope("shift_targets"): # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] with tf.name_scope("add_pos_encoding"): length = tf.shape(decoder_inputs)[1] decoder_inputs += model_utils.get_position_encoding( length, self.params["hidden_size"]) if self.train: decoder_inputs = tf.nn.dropout( decoder_inputs, 1 - self.params["layer_postprocess_dropout"]) # Run values decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( length) outputs = self.decoder_stack(decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias) logits = self.embedding_softmax_layer.linear(outputs) return logits
def get_attention_bias(input_tensor, bias_type, padding_value=0, max_length=None): """A helper function to get various attention bias tensors.""" if bias_type not in ("single_cross", "multi_cross", "decoder_self"): raise ValueError("Invalid attention bias type: %s" % bias_type) if bias_type == "single_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_padding_bias(input_tensor, padding_value=padding_value) elif bias_type == "multi_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2] padding = transformer_utils.get_padding(input_tensor, padding_value=padding_value) bias = padding * -1e9 else: if max_length is not None: length = max_length else: length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_decoder_self_attention_bias(length) return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias))
def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = self.position_embedding(inputs=None, length=max_decode_length + 1) timing_signal = tf.cast(timing_signal, self._dtype) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length, dtype=self._dtype) def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1]. i: Loop index. cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. # decoder_input = self.embedding_softmax_layer(decoder_input) source_decoder_input = decoder_input decoder_input = self.embedding_lookup(decoder_input) embedding_mask = tf.cast(tf.not_equal(source_decoder_input, 0), self.embedding_lookup.embeddings.dtype) decoder_input *= tf.expand_dims(embedding_mask, -1) if self._padded_decode: timing_signal_shape = timing_signal.shape.as_list() decoder_input += tf.slice(timing_signal, [i, 0], [1, timing_signal_shape[1]]) bias_shape = decoder_self_attention_bias.shape.as_list() self_attention_bias = tf.slice( decoder_self_attention_bias, [0, 0, i, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) else: decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] attention_bias = cache.get("encoder_decoder_attention_bias") attention_bias = tf.where(attention_bias < 0, tf.zeros_like(attention_bias), tf.ones_like(attention_bias)) attention_bias = tf.squeeze(attention_bias, axis=[1]) attention_mask = tf.tile(attention_bias, [1, decoder_length, 1]) self_attention_bias = tf.where(self_attention_bias < 0, tf.zeros_like(self_attention_bias), tf.ones_like(self_attention_bias)) self_attention_bias = tf.squeeze(self_attention_bias, axis=[1]) self_attention_mask = tf.tile(self_attention_bias, [batch_size, 1, 1]) decoder_outputs = self.decoder_layer( decoder_input, cache.get("encoder_outputs"), memory_mask=self_attention_mask, target_mask=attention_mask, cache=cache, decode_loop_step=i if self._padded_decode else None) logits = self._embedding_linear(self.embedding_lookup.embeddings, decoder_outputs) logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn