def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = positional_encoding(max_decode_length + 1, self.params.hidden_size) decoder_self_attention_bias = get_target_mask(max_decode_length) def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_outputs = self.decoder_stack( decoder_input, features=cache.get("encoder_outputs"), target_mask=self_attention_bias, input_mask=cache.get("encoder_decoder_attention_bias"), cache=cache) logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn
def decode(self, targets, encoder_outputs, attention_bias): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] attention_bias: float tensor with shape [batch_size, 1, 1, input_length] training: boolean, whether in training mode or not. Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ with tf.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) with tf.name_scope("shift_targets"): # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] with tf.name_scope("add_pos_encoding"): length = tf.shape(decoder_inputs)[1] decoder_inputs += positional_encoding(length, self.params.hidden_size) if self.is_train: decoder_inputs = tf.nn.dropout(decoder_inputs, rate=1 - self.params.keep_prob) # Run values decoder_self_attention_bias = get_target_mask(length) outputs = self.decoder_stack( decoder_inputs, features=encoder_outputs, input_mask=attention_bias, target_mask=decoder_self_attention_bias, ) logits = self.embedding_softmax_layer(outputs, mode="linear") return logits