def __init__(self, embed_dim, att_type, att_heads, att_mid_dim, att_mid_drop, dropout): super(LowRankBilinearLayer, self).__init__() self.encoder_attn = LowRank(embed_dim=embed_dim, att_type=att_type, att_heads=att_heads, att_mid_dim=att_mid_dim, att_mid_drop=att_mid_drop) self.dropout = nn.Dropout(dropout) if dropout > 0 else None
def __init__( self, embed_dim, dropout, att_type, att_heads, att_mid_dim, att_mid_drop, bifeat_emb_act, bifeat_emb_drop, ff_dropout, last_layer = False ): super(DecoderLayer, self).__init__() self.last_layer = last_layer self.word_attn = LowRank( embed_dim = embed_dim, att_type = att_type, att_heads = att_heads, att_mid_dim = att_mid_dim, att_mid_drop = att_mid_drop) self.word_dropout = nn.Dropout(dropout) self.cross_att = LowRank( embed_dim = embed_dim, att_type = att_type, att_heads = att_heads, att_mid_dim = att_mid_dim, att_mid_drop = att_mid_drop) self.cross_dropout = nn.Dropout(dropout) self.layer_norm_cross = torch.nn.LayerNorm(embed_dim) if self.last_layer == False: self.bifeat_emb = nn.Sequential( nn.Linear(2 * embed_dim, embed_dim), utils.activation(bifeat_emb_act), nn.Dropout(bifeat_emb_drop) ) self.layer_norm_x = torch.nn.LayerNorm(embed_dim) self.ff_layer = blocks.create( 'FeedForward', embed_dim = embed_dim, ffn_embed_dim = embed_dim * 4, relu_dropout = ff_dropout, dropout = ff_dropout) self.layer_norm_gx = torch.nn.LayerNorm(embed_dim)
class LowRankBilinearLayer(nn.Module): def __init__(self, embed_dim, att_type, att_heads, att_mid_dim, att_mid_drop, dropout): super(LowRankBilinearLayer, self).__init__() self.encoder_attn = LowRank(embed_dim=embed_dim, att_type=att_type, att_heads=att_heads, att_mid_dim=att_mid_dim, att_mid_drop=att_mid_drop) self.dropout = nn.Dropout(dropout) if dropout > 0 else None def forward(self, x, key=None, mask=None, value1=None, value2=None, precompute=False): x = self.encoder_attn(query=x, key=key if key is not None else x, mask=mask, value1=value1 if value1 is not None else x, value2=value2 if value2 is not None else x, precompute=precompute) if self.dropout is not None: x = self.dropout(x) return x def precompute(self, key, value2): return self.encoder_attn.precompute(key, value2)
def __init__( self, embed_dim, dropout, att_type, att_heads, att_mid_dim, att_mid_drop, bifeat_emb_act, bifeat_emb_drop, ff_dropout ): super(EncoderLayer, self).__init__() self.encoder_attn = LowRank( embed_dim = embed_dim, att_type = att_type, att_heads = att_heads, att_mid_dim = att_mid_dim, att_mid_drop = att_mid_drop) self.dropout = nn.Dropout(dropout) self.bifeat_emb = nn.Sequential( nn.Linear(2 * embed_dim, embed_dim), utils.activation(bifeat_emb_act), nn.Dropout(bifeat_emb_drop) ) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ff_layer = blocks.create( 'FeedForward', embed_dim = embed_dim, ffn_embed_dim = embed_dim * 4, relu_dropout = ff_dropout, dropout = ff_dropout)
class DecoderLayer(nn.Module): def __init__(self, embed_dim, dropout, att_type, att_heads, att_mid_dim, att_mid_drop, bifeat_emb_act, bifeat_emb_drop, ff_dropout, last_layer=False): super(DecoderLayer, self).__init__() self.last_layer = last_layer self.word_attn = LowRank(embed_dim=embed_dim, att_type=att_type, att_heads=att_heads, att_mid_dim=att_mid_dim, att_mid_drop=att_mid_drop) self.word_dropout = nn.Dropout(dropout) self.cross_att = LowRank(embed_dim=embed_dim, att_type=att_type, att_heads=att_heads, att_mid_dim=att_mid_dim, att_mid_drop=att_mid_drop) self.cross_dropout = nn.Dropout(dropout) self.layer_norm_cross = torch.nn.LayerNorm(embed_dim) if self.last_layer == False: self.bifeat_emb = nn.Sequential( nn.Linear(2 * embed_dim, embed_dim), utils.activation(bifeat_emb_act), nn.Dropout(bifeat_emb_drop)) self.layer_norm_x = torch.nn.LayerNorm(embed_dim) self.ff_layer = blocks.create('FeedForward', embed_dim=embed_dim, ffn_embed_dim=embed_dim * 4, relu_dropout=ff_dropout, dropout=ff_dropout) self.layer_norm_gx = torch.nn.LayerNorm(embed_dim) def apply_to_states(self, fn): self.word_attn.apply_to_states(fn) def init_buffer(self, batch_size): self.word_attn.init_buffer(batch_size) def clear_buffer(self): self.word_attn.clear_buffer() def precompute(self, encoder_out): key, value2 = self.cross_att.precompute(encoder_out, encoder_out) return key, value2 def forward(self, gx, x, encoder_out, att_mask, seq_mask, p_key=None, p_value2=None, precompute=False): word_x = x residual = x x = self.word_attn.forward2(query=gx, key=x, mask=seq_mask, value1=gx, value2=x) x = self.word_dropout(x) x = residual + x residual = x x = self.layer_norm_cross(x) x = self.cross_att.forward2( query=x, key=encoder_out if precompute == False else p_key, mask=att_mask, value1=x, value2=encoder_out if precompute == False else p_value2, precompute=precompute) x = self.cross_dropout(x) gx = residual + x gx = self.layer_norm_gx(gx) if self.last_layer == False: x_ = torch.cat([gx, word_x], dim=-1) x = self.bifeat_emb(x_) + word_x x = self.layer_norm_x(x) if self.ff_layer is not None: x = self.ff_layer(x) else: x = None return gx, x