Пример #1
0
class TransformerDecoderLayer(IncrementalModule):
    """
    Wraps multi-head self-attention, encoder-decoder attention and position-wise
    feed forward into one layer of decoder

    Layers:
        (1)
         Layer norm
         Multi-head self-attention
         Dropout
         Residual with (1)
         (2)
         Layer norm
         Multi-head query-context attention
         Dropout
         Residual with (2)
         (3)
         Layer norm
         Feed-forward
         Dropout
         Residual with (3)

    Feed-Forward:
        Configurable between linear -> ReLU -> linear and Maxout

    Args:
        model_dim:            dimension of model
        num_heads:            number of heads
        feed_forward_dim:     dimension of feed forward
        feed_forward_dropout: dropout probability in the feed forward
        attention_dropout:    dropout probability in attention
        residual_dropout:     dropout probability for the residual layers
        weight_norm:          whether to use weight normalization on the feed forward layers
        masked_layers:        whether to use masking for layer norm and feed forward. Useful for sparse masks
        gated_residuals:      whether to use gated residuals
        batch_first:          whether input (and output) should be batch dimension first or sequence
                              length dimension first
        feed_forward_type:    Which type of feed forward to use. Currently supports 'linear_relu_linear'
                              and 'maxout'
        ignore_context:       If True, do not use the context input at all
        encoder_to_share:     Instance of TransformerEncoderLayer to share parameters with

    Input Shapes:
        inputs:              len_query x batch_size x model_dim  or  batch_size x len_query x model_dim
        context:             len_context x batch_size x model_dim  or  batch_size x len_context x model_dim
        input_mask:          batch_size x len_query  or  len_query x batch_size
        context_mask:        batch_size x len_context  or  len_context x batch_size
        self_attention_mask: batch_size x len_query x len_query or broadcastable, regardless of batch_first

    Output Shapes:
        out:      len_query x batch_size x model_dim  or  len_query x batch_size x model_dim
    """

    _version = 2

    def __init__(self,
                 *,
                 model_dim=512,
                 num_heads=8,
                 feed_forward_dim=2048,
                 feed_forward_dropout=0.1,
                 attention_dropout=0.1,
                 residual_dropout=0.1,
                 weight_norm=False,
                 masked_layers=False,
                 gated_residuals=False,
                 batch_first=False,
                 feed_forward_type='linear_relu_linear',
                 ignore_context=False,
                 encoder_to_share=None):
        super().__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.feed_forward_dim = feed_forward_dim
        self.feed_forward_dropout = feed_forward_dropout
        self.attention_dropout = attention_dropout
        self.residual_dropout = residual_dropout
        self.weight_norm = weight_norm
        self.masked_layers = masked_layers
        self.gated_residuals = gated_residuals
        self.batch_first = batch_first
        self.feed_forward_type = feed_forward_type
        self.ignore_context = ignore_context

        if encoder_to_share is None:
            self.build_self_attention()
            self.build_feed_forward()
        else:
            # share the self-attention layers between encoder and decoder
            self.share_feed_forward(encoder_to_share)
            self.share_self_attention(encoder_to_share)

        if not ignore_context:
            self.build_encoder_attention()

        self._register_load_state_dict_pre_hook(self._update_names)

    def get_preprocessing_module(self):
        return PrePostProcessing(self.model_dim,
                                 'n',
                                 masking=self.masked_layers)

    def get_postprocessing_module(self):
        return PrePostProcessing(self.model_dim,
                                 'da',
                                 self.residual_dropout,
                                 gated_residuals=self.gated_residuals)

    # noinspection PyAttributeOutsideInit
    def build_self_attention(self):
        self.preprocess_self_attn = self.get_preprocessing_module()
        self.self_attention = MultiHeadAttention(
            self.model_dim,
            self.num_heads,
            self.attention_dropout,
            masked_layers=self.masked_layers,
            batch_first=self.batch_first)
        self.postprocess_self_attn = self.get_postprocessing_module()

    # noinspection PyAttributeOutsideInit
    def share_self_attention(self, encoder):
        self.preprocess_self_attn = encoder.preprocess_attn
        self.postprocess_self_attn = encoder.postprocess_attn
        self.self_attention = encoder.attention

    def self_attention_layer(self,
                             inputs,
                             input_mask=None,
                             self_attention_bias=None):
        query = self.preprocess_self_attn(inputs, mask=input_mask)
        self_attention_out, _ = self.self_attention(query, query, query,
                                                    self_attention_bias,
                                                    input_mask)
        self_attention_out = self.postprocess_self_attn(
            self_attention_out, inputs)
        return self_attention_out

    def self_attention_step(self,
                            inputs,
                            incremental_state,
                            input_mask=None,
                            self_attention_bias=None):
        query = self.preprocess_self_attn(inputs, mask=input_mask)
        self_attention_out, _ = self.self_attention.step(
            query, query, query, incremental_state, self_attention_bias,
            input_mask)
        self_attention_out = self.postprocess_self_attn(
            self_attention_out, inputs)
        return self_attention_out

    # noinspection PyAttributeOutsideInit
    def build_encoder_attention(self):
        self.preprocess_enc_attn = self.get_preprocessing_module()
        self.enc_attention = MultiHeadAttention(
            self.model_dim,
            self.num_heads,
            self.attention_dropout,
            masked_layers=self.masked_layers,
            batch_first=self.batch_first)
        self.postprocess_enc_attn = self.get_postprocessing_module()

    def encoder_attention_layer(self,
                                inputs,
                                encoder_outputs,
                                input_mask=None,
                                context_mask=None,
                                encoder_attention_bias=None):
        query = self.preprocess_enc_attn(inputs, mask=input_mask)
        enc_attention_out, attention_weights = self.enc_attention(
            query, encoder_outputs, encoder_outputs, encoder_attention_bias,
            input_mask, context_mask)
        enc_attention_out = self.postprocess_enc_attn(enc_attention_out,
                                                      inputs)
        return enc_attention_out, attention_weights

    def encoder_attention_step(self,
                               inputs,
                               encoder_outputs,
                               incremental_state,
                               input_mask=None,
                               context_mask=None,
                               encoder_attention_bias=None):
        query = self.preprocess_enc_attn(inputs, mask=input_mask)
        enc_attention_out, attention_weights = self.enc_attention.step(
            query,
            encoder_outputs,
            encoder_outputs,
            incremental_state,
            encoder_attention_bias,
            input_mask,
            context_mask,
            static_kv=True)
        enc_attention_out = self.postprocess_enc_attn(enc_attention_out,
                                                      inputs)
        return enc_attention_out, attention_weights

    # noinspection PyAttributeOutsideInit
    def build_feed_forward(self):
        self.preprocess_ffn = self.get_preprocessing_module()
        self.feed_forward = MaskedFunction(
            get_feed_forward(self.feed_forward_type, self.model_dim,
                             self.feed_forward_dim, self.feed_forward_dropout,
                             self.weight_norm))
        self.postprocess_ffn = self.get_postprocessing_module()

    # noinspection PyAttributeOutsideInit
    def share_feed_forward(self, encoder):
        self.preprocess_ffn = encoder.preprocess_ffn
        self.postprocess_ffn = encoder.postprocess_ffn
        self.feed_forward = encoder.feed_forward

    def feed_forward_layer(self, inputs, input_mask=None):
        out = self.preprocess_ffn(inputs, mask=input_mask)
        out = self.feed_forward(
            out, mask=input_mask if self.masked_layers else None)
        out = self.postprocess_ffn(out, inputs)
        return out

    def feed_forward_step(self, inputs, input_mask):
        return self.feed_forward_layer(inputs, input_mask)

    def forward(self,
                inputs,
                context,
                input_mask=None,
                context_mask=None,
                self_attention_bias=None,
                encoder_attention_bias=None):
        self_attention_out = self.self_attention_layer(inputs, input_mask,
                                                       self_attention_bias)

        if not self.ignore_context:
            context_attention_out, attention_weights = self.encoder_attention_layer(
                self_attention_out, context, input_mask, context_mask,
                encoder_attention_bias)
        else:
            context_attention_out = self_attention_out
            attention_weights = None

        out = self.feed_forward_layer(context_attention_out, input_mask)
        return out, attention_weights

    def _step(self,
              inputs,
              encoder_outputs,
              incremental_state,
              input_mask=None,
              context_mask=None,
              self_attention_bias=None,
              encoder_attention_bias=None):
        self_attention_out = self.self_attention_step(inputs,
                                                      incremental_state,
                                                      input_mask,
                                                      self_attention_bias)

        if not self.ignore_context:
            enc_attention_out, attention_weights = self.encoder_attention_step(
                self_attention_out, encoder_outputs, incremental_state,
                input_mask, context_mask, encoder_attention_bias)
        else:
            enc_attention_out = self_attention_out
            attention_weights = None

        out = self.feed_forward_step(enc_attention_out, input_mask)
        return out, attention_weights

    def _update_names(self, state_dict, prefix, local_metadata, strict,
                      missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', 1)
        if version == 1 and prefix + 'version' not in state_dict:
            for key in self.preprocess_self_attn.state_dict().keys():
                state_dict[prefix + 'preprocess_self_attn.' +
                           key] = state_dict.pop(prefix + 'preprocess_attn.' +
                                                 key)
            for key in self.preprocess_enc_attn.state_dict().keys():
                state_dict[prefix + 'preprocess_enc_attn.' +
                           key] = state_dict.pop(prefix +
                                                 'preprocess_src_attn.' + key)
            for key in self.self_attention.state_dict().keys():
                state_dict[prefix + 'self_attention.' +
                           key] = state_dict.pop(prefix + 'attention_tgt.' +
                                                 key)
            for key in self.enc_attention.state_dict().keys():
                state_dict[prefix + 'enc_attention.' +
                           key] = state_dict.pop(prefix + 'attention_src.' +
                                                 key)
        elif version == 1:
            del state_dict[prefix + 'version']
Пример #2
0
class NMTDecoder(IncrementalDecoder):
    """Wraps a Decoder and adds embedding and projection"""
    def __init__(self,
                 decoder,
                 embedding,
                 dropout,
                 linear,
                 *,
                 copy_decoder=False,
                 batch_first=False,
                 extra_attention=False,
                 masked_layers=False,
                 attention_dropout=0.1,
                 language_embedding=None):
        super().__init__()
        self.decoder = decoder
        self.embedded_dropout = EmbeddingDropout(embedding, dropout)
        self.linear = linear
        self.copy_decoder = copy_decoder
        self.batch_first = batch_first
        self.extra_attention = extra_attention

        if self.copy_decoder:
            model_dim = linear.weight.size(1)
            self.gate_layer = XavierLinear(model_dim, 1)
            if extra_attention:
                self.attention = MultiHeadAttention(model_dim, 1,
                                                    attention_dropout,
                                                    batch_first, masked_layers)

            self._register_load_state_dict_pre_hook(
                self._load_nmt_model_compatibility)

        if language_embedding is not None:
            self.language_embedding = language_embedding
            model_dim = self.embedded_dropout.embedding.weight.size(1)
            emb_dim = language_embedding.weight.size(1)
            self.merge_layer = XavierLinear(model_dim + emb_dim, model_dim)
        else:
            self.language_embedding = None

    def forward(self,
                decoder_inputs,
                encoder_outputs,
                decoder_mask=None,
                encoder_mask=None):
        if self.language_embedding is not None:
            indices, language_id = decoder_inputs

            emb = torch.cat((self.embedded_dropout(indices),
                             self.language_embedding(language_id)),
                            dim=-1)
            emb = self.merge_layer(emb)
        else:
            emb = self.embedded_dropout(decoder_inputs)

        out, attention_weights = self.decoder(emb, encoder_outputs,
                                              decoder_mask, encoder_mask)

        if self.copy_decoder:
            if self.extra_attention:
                source_attention_bias = self.get_encoder_attention_bias(
                    encoder_outputs, self.batch_first, encoder_mask)
                _, attention_weights = self.attention(out, encoder_outputs,
                                                      encoder_outputs,
                                                      source_attention_bias,
                                                      decoder_mask,
                                                      encoder_mask)

            gates = torch.sigmoid(self.gate_layer(out)).squeeze(-1)

        if self.training and decoder_mask is not None:
            # Optimize the projection by calculating only those position where
            # the input was not padding
            nonpad_indices = torch.nonzero(decoder_mask.view(-1)).squeeze(1)
            out = out.view(-1, out.size(-1))
            out = out.index_select(0, nonpad_indices)

            # For multihead attention, the batch size dimension will be bigger. That means the results
            # of this operation are garbage
            if attention_weights is not None:
                attention_weights = attention_weights.view(
                    -1, attention_weights.size(-1))
                attention_weights = attention_weights.index_select(
                    0, nonpad_indices)
            if self.copy_decoder:
                gates = gates.masked_select(decoder_mask)

        if self.copy_decoder:
            attention_weights = {'attn': attention_weights, 'gates': gates}

        return self.linear(out), attention_weights

    def _step(self,
              decoder_inputs,
              encoder_outputs,
              incremental_state,
              decoder_mask=None,
              encoder_mask=None):
        emb = self.embedded_dropout(decoder_inputs)
        out, attention_weights = self.decoder.step(emb, encoder_outputs,
                                                   incremental_state,
                                                   decoder_mask, encoder_mask)

        if self.copy_decoder:
            if self.extra_attention:
                source_attention_bias = self.get_encoder_attention_bias(
                    encoder_outputs, self.batch_first, encoder_mask)
                _, attention_weights = self.attention(out, encoder_outputs,
                                                      encoder_outputs,
                                                      source_attention_bias,
                                                      decoder_mask,
                                                      encoder_mask)

            gates = torch.sigmoid(self.gate_layer(out)).squeeze(-1)
            attention_weights = {'attn': attention_weights, 'gates': gates}

        return self.linear(out), attention_weights

    def get_normalized_probs(self,
                             decoder_outputs,
                             attention_weights,
                             encoder_inputs=None,
                             encoder_mask=None,
                             decoder_mask=None,
                             log_probs=False):
        decoder_probs = self.decoder.get_normalized_probs(
            decoder_outputs, attention_weights, encoder_inputs, encoder_mask,
            decoder_mask, log_probs)

        if not self.copy_decoder:
            return decoder_probs

        attention_weights, gates = attention_weights[
            'attn'], attention_weights['gates']
        gates = gates.unsqueeze(-1)

        optimized = decoder_outputs.dim() == 2
        if not self.batch_first:
            encoder_inputs = encoder_inputs.transpose(0, 1).unsqueeze(
                0)  # (1, batch, src_len)
        if optimized:
            # (batch, tgt_len, src_len) | (tgt_len, batch, src_len)
            new_size = list(decoder_mask.size()) + [encoder_inputs.size(-1)]
            nonpad_indices = torch.nonzero(decoder_mask.view(-1)).squeeze(1)
            encoder_inputs = encoder_inputs.expand(new_size).contiguous() \
                .view(-1, encoder_inputs.size(-1)) \
                .index_select(0, nonpad_indices)
            # encoder_inputs is now (decoder_outputs.size(0), src_len)
        else:
            encoder_inputs = encoder_inputs.expand_as(attention_weights)

        assert encoder_inputs.size() == attention_weights.size()

        encoder_probs = decoder_probs.new_full(decoder_probs.size(), 1e-20)
        encoder_probs.scatter_add_(1 if optimized else 2, encoder_inputs,
                                   attention_weights)

        if log_probs:
            encoder_probs.log_()
            encoder_probs.add_(torch.log(gates))
            decoder_probs.add_(torch.log(1 - gates))
            # Very important to have it this way around, otherwise we will add -inf + inf = NaN
            res = decoder_probs + torch.log1p(
                torch.exp(encoder_probs - decoder_probs))
            return res
        else:
            return gates * encoder_probs + (1 - gates) * decoder_probs

    def reorder_incremental_state(self, incremental_state, new_order):
        self.decoder.reorder_incremental_state(incremental_state, new_order)
        if self.extra_attention:
            self.attention.reorder_incremental_state(incremental_state,
                                                     new_order)

    def _load_nmt_model_compatibility(self, state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs):
        if prefix + 'gate_layer.weight' in state_dict:
            return

        logger.info('Augmenting NMTModel with a copy decoder')
        items = self.gate_layer.state_dict(prefix=prefix +
                                           'gate_layer.').items()
        if self.extra_attention:
            items = itertools.chain(
                items,
                self.attention.state_dict(prefix=prefix +
                                          'attention.').items())
        for key, value in items:
            assert key not in state_dict
            state_dict[key] = value