def test_bwd_dw(self, mb, i, o, block_size): X, masked_W, bs_W, mask = get_rand_vals(mb, i, o, block_size, True) masked_W.requires_grad = True bs_W.data.requires_grad = True ts.mm(X, bs_W).sum().backward() (X @ masked_W).sum().backward() torch.testing.assert_allclose( (masked_W.grad * mask).flatten().sum(), bs_W.data.grad.flatten().sum() )
def test_bwd_dx(self, mb, i, o, block_size): X1, masked_W, bs_W = get_rand_vals(mb, i, o, block_size) X2 = X1.clone() X1.requires_grad = True X2.requires_grad = True Y = ts.mm(X1, bs_W) Y.sum().backward() Y_ref = X2 @ masked_W Y_ref.sum().backward() torch.testing.assert_allclose(Y, Y_ref) torch.testing.assert_allclose(X1.grad, X2.grad)
import torch import numpy as np import torch_sparse as ts minibatch = 64 input_features = 128 output_features = 512 block_size = 16 cuda = torch.device("cuda") X = torch.randn(minibatch, input_features, device=cuda) W = torch.randn(input_features, output_features, device=cuda) # First, generate a sparse layout ib = input_features // block_size ob = output_features // block_size layout = np.random.randint(2, size=(ib, ob)) # Then, create a blocksparse weight bs_W = ts.BlockSparseTensor(W, layout) # Differentiable matrix multiplication Y = ts.mm(X, bs_W)
def test_fwd(self, mb, i, o, block_size): X, masked_W, bs_W = get_rand_vals(mb, i, o, block_size) Y = ts.mm(X, bs_W) Y_ref = X @ masked_W torch.testing.assert_allclose(Y, Y_ref)
def test_ident(self): X = torch.randn(128, 128, device=cuda) W = torch.eye(128, device=cuda) bs_W = ts.createBlockSparseTensor(W, torch.ones(8, 8)) Y = ts.mm(X, bs_W) torch.testing.assert_allclose(X, Y)
def f_t(X, W): o = ts.mm(X, W, True) return o
def f(X, W): o = ts.mm(X, W) return o