Beispiel #1
0
 def __init__(self, temperature, topk, grad_sparse, attn_dropout=0.1):
     super().__init__()
     self.temperature = temperature
     #self.dropout = nn.Dropout(attn_dropout)
     self.softmax = nn.Softmax(dim=2)
     self.grad_sparse = grad_sparse
     #print('top 2 sparsity')
     self.sa = Sparse_attention(top_k=topk)  #k=2
Beispiel #2
0
    def __init__(self, top_k):
        super(Sparse_grad_attention, self).__init__()

        self.sa = Sparse_attention(top_k=top_k)
Beispiel #3
0
    def forward(self, inp):

        sparsified = self.sa(inp)
        self.save_for_backward(inp, sparsified)

        return inp

    def backward(self, grad_output):
        inp, sparsified = self.saved_tensors
        #print('sparsified', sparsified)
        return (grad_output) * (sparsified > 0.0).float()


if __name__ == "__main__":
    k = 2
    sga = Sparse_grad_attention(k)
    sa = Sparse_attention(k)

    x = torch.from_numpy(
        numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4], [0.5, 0.4, 0.1, 0.0, 0.0]]]))
    x = x.reshape((2, 5))
    mask = x * 0.0 + 0.1

    x = Variable(x.data, requires_grad=True)

    bg = blocked_grad()
    (((bg(x, mask))**2).sum()).backward()

    print('normal grad', x.grad)