예제 #1
0
class RelativeTransformerDecoderLayer(nn.Module):
    def __init__(self,
                 h,
                 d_model,
                 p,
                 d_ff,
                 attn_p=0.1,
                 version=1.0,
                 ignore_source=False,
                 variational=False,
                 death_rate=0.0):
        super(RelativeTransformerDecoderLayer, self).__init__()
        self.version = version
        self.ignore_source = ignore_source
        self.variational = variational
        self.death_rate = death_rate

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(d_model,
                                                         p,
                                                         sequence='n')
            self.postprocess_src_attn = PrePostProcessing(
                d_model, p, sequence='da', variational=self.variational)
            self.multihead_src = MultiHeadAttention(h,
                                                    d_model,
                                                    attn_p=attn_p,
                                                    share=2)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 variational=self.variational)

        d_head = d_model // h
        self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h,
                                                              d_model,
                                                              d_head,
                                                              dropatt=attn_p)
        # self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, share=1)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model,
                                      d_ff,
                                      ff_p,
                                      variational=self.variational)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        elif onmt.constants.activation_layer == 'linear_swish_linear':
            ff_p = p
            feedforward = FeedForwardSwish(d_model, d_ff, ff_p)
        else:
            raise NotImplementedError
        self.feedforward = Bottle(feedforward)

    # def forward(self, input, context, pos_emb, r_w_bias, r_r_bias, mask_tgt, mask_src):
    def forward(self,
                input,
                context,
                pos_emb,
                mask_tgt,
                mask_src,
                incremental=False,
                incremental_cache=None,
                reuse_source=True,
                mems=None):
        """ Self attention layer
            layernorm > attn > dropout > residual
        """

        coin = True
        if self.training and self.death_rate > 0:
            coin = (torch.rand(1)[0].item() >= self.death_rate)

        if coin:
            # input and context should be time first ?
            if mems is not None and mems.size(0) > 0:
                mems = self.preprocess_attn(mems)
            else:
                mems = None

            query = self.preprocess_attn(input)

            # out, _ = self.multihead_tgt(query, pos_emb, r_w_bias, r_r_bias, attn_mask=mask_tgt)
            # print(query.size(), pos_emb.size(), mask_tgt.size(), mems.size() if mems is not None else 0)
            out, _, incremental_cache = self.multihead_tgt(
                query,
                pos_emb,
                attn_mask=mask_tgt,
                mems=mems,
                incremental=incremental,
                incremental_cache=incremental_cache)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_attn(out, input)
            """ Context Attention layer 
                layernorm > attn > dropout > residual
            """
            if not self.ignore_source:
                query = self.preprocess_src_attn(input)
                incremental_source = incremental and reuse_source
                out, coverage, incremental_cache = self.multihead_src(
                    query,
                    context,
                    context,
                    mask_src,
                    incremental=incremental_source,
                    incremental_cache=incremental_cache)

                # rescaling before residual
                if self.training and self.death_rate > 0:
                    out = out / (1 - self.death_rate)

                input = self.postprocess_src_attn(out, input)
            else:
                coverage = None
            """ Feed forward layer 
                layernorm > ffn > dropout > residual
            """
            out = self.feedforward(self.preprocess_ffn(input))

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_ffn(out, input)
        else:
            coverage = None

        return input, coverage, incremental_cache

    def step(self, input, context, pos_emb, mask_tgt, mask_src, buffer=None):
        """ Self attention layer
            layernorm > attn > dropout > residual
        """

        query = self.preprocess_attn(input)

        out, _, buffer = self.multihead_tgt.step(query,
                                                 pos_emb,
                                                 attn_mask=mask_tgt,
                                                 buffer=buffer)

        input = self.postprocess_attn(out, input)
        """ Context Attention layer
            layernorm > attn > dropout > residual
        """
        if not self.ignore_source:
            query = self.preprocess_src_attn(input)
            out, coverage, buffer = self.multihead_src.step(query,
                                                            context,
                                                            context,
                                                            mask_src,
                                                            buffer=buffer)
            input = self.postprocess_src_attn(out, input)
        else:
            coverage = None
        """ Feed forward layer
            layernorm > ffn > dropout > residual
        """
        out = self.feedforward(self.preprocess_ffn(input))
        input = self.postprocess_ffn(out, input)

        return input, coverage, buffer
예제 #2
0
class LMDecoderLayer(nn.Module):
    """Wraps multi-head attentions and position-wise feed forward into one layer of decoder
    
    Args:
        h:       number of heads
        d_model: dimension of model
        p:       dropout probabolity 
        d_ff:    dimension of feed forward
        
    Params:
        multihead_tgt:  multi-head self attentions layer
        multihead_src:  multi-head encoder-decoder attentions layer        
        feedforward:    feed forward layer
    
    Input Shapes:
        query:    batch_size x len_query x d_model 
        key:      batch_size x len_key x d_model   
        value:    batch_size x len_key x d_model
        context:  batch_size x len_src x d_model
        mask_tgt: batch_size x len_query x len_key or broadcastable 
        mask_src: batch_size x len_query x len_src or broadcastable 
    
    Output Shapes:
        out:      batch_size x len_query x d_model
        coverage: batch_size x len_query x len_key
        
    """
    def __init__(
        self,
        h,
        d_model,
        p,
        d_ff,
        attn_p=0.1,
    ):
        super(LMDecoderLayer, self).__init__()

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  static=onmt.constants.static)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 static=onmt.constants.static)

        self.multihead_tgt = MultiHeadAttention(h,
                                                d_model,
                                                attn_p=attn_p,
                                                static=onmt.constants.static,
                                                share=1)

        ff_p = p
        feedforward = FeedForward(d_model,
                                  d_ff,
                                  ff_p,
                                  static=onmt.constants.static)
        self.feedforward = Bottle(feedforward)

    def forward(self, input, mask_tgt):
        """ Self attention layer 
            layernorm > attn > dropout > residual
        """

        # input and context should be time first ?

        query = self.preprocess_attn(input)

        self_context = query

        out, _ = self.multihead_tgt(query, self_context, self_context,
                                    mask_tgt)

        input = self.postprocess_attn(out, input)
        """ Feed forward layer 
            layernorm > ffn > dropout > residual
        """
        out = self.feedforward(self.preprocess_ffn(input))
        input = self.postprocess_ffn(out, input)

        coverage = None
        return input, coverage

    def step(self, input, mask_tgt, buffer=None):
        """ Self attention layer 
            layernorm > attn > dropout > residual
        """

        query = self.preprocess_attn(input)

        out, _, buffer = self.multihead_tgt.step(query,
                                                 query,
                                                 query,
                                                 mask_tgt,
                                                 buffer=buffer)

        input = self.postprocess_attn(out, input)

        coverage = None
        """ Feed forward layer 
            layernorm > ffn > dropout > residual
        """
        out = self.feedforward(self.preprocess_ffn(input))

        input = self.postprocess_ffn(out, input)

        return input, coverage, buffer