def _get_symbols_to_logits_fn(self, max_decode_length, training): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = model_utils.get_position_encoding( max_decode_length + 1, self.params["hidden_size"]) timing_signal = tf.cast(timing_signal, self.params["dtype"]) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length, dtype=self.params["dtype"]) # TODO(b/139770046): Refactor code with better naming of i. def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1]. i: Loop index. cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) if self.params["padded_decode"]: timing_signal_shape = timing_signal.shape.as_list() decoder_input += tf.slice(timing_signal, [i, 0], [1, timing_signal_shape[1]]) bias_shape = decoder_self_attention_bias.shape.as_list() self_attention_bias = tf.slice( decoder_self_attention_bias, [0, 0, i, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) else: decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_outputs = self.decoder_stack( decoder_input, cache.get("encoder_outputs"), self_attention_bias, cache.get("encoder_decoder_attention_bias"), training=training, cache=cache, decode_loop_step=i if self.params["padded_decode"] else None) logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn
def encode(self, inputs, attention_bias): """Generate continuous representation for inputs. Args: inputs: int tensor with shape [batch_size, input_length]. attention_bias: float tensor with shape [batch_size, 1, 1, input_length] Returns: float tensor with shape [batch_size, input_length, hidden_size] """ with tf.name_scope("encode"): # Prepare inputs to the layer stack by adding positional encodings and # applying dropout. embedded_inputs = self.embedding_softmax_layer(inputs) inputs_padding = model_utils.get_padding(inputs) with tf.name_scope("add_pos_encoding"): length = tf.shape(embedded_inputs)[1] pos_encoding = model_utils.get_position_encoding( length, self.params["hidden_size"]) encoder_inputs = embedded_inputs + pos_encoding if self.train: encoder_inputs = tf.nn.dropout( encoder_inputs, 1 - self.params["layer_postprocess_dropout"]) return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)
def decode(self, targets, encoder_outputs, attention_bias, training): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] attention_bias: float tensor with shape [batch_size, 1, 1, input_length] training: boolean, whether in training mode or not. Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ with tf.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"]) attention_bias = tf.cast(attention_bias, self.params["dtype"]) with tf.name_scope("shift_targets"): # Shift targets to the right, and remove the last element decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] with tf.name_scope("add_pos_encoding"): length = tf.shape(decoder_inputs)[1] pos_encoding = model_utils.get_position_encoding( length, self.params["hidden_size"]) pos_encoding = tf.cast(pos_encoding, self.params["dtype"]) decoder_inputs += pos_encoding if training: decoder_inputs = tf.nn.dropout( decoder_inputs, rate=self.params["layer_postprocess_dropout"]) # Run values decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( length, dtype=self.params["dtype"]) outputs = self.decoder_stack(decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias, training=training) logits = self.embedding_softmax_layer(outputs, mode="linear") logits = tf.cast(logits, tf.float32) return logits
def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = model_utils.get_position_encoding( max_decode_length + 1, self.params["hidden_size"]) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length) def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] decoder_outputs = self.decoder_stack( decoder_input, cache.get("encoder_outputs"), self_attention_bias, cache.get("encoder_decoder_attention_bias"), cache) logits = self.embedding_softmax_layer.linear(decoder_outputs) logits = tf.squeeze(logits, axis=[1]) return logits, cache return symbols_to_logits_fn