def __init__(self,
                 d_model,
                 n_head,
                 d_head,
                 dropout,
                 dropatt,
                 pre_lnorm=False,
                 local_size=None):
        super(WeightShareSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head
        self.dropout = dropout
        self.scale = 1 / (d_head**0.5)

        self.qkv_net = nn.Conv1d(d_model,
                                 3 * n_head * d_head,
                                 kernel_size=1,
                                 bias=False)
        self.r_net = nn.Conv1d(d_model,
                               n_head * d_head,
                               kernel_size=1,
                               bias=False)
        self.temp_encoder = nn.Linear(3 * d_model, 3 * n_head * d_head)
        self.r_w_bias = nn.Parameter(
            torch.rand(n_head, d_head).uniform_(-0.05, 0.05))
        self.r_r_bias = nn.Parameter(
            torch.rand(n_head, d_head).uniform_(-0.05, 0.05))
        self.o_net = nn.Conv1d(n_head * d_head, d_model, kernel_size=1)
        self.dropatt = VariationalAttnDropout(dropout=dropatt)
        self.drop = VariationalHidDropout(dropout=dropout)

        self.pre_lnorm = pre_lnorm
        self.local_size = local_size
예제 #2
0
class WeightShareSelfAttention(nn.Module):
    # This is similar to the RelPartialLearnableMultiHeadAttn class in Transformer-XL
    def __init__(self, d_model, n_head, d_head, dropout, dropatt, 
                 pre_lnorm=False, local_size=None):
        super(WeightShareSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head
        self.dropout = dropout
        self.scale = 1 / (d_head ** 0.5)

        self.qkv_net = nn.Conv1d(d_model, 3 * n_head * d_head, kernel_size=1, bias=False)
        self.r_net = nn.Conv1d(d_model, n_head * d_head, kernel_size=1, bias=False)
        self.r_w_bias = nn.Parameter(torch.rand(n_head, d_head).uniform_(-0.05, 0.05))
        self.r_r_bias = nn.Parameter(torch.rand(n_head, d_head).uniform_(-0.05, 0.05))
        self.o_net = nn.Conv1d(n_head * d_head, d_model, kernel_size=1)
        self.dropatt = VariationalAttnDropout(dropout=dropatt)
        self.drop = VariationalHidDropout(dropout=dropout)

        self.pre_lnorm = pre_lnorm
        self.local_size = local_size
        
    def wnorm(self):
        print("Weight normalization applied to SA")
        self.qkv_net, self.qkv_fn = weight_norm(module=self.qkv_net, names=['weight'], dim=0)
        self.r_net, self.r_fn = weight_norm(module=self.r_net, names=['weight'], dim=0)
        self.o_net, self.o_fn = weight_norm(module=self.o_net, names=['weight'], dim=0)

    def reset(self, bsz, qlen, klen):
        self.dropatt.reset_mask(torch.zeros(bsz, self.n_head, qlen, klen))
        self.drop.reset_mask(torch.zeros(bsz, self.d_model, qlen))
        if 'qkv_fn' in self.__dict__:
            self.qkv_fn.reset(self.qkv_net)
        if 'r_fn' in self.__dict__:
            self.r_fn.reset(self.r_net)
        if 'o_fn' in self.__dict__:
            self.o_fn.reset(self.o_net)

    def copy(self, func):
        # Destructive copy
        self.qkv_net.weight.data = func.qkv_net.weight.data.clone()
        self.r_net.weight.data = func.r_net.weight.data.clone()
        self.r_w_bias.data = func.r_w_bias.data.clone()
        self.r_r_bias.data = func.r_r_bias.data.clone()
        self.o_net.weight.data = func.o_net.weight.data.clone()
        self.o_net.bias.data = func.o_net.bias.data.clone()
        self.dropatt.mask = func.dropatt.mask.clone()
        self.drop.mask = func.drop.mask.clone()

    def _rel_shift(self, x):
        # x has dimension (bsz x n_head x qlen x klen)
        bsz, n_head, qlen, klen = x.size()
        x_padded = F.pad(x, (1,0))
        x_padded = x_padded.view(bsz, n_head, klen+1, qlen)
        return x_padded[:,:,1:].view_as(x)

    def forward(self, z1ss, pos_emb, u1ss, mems=None):
        # Note: In this context, qlen means the length of the (small) subsequence; and mlen describes
        #       the length of the padding. Their sum is klen. 
        bsz, d_model, qlen = z1ss.size()
        r_w_bias, r_r_bias = self.r_w_bias, self.r_r_bias
        n_head, d_head = self.n_head, self.d_head
        rlen = pos_emb.size(2)
        
        if mems is None: 
            mems = torch.tensor([]).view(0,0,0)
        mlen = mems.size(2)
        cat = torch.cat([mems, z1ss], dim=-1)

        if self.pre_lnorm:
            cat = F.layer_norm(cat.transpose(1,2), (d_model,)).transpose(1,2)
        w_heads = self.qkv_net(cat)      # (N, 3C, L)
        r_head_k = self.r_net(pos_emb)

        # Input injection
        w_heads += u1ss
        w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=1)
        w_head_q = w_head_q[:,:,-qlen:]

        klen = w_head_k.size(2)

        w_head_q = w_head_q.view(bsz, n_head, d_head, qlen)           # bsz x n_head x d_head x qlen
        w_head_k = w_head_k.view(bsz, n_head, d_head, klen)           # bsz x n_head x d_head x klen
        w_head_v = w_head_v.view(bsz, n_head, d_head, klen)           # bsz x n_head x d_head x klen

        r_head_k = r_head_k.view(n_head, d_head, rlen)                # n_head x d_head x rlen

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[:,:,None]                   # bsz x n_head x d_head x qlen
        AC = torch.einsum('bndi,bndj->bnij', rw_head_q, w_head_k)
        rr_head_q = w_head_q + r_r_bias[:,:,None]
        BD = torch.einsum('bndi,ndj->bnij', rr_head_q, r_head_k)
        BD = self._rel_shift(BD)    # for the sake of relative positional embedding

        attn_score = AC + BD        # bsz x n_head x qlen x klen
        attn_score.mul_(self.scale)
            
        #### compute attention probability
        # We apply a local mask, with local horizon size of mlen
        local_size = self.local_size or 1000
        attn_mask = torch.triu(torch.ones(qlen, klen), diagonal=1+mlen).byte()[None,:,:]
        attn_mask += torch.tril(torch.ones(qlen, klen), diagonal=mlen-local_size).byte()[None,:,:]
        if attn_mask is not None and attn_mask.any().item():
            attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,:], -float('inf')).type_as(attn_score)
                
        attn_prob = F.softmax(attn_score, dim=-1)          # bsz x n_head x qlen x klen
            
        #### compute attention vector
        attn_vec = torch.einsum('bnij,bndj->bndi', (attn_prob, w_head_v))
        
        # [bsz x d x qlen]
        attn_vec = attn_vec.contiguous().view(bsz, n_head*d_head, attn_vec.size(-1))

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)
        
        ##### residual connection + layer normolization (if applicable)
        if self.pre_lnorm:
            out = attn_out + z1ss
        else:
            out = F.layer_norm((attn_out + z1ss).transpose(1,2), (d_model,)).transpose(1,2)
        return out