示例#1
0
    def __init__(self, h, d_model, p, d_ff, attn_p=0.1, version=1.0, ignore_source=False,
                 variational=False, death_rate=0.0):
        super(TransformerXLDecoderLayer, self).__init__()
        self.version = version
        self.ignore_source = ignore_source
        self.variational = variational
        self.death_rate = death_rate

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational)

        d_head = d_model // h
        self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h, d_model, d_head, dropatt=attn_p)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model, d_ff, ff_p, variational=self.variational)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        elif onmt.constants.activation_layer == 'linear_swish_linear':
            ff_p = p
            feedforward = FeedForwardSwish(d_model, d_ff, ff_p)
        else:
            raise NotImplementedError
        self.feedforward = Bottle(feedforward)
示例#2
0
 def __init__(self, opt):
     super().__init__()
     # self.layer_norm = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
     self.layer_norm = nn.LayerNorm((opt.model_size, ),
                                    elementwise_affine=True)
     # self.attn = MultiHeadAttention(opt.n_heads, opt.model_size, attn_p=opt.attn_dropout, share=1)
     self.attn = RelPartialLearnableMultiHeadAttn(opt.n_heads,
                                                  opt.model_size,
                                                  opt.model_size //
                                                  opt.n_heads,
                                                  dropatt=opt.attn_dropout)
     self.dropout = opt.attn_dropout
     self.variational = opt.variational_dropout
示例#3
0
    def __init__(self, opt, death_rate=0.0, **kwargs):
        super(RelativeTransformerEncoderLayer, self).__init__()
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)
        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='da',
                                                 variational=self.variational)
        d_head = opt.model_size // opt.n_heads
        if not self.fast_self_attention:
            self.multihead = RelPartialLearnableMultiHeadAttn(
                opt.n_heads, opt.model_size, d_head, dropatt=opt.attn_dropout)
        else:
            self.multihead = RelativeSelfMultiheadAttn(opt.model_size,
                                                       opt.n_heads,
                                                       opt.attn_dropout)

        print(opt.fast_feed_forward)
        if not opt.fast_feed_forward:
            feedforward = FeedForward(opt.model_size,
                                      opt.inner_size,
                                      opt.dropout,
                                      variational=self.variational)
            self.feedforward = Bottle(feedforward)
        else:
            self.feedforward = PositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational)
示例#4
0
    def __init__(self, opt, death_rate=0.0):
        super(RelativeTransformerDecoderLayer, self).__init__()
        self.ignore_source = opt.ignore_source
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention
        # self.lfv_multilingual = opt.lfv_multilingual

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(opt.model_size,
                                                         opt.dropout,
                                                         sequence='n')
            self.postprocess_src_attn = PrePostProcessing(
                opt.model_size,
                opt.dropout,
                sequence='da',
                variational=self.variational)

            if opt.fast_xattention:
                self.multihead_src = EncdecMultiheadAttn(
                    opt.n_heads, opt.model_size, opt.attn_dropout)
            else:
                self.multihead_src = MultiHeadAttention(
                    opt.n_heads,
                    opt.model_size,
                    attn_p=opt.attn_dropout,
                    share=2)

        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='da',
                                                 variational=self.variational)

        d_head = opt.model_size // opt.n_heads

        if not self.fast_self_attention:
            self.multihead_tgt = RelPartialLearnableMultiHeadAttn(
                opt.n_heads, opt.model_size, d_head, dropatt=opt.attn_dropout)
        else:
            self.multihead_tgt = RelativeSelfMultiheadAttn(
                opt.model_size, opt.n_heads, opt.attn_dropout)

        if not opt.fast_feed_forward:
            feedforward = FeedForward(opt.model_size,
                                      opt.inner_size,
                                      opt.dropout,
                                      variational=self.variational)
            self.feedforward = Bottle(feedforward)
        else:
            self.feedforward = PositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational)
示例#5
0
class RelativeTransformerDecoderLayer(nn.Module):
    def __init__(self,
                 h,
                 d_model,
                 p,
                 d_ff,
                 attn_p=0.1,
                 version=1.0,
                 ignore_source=False,
                 variational=False,
                 death_rate=0.0):
        super(RelativeTransformerDecoderLayer, self).__init__()
        self.version = version
        self.ignore_source = ignore_source
        self.variational = variational
        self.death_rate = death_rate

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(d_model,
                                                         p,
                                                         sequence='n')
            self.postprocess_src_attn = PrePostProcessing(
                d_model, p, sequence='da', variational=self.variational)
            self.multihead_src = MultiHeadAttention(h,
                                                    d_model,
                                                    attn_p=attn_p,
                                                    share=2)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 variational=self.variational)

        d_head = d_model // h
        self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h,
                                                              d_model,
                                                              d_head,
                                                              dropatt=attn_p)
        # self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, share=1)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model,
                                      d_ff,
                                      ff_p,
                                      variational=self.variational)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        elif onmt.constants.activation_layer == 'linear_swish_linear':
            ff_p = p
            feedforward = FeedForwardSwish(d_model, d_ff, ff_p)
        else:
            raise NotImplementedError
        self.feedforward = Bottle(feedforward)

    # def forward(self, input, context, pos_emb, r_w_bias, r_r_bias, mask_tgt, mask_src):
    def forward(self,
                input,
                context,
                pos_emb,
                mask_tgt,
                mask_src,
                incremental=False,
                incremental_cache=None,
                reuse_source=True,
                mems=None):
        """ Self attention layer
            layernorm > attn > dropout > residual
        """

        coin = True
        if self.training and self.death_rate > 0:
            coin = (torch.rand(1)[0].item() >= self.death_rate)

        if coin:
            # input and context should be time first ?
            if mems is not None and mems.size(0) > 0:
                mems = self.preprocess_attn(mems)
            else:
                mems = None

            query = self.preprocess_attn(input)

            # out, _ = self.multihead_tgt(query, pos_emb, r_w_bias, r_r_bias, attn_mask=mask_tgt)
            # print(query.size(), pos_emb.size(), mask_tgt.size(), mems.size() if mems is not None else 0)
            out, _, incremental_cache = self.multihead_tgt(
                query,
                pos_emb,
                attn_mask=mask_tgt,
                mems=mems,
                incremental=incremental,
                incremental_cache=incremental_cache)

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_attn(out, input)
            """ Context Attention layer 
                layernorm > attn > dropout > residual
            """
            if not self.ignore_source:
                query = self.preprocess_src_attn(input)
                incremental_source = incremental and reuse_source
                out, coverage, incremental_cache = self.multihead_src(
                    query,
                    context,
                    context,
                    mask_src,
                    incremental=incremental_source,
                    incremental_cache=incremental_cache)

                # rescaling before residual
                if self.training and self.death_rate > 0:
                    out = out / (1 - self.death_rate)

                input = self.postprocess_src_attn(out, input)
            else:
                coverage = None
            """ Feed forward layer 
                layernorm > ffn > dropout > residual
            """
            out = self.feedforward(self.preprocess_ffn(input))

            # rescaling before residual
            if self.training and self.death_rate > 0:
                out = out / (1 - self.death_rate)

            input = self.postprocess_ffn(out, input)
        else:
            coverage = None

        return input, coverage, incremental_cache

    def step(self, input, context, pos_emb, mask_tgt, mask_src, buffer=None):
        """ Self attention layer
            layernorm > attn > dropout > residual
        """

        query = self.preprocess_attn(input)

        out, _, buffer = self.multihead_tgt.step(query,
                                                 pos_emb,
                                                 attn_mask=mask_tgt,
                                                 buffer=buffer)

        input = self.postprocess_attn(out, input)
        """ Context Attention layer
            layernorm > attn > dropout > residual
        """
        if not self.ignore_source:
            query = self.preprocess_src_attn(input)
            out, coverage, buffer = self.multihead_src.step(query,
                                                            context,
                                                            context,
                                                            mask_src,
                                                            buffer=buffer)
            input = self.postprocess_src_attn(out, input)
        else:
            coverage = None
        """ Feed forward layer
            layernorm > ffn > dropout > residual
        """
        out = self.feedforward(self.preprocess_ffn(input))
        input = self.postprocess_ffn(out, input)

        return input, coverage, buffer