def decode_at_train(self, target_ids, enc_output, cross_attn_mask): """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder; performs decoding in parallel at training time. """ def _decode_all(target_embeddings): """ Decodes the encoder-generated representations into target-side logits in parallel. """ # Apply input dropout dec_input = \ tf.layers.dropout(target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) # Propagate inputs through the encoder stack dec_output = dec_input for layer_id in range(1, self.config.transformer_dec_depth + 1): dec_output, _ = self.decoder_stack[layer_id][ 'self_attn'].forward(dec_output, None, self_attn_mask) dec_output, _ = \ self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask) dec_output = self.decoder_stack[layer_id]['ffn'].forward( dec_output) return dec_output def _prepare_targets(): """ Pre-processes target token ids before they're passed on as input to the decoder for parallel decoding. """ # Embed target_ids target_embeddings = self._embed(target_ids) target_embeddings += positional_signal if self.config.transformer_dropout_embeddings > 0: target_embeddings = tf.layers.dropout( target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) return target_embeddings def _decoding_function(): """ Generates logits for target-side tokens. """ # Embed the model's predictions up to the current time-step; add positional information, mask target_embeddings = _prepare_targets() # Pass encoder context and decoder embeddings through the decoder dec_output = _decode_all(target_embeddings) # Project decoder stack outputs and apply the soft-max non-linearity full_logits = self.softmax_projection_layer.project(dec_output) return full_logits with tf.variable_scope(self.name): # Transpose encoder information in hybrid models if self.from_rnn: enc_output = tf.transpose(enc_output, [1, 0, 2]) cross_attn_mask = tf.transpose(cross_attn_mask, [3, 1, 2, 0]) self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1]) positional_signal = get_positional_signal( tf.shape(target_ids)[-1], self.config.embedding_size, FLOAT_DTYPE) logits = _decoding_function() return logits
def decode_at_train(self, target_ids, enc_output, cross_attn_mask): """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder; performs decoding in parallel at training time. """ def _decode_all(target_embeddings): """ Decodes the encoder-generated representations into target-side logits in parallel. """ # Apply input dropout dec_input = \ tf.layers.dropout(target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) # Propagate inputs through the encoder stack dec_output = dec_input for layer_id in range(1, self.config.transformer_dec_depth + 1): dec_output, _ = self.decoder_stack[layer_id]['self_attn'].forward(dec_output, None, self_attn_mask) dec_output, _ = \ self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask) dec_output = self.decoder_stack[layer_id]['ffn'].forward(dec_output) return dec_output def _prepare_targets(): """ Pre-processes target token ids before they're passed on as input to the decoder for parallel decoding. """ # Embed target_ids target_embeddings = self._embed(target_ids) target_embeddings += positional_signal if self.config.transformer_dropout_embeddings > 0: target_embeddings = tf.layers.dropout(target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training) return target_embeddings def _decoding_function(): """ Generates logits for target-side tokens. """ # Embed the model's predictions up to the current time-step; add positional information, mask target_embeddings = _prepare_targets() # Pass encoder context and decoder embeddings through the decoder dec_output = _decode_all(target_embeddings) # Project decoder stack outputs and apply the soft-max non-linearity full_logits = self.softmax_projection_layer.project(dec_output) return full_logits with tf.variable_scope(self.name): # Transpose encoder information in hybrid models if self.from_rnn: enc_output = tf.transpose(enc_output, [1, 0, 2]) cross_attn_mask = tf.transpose(cross_attn_mask, [3, 1, 2, 0]) self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1]) positional_signal = get_positional_signal(tf.shape(target_ids)[-1], self.config.embedding_size, self.float_dtype) logits = _decoding_function() return logits