Ejemplo n.º 1
0
    def decode_at_train(self, target_ids, enc_output, cross_attn_mask):
        """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder;
         performs decoding in parallel at training time. """
        def _decode_all(target_embeddings):
            """ Decodes the encoder-generated representations into target-side logits in parallel. """
            # Apply input dropout
            dec_input = \
                tf.layers.dropout(target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training)
            # Propagate inputs through the encoder stack
            dec_output = dec_input
            for layer_id in range(1, self.config.transformer_dec_depth + 1):
                dec_output, _ = self.decoder_stack[layer_id][
                    'self_attn'].forward(dec_output, None, self_attn_mask)
                dec_output, _ = \
                    self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask)
                dec_output = self.decoder_stack[layer_id]['ffn'].forward(
                    dec_output)
            return dec_output

        def _prepare_targets():
            """ Pre-processes target token ids before they're passed on as input to the decoder
            for parallel decoding. """
            # Embed target_ids
            target_embeddings = self._embed(target_ids)
            target_embeddings += positional_signal
            if self.config.transformer_dropout_embeddings > 0:
                target_embeddings = tf.layers.dropout(
                    target_embeddings,
                    rate=self.config.transformer_dropout_embeddings,
                    training=self.training)
            return target_embeddings

        def _decoding_function():
            """ Generates logits for target-side tokens. """
            # Embed the model's predictions up to the current time-step; add positional information, mask
            target_embeddings = _prepare_targets()
            # Pass encoder context and decoder embeddings through the decoder
            dec_output = _decode_all(target_embeddings)
            # Project decoder stack outputs and apply the soft-max non-linearity
            full_logits = self.softmax_projection_layer.project(dec_output)
            return full_logits

        with tf.variable_scope(self.name):
            # Transpose encoder information in hybrid models
            if self.from_rnn:
                enc_output = tf.transpose(enc_output, [1, 0, 2])
                cross_attn_mask = tf.transpose(cross_attn_mask, [3, 1, 2, 0])

            self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1])
            positional_signal = get_positional_signal(
                tf.shape(target_ids)[-1], self.config.embedding_size,
                FLOAT_DTYPE)
            logits = _decoding_function()
        return logits
Ejemplo n.º 2
0
    def decode_at_train(self, target_ids, enc_output, cross_attn_mask):
        """ Returns the probability distribution over target-side tokens conditioned on the output of the encoder;
         performs decoding in parallel at training time. """

        def _decode_all(target_embeddings):
            """ Decodes the encoder-generated representations into target-side logits in parallel. """
            # Apply input dropout
            dec_input = \
                tf.layers.dropout(target_embeddings, rate=self.config.transformer_dropout_embeddings, training=self.training)
            # Propagate inputs through the encoder stack
            dec_output = dec_input
            for layer_id in range(1, self.config.transformer_dec_depth + 1):
                dec_output, _ = self.decoder_stack[layer_id]['self_attn'].forward(dec_output, None, self_attn_mask)
                dec_output, _ = \
                    self.decoder_stack[layer_id]['cross_attn'].forward(dec_output, enc_output, cross_attn_mask)
                dec_output = self.decoder_stack[layer_id]['ffn'].forward(dec_output)
            return dec_output

        def _prepare_targets():
            """ Pre-processes target token ids before they're passed on as input to the decoder
            for parallel decoding. """
            # Embed target_ids
            target_embeddings = self._embed(target_ids)
            target_embeddings += positional_signal
            if self.config.transformer_dropout_embeddings > 0:
                target_embeddings = tf.layers.dropout(target_embeddings,
                                                      rate=self.config.transformer_dropout_embeddings, training=self.training)
            return target_embeddings

        def _decoding_function():
            """ Generates logits for target-side tokens. """
            # Embed the model's predictions up to the current time-step; add positional information, mask
            target_embeddings = _prepare_targets()
            # Pass encoder context and decoder embeddings through the decoder
            dec_output = _decode_all(target_embeddings)
            # Project decoder stack outputs and apply the soft-max non-linearity
            full_logits = self.softmax_projection_layer.project(dec_output)
            return full_logits

        with tf.variable_scope(self.name):
            # Transpose encoder information in hybrid models
            if self.from_rnn:
                enc_output = tf.transpose(enc_output, [1, 0, 2])
                cross_attn_mask = tf.transpose(cross_attn_mask, [3, 1, 2, 0])

            self_attn_mask = get_right_context_mask(tf.shape(target_ids)[-1])
            positional_signal = get_positional_signal(tf.shape(target_ids)[-1],
                                                      self.config.embedding_size,
                                                      self.float_dtype)
            logits = _decoding_function()
        return logits