def __init__(self, h, d_model, p, d_ff, position_encoder, time_encoder, attn_p=0.1, version=1.0): super(UniversalDecoderLayer, self).__init__() self.version = version self.position_encoder = position_encoder self.time_encoder = time_encoder self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.preprocess_src_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_src_attn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, static=onmt.constants.static) self.multihead_src = MultiHeadAttention(h, d_model, attn_p=attn_p, static=onmt.constants.static) if onmt.constants.activation_layer == 'linear_relu_linear': ff_p = p feedforward = FeedForward(d_model, d_ff, ff_p, static=onmt.constants.static) elif onmt.constants.activation_layer == 'maxout': k = int(math.ceil(d_ff / d_model)) feedforward = MaxOut(d_model, d_model, k) self.feedforward = Bottle(feedforward)
def __init__( self, h, d_model, p, d_ff, attn_p=0.1, ): super(LMDecoderLayer, self).__init__() self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, static=onmt.constants.static, share=1) ff_p = p feedforward = FeedForward(d_model, d_ff, ff_p, static=onmt.constants.static) self.feedforward = Bottle(feedforward)
def __init__(self, opt, death_rate=0.0): super(RelativeTransformerDecoderLayer, self).__init__() self.ignore_source = opt.ignore_source self.variational = opt.variational_dropout self.death_rate = death_rate self.fast_self_attention = opt.fast_self_attention # self.lfv_multilingual = opt.lfv_multilingual self.preprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n') self.postprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da', variational=self.variational) if not self.ignore_source: self.preprocess_src_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n') self.postprocess_src_attn = PrePostProcessing( opt.model_size, opt.dropout, sequence='da', variational=self.variational) if opt.fast_xattention: self.multihead_src = EncdecMultiheadAttn( opt.n_heads, opt.model_size, opt.attn_dropout) else: self.multihead_src = MultiHeadAttention( opt.n_heads, opt.model_size, attn_p=opt.attn_dropout, share=2) self.preprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n') self.postprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da', variational=self.variational) d_head = opt.model_size // opt.n_heads if not self.fast_self_attention: self.multihead_tgt = RelPartialLearnableMultiHeadAttn( opt.n_heads, opt.model_size, d_head, dropatt=opt.attn_dropout) else: self.multihead_tgt = RelativeSelfMultiheadAttn( opt.model_size, opt.n_heads, opt.attn_dropout) if not opt.fast_feed_forward: feedforward = FeedForward(opt.model_size, opt.inner_size, opt.dropout, variational=self.variational) self.feedforward = Bottle(feedforward) else: self.feedforward = PositionWiseFeedForward( opt.model_size, opt.inner_size, opt.dropout, variational=self.variational)
def __init__(self, h, d_model, p, d_ff, attn_p=0.1, version=1.0, ignore_source=False, variational=False, death_rate=0.0): super(RelativeTransformerDecoderLayer, self).__init__() self.version = version self.ignore_source = ignore_source self.variational = variational self.death_rate = death_rate self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational) if not self.ignore_source: self.preprocess_src_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_src_attn = PrePostProcessing( d_model, p, sequence='da', variational=self.variational) self.multihead_src = MultiHeadAttention(h, d_model, attn_p=attn_p, share=2) self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational) d_head = d_model // h self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h, d_model, d_head, dropatt=attn_p) # self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, share=1) if onmt.constants.activation_layer == 'linear_relu_linear': ff_p = p feedforward = FeedForward(d_model, d_ff, ff_p, variational=self.variational) elif onmt.constants.activation_layer == 'maxout': k = int(math.ceil(d_ff / d_model)) feedforward = MaxOut(d_model, d_model, k) elif onmt.constants.activation_layer == 'linear_swish_linear': ff_p = p feedforward = FeedForwardSwish(d_model, d_ff, ff_p) else: raise NotImplementedError self.feedforward = Bottle(feedforward)
class RelativeTransformerDecoderLayer(nn.Module): def __init__(self, h, d_model, p, d_ff, attn_p=0.1, version=1.0, ignore_source=False, variational=False, death_rate=0.0): super(RelativeTransformerDecoderLayer, self).__init__() self.version = version self.ignore_source = ignore_source self.variational = variational self.death_rate = death_rate self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational) if not self.ignore_source: self.preprocess_src_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_src_attn = PrePostProcessing( d_model, p, sequence='da', variational=self.variational) self.multihead_src = MultiHeadAttention(h, d_model, attn_p=attn_p, share=2) self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational) d_head = d_model // h self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h, d_model, d_head, dropatt=attn_p) # self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, share=1) if onmt.constants.activation_layer == 'linear_relu_linear': ff_p = p feedforward = FeedForward(d_model, d_ff, ff_p, variational=self.variational) elif onmt.constants.activation_layer == 'maxout': k = int(math.ceil(d_ff / d_model)) feedforward = MaxOut(d_model, d_model, k) elif onmt.constants.activation_layer == 'linear_swish_linear': ff_p = p feedforward = FeedForwardSwish(d_model, d_ff, ff_p) else: raise NotImplementedError self.feedforward = Bottle(feedforward) # def forward(self, input, context, pos_emb, r_w_bias, r_r_bias, mask_tgt, mask_src): def forward(self, input, context, pos_emb, mask_tgt, mask_src, incremental=False, incremental_cache=None, reuse_source=True, mems=None): """ Self attention layer layernorm > attn > dropout > residual """ coin = True if self.training and self.death_rate > 0: coin = (torch.rand(1)[0].item() >= self.death_rate) if coin: # input and context should be time first ? if mems is not None and mems.size(0) > 0: mems = self.preprocess_attn(mems) else: mems = None query = self.preprocess_attn(input) # out, _ = self.multihead_tgt(query, pos_emb, r_w_bias, r_r_bias, attn_mask=mask_tgt) # print(query.size(), pos_emb.size(), mask_tgt.size(), mems.size() if mems is not None else 0) out, _, incremental_cache = self.multihead_tgt( query, pos_emb, attn_mask=mask_tgt, mems=mems, incremental=incremental, incremental_cache=incremental_cache) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_attn(out, input) """ Context Attention layer layernorm > attn > dropout > residual """ if not self.ignore_source: query = self.preprocess_src_attn(input) incremental_source = incremental and reuse_source out, coverage, incremental_cache = self.multihead_src( query, context, context, mask_src, incremental=incremental_source, incremental_cache=incremental_cache) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_src_attn(out, input) else: coverage = None """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input)) # rescaling before residual if self.training and self.death_rate > 0: out = out / (1 - self.death_rate) input = self.postprocess_ffn(out, input) else: coverage = None return input, coverage, incremental_cache def step(self, input, context, pos_emb, mask_tgt, mask_src, buffer=None): """ Self attention layer layernorm > attn > dropout > residual """ query = self.preprocess_attn(input) out, _, buffer = self.multihead_tgt.step(query, pos_emb, attn_mask=mask_tgt, buffer=buffer) input = self.postprocess_attn(out, input) """ Context Attention layer layernorm > attn > dropout > residual """ if not self.ignore_source: query = self.preprocess_src_attn(input) out, coverage, buffer = self.multihead_src.step(query, context, context, mask_src, buffer=buffer) input = self.postprocess_src_attn(out, input) else: coverage = None """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input)) input = self.postprocess_ffn(out, input) return input, coverage, buffer
class LMDecoderLayer(nn.Module): """Wraps multi-head attentions and position-wise feed forward into one layer of decoder Args: h: number of heads d_model: dimension of model p: dropout probabolity d_ff: dimension of feed forward Params: multihead_tgt: multi-head self attentions layer multihead_src: multi-head encoder-decoder attentions layer feedforward: feed forward layer Input Shapes: query: batch_size x len_query x d_model key: batch_size x len_key x d_model value: batch_size x len_key x d_model context: batch_size x len_src x d_model mask_tgt: batch_size x len_query x len_key or broadcastable mask_src: batch_size x len_query x len_src or broadcastable Output Shapes: out: batch_size x len_query x d_model coverage: batch_size x len_query x len_key """ def __init__( self, h, d_model, p, d_ff, attn_p=0.1, ): super(LMDecoderLayer, self).__init__() self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n') self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static) self.multihead_tgt = MultiHeadAttention(h, d_model, attn_p=attn_p, static=onmt.constants.static, share=1) ff_p = p feedforward = FeedForward(d_model, d_ff, ff_p, static=onmt.constants.static) self.feedforward = Bottle(feedforward) def forward(self, input, mask_tgt): """ Self attention layer layernorm > attn > dropout > residual """ # input and context should be time first ? query = self.preprocess_attn(input) self_context = query out, _ = self.multihead_tgt(query, self_context, self_context, mask_tgt) input = self.postprocess_attn(out, input) """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input)) input = self.postprocess_ffn(out, input) coverage = None return input, coverage def step(self, input, mask_tgt, buffer=None): """ Self attention layer layernorm > attn > dropout > residual """ query = self.preprocess_attn(input) out, _, buffer = self.multihead_tgt.step(query, query, query, mask_tgt, buffer=buffer) input = self.postprocess_attn(out, input) coverage = None """ Feed forward layer layernorm > ffn > dropout > residual """ out = self.feedforward(self.preprocess_ffn(input)) input = self.postprocess_ffn(out, input) return input, coverage, buffer