Esempio n. 1
0
 def __init__(self, embed_dim, att_type, att_heads, att_mid_dim,
              att_mid_drop, dropout):
     super(LowRankBilinearLayer, self).__init__()
     self.encoder_attn = LowRank(embed_dim=embed_dim,
                                 att_type=att_type,
                                 att_heads=att_heads,
                                 att_mid_dim=att_mid_dim,
                                 att_mid_drop=att_mid_drop)
     self.dropout = nn.Dropout(dropout) if dropout > 0 else None
Esempio n. 2
0
    def __init__(
        self, 
        embed_dim, 
        dropout, 
        att_type, 
        att_heads, 
        att_mid_dim, 
        att_mid_drop,
        bifeat_emb_act, 
        bifeat_emb_drop, 
        ff_dropout, 
        last_layer = False
    ):
        super(DecoderLayer, self).__init__()
        self.last_layer = last_layer
        self.word_attn = LowRank(
            embed_dim = embed_dim, 
            att_type = att_type, 
            att_heads = att_heads, 
            att_mid_dim = att_mid_dim, 
            att_mid_drop = att_mid_drop)
        self.word_dropout = nn.Dropout(dropout)

        self.cross_att = LowRank(
            embed_dim = embed_dim, 
            att_type = att_type, 
            att_heads = att_heads, 
            att_mid_dim = att_mid_dim, 
            att_mid_drop = att_mid_drop)
        self.cross_dropout = nn.Dropout(dropout)
        self.layer_norm_cross = torch.nn.LayerNorm(embed_dim)

        if self.last_layer == False:
            self.bifeat_emb = nn.Sequential(
                nn.Linear(2 * embed_dim, embed_dim),
                utils.activation(bifeat_emb_act),
                nn.Dropout(bifeat_emb_drop)
            )
            self.layer_norm_x = torch.nn.LayerNorm(embed_dim)

            self.ff_layer = blocks.create(
                'FeedForward',
                embed_dim = embed_dim, 
                ffn_embed_dim = embed_dim * 4, 
                relu_dropout = ff_dropout, 
                dropout = ff_dropout)

        self.layer_norm_gx = torch.nn.LayerNorm(embed_dim)
class LowRankBilinearLayer(nn.Module):
    def __init__(self, embed_dim, att_type, att_heads, att_mid_dim,
                 att_mid_drop, dropout):
        super(LowRankBilinearLayer, self).__init__()
        self.encoder_attn = LowRank(embed_dim=embed_dim,
                                    att_type=att_type,
                                    att_heads=att_heads,
                                    att_mid_dim=att_mid_dim,
                                    att_mid_drop=att_mid_drop)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self,
                x,
                key=None,
                mask=None,
                value1=None,
                value2=None,
                precompute=False):
        x = self.encoder_attn(query=x,
                              key=key if key is not None else x,
                              mask=mask,
                              value1=value1 if value1 is not None else x,
                              value2=value2 if value2 is not None else x,
                              precompute=precompute)
        if self.dropout is not None:
            x = self.dropout(x)
        return x

    def precompute(self, key, value2):
        return self.encoder_attn.precompute(key, value2)
Esempio n. 4
0
    def __init__(
        self, 
        embed_dim, 
        dropout, 
        att_type, 
        att_heads, 
        att_mid_dim, 
        att_mid_drop,
        bifeat_emb_act, 
        bifeat_emb_drop, 
        ff_dropout
    ):
        super(EncoderLayer, self).__init__()
        self.encoder_attn = LowRank(
            embed_dim = embed_dim, 
            att_type = att_type, 
            att_heads = att_heads, 
            att_mid_dim = att_mid_dim, 
            att_mid_drop = att_mid_drop)
        self.dropout = nn.Dropout(dropout)

        self.bifeat_emb = nn.Sequential(
            nn.Linear(2 * embed_dim, embed_dim),
            utils.activation(bifeat_emb_act),
            nn.Dropout(bifeat_emb_drop)
        )
        self.layer_norm = torch.nn.LayerNorm(embed_dim)

        self.ff_layer = blocks.create(
            'FeedForward',
            embed_dim = embed_dim, 
            ffn_embed_dim = embed_dim * 4, 
            relu_dropout = ff_dropout, 
            dropout = ff_dropout)
Esempio n. 5
0
class DecoderLayer(nn.Module):
    def __init__(self,
                 embed_dim,
                 dropout,
                 att_type,
                 att_heads,
                 att_mid_dim,
                 att_mid_drop,
                 bifeat_emb_act,
                 bifeat_emb_drop,
                 ff_dropout,
                 last_layer=False):
        super(DecoderLayer, self).__init__()
        self.last_layer = last_layer
        self.word_attn = LowRank(embed_dim=embed_dim,
                                 att_type=att_type,
                                 att_heads=att_heads,
                                 att_mid_dim=att_mid_dim,
                                 att_mid_drop=att_mid_drop)
        self.word_dropout = nn.Dropout(dropout)

        self.cross_att = LowRank(embed_dim=embed_dim,
                                 att_type=att_type,
                                 att_heads=att_heads,
                                 att_mid_dim=att_mid_dim,
                                 att_mid_drop=att_mid_drop)
        self.cross_dropout = nn.Dropout(dropout)
        self.layer_norm_cross = torch.nn.LayerNorm(embed_dim)

        if self.last_layer == False:
            self.bifeat_emb = nn.Sequential(
                nn.Linear(2 * embed_dim, embed_dim),
                utils.activation(bifeat_emb_act), nn.Dropout(bifeat_emb_drop))
            self.layer_norm_x = torch.nn.LayerNorm(embed_dim)

            self.ff_layer = blocks.create('FeedForward',
                                          embed_dim=embed_dim,
                                          ffn_embed_dim=embed_dim * 4,
                                          relu_dropout=ff_dropout,
                                          dropout=ff_dropout)

        self.layer_norm_gx = torch.nn.LayerNorm(embed_dim)

    def apply_to_states(self, fn):
        self.word_attn.apply_to_states(fn)

    def init_buffer(self, batch_size):
        self.word_attn.init_buffer(batch_size)

    def clear_buffer(self):
        self.word_attn.clear_buffer()

    def precompute(self, encoder_out):
        key, value2 = self.cross_att.precompute(encoder_out, encoder_out)
        return key, value2

    def forward(self,
                gx,
                x,
                encoder_out,
                att_mask,
                seq_mask,
                p_key=None,
                p_value2=None,
                precompute=False):
        word_x = x
        residual = x
        x = self.word_attn.forward2(query=gx,
                                    key=x,
                                    mask=seq_mask,
                                    value1=gx,
                                    value2=x)
        x = self.word_dropout(x)
        x = residual + x

        residual = x
        x = self.layer_norm_cross(x)
        x = self.cross_att.forward2(
            query=x,
            key=encoder_out if precompute == False else p_key,
            mask=att_mask,
            value1=x,
            value2=encoder_out if precompute == False else p_value2,
            precompute=precompute)
        x = self.cross_dropout(x)
        gx = residual + x
        gx = self.layer_norm_gx(gx)

        if self.last_layer == False:
            x_ = torch.cat([gx, word_x], dim=-1)
            x = self.bifeat_emb(x_) + word_x
            x = self.layer_norm_x(x)

            if self.ff_layer is not None:
                x = self.ff_layer(x)
        else:
            x = None
        return gx, x