def test_spmm(dtype, device, reduce): src = torch.randn((10, 8), dtype=dtype, device=device) src[2:4, :] = 0 # Remove multiple rows. src[:, 2:4] = 0 # Remove multiple columns. src = SparseTensor.from_dense(src).requires_grad_() row, col, value = src.coo() other = torch.randn((2, 8, 2), dtype=dtype, device=device, requires_grad=True) src_col = other.index_select(-2, col) * value.unsqueeze(-1) expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce) if reduce == 'min': expected[expected > 1000] = 0 if reduce == 'max': expected[expected < -1000] = 0 grad_out = torch.randn_like(expected) expected.backward(grad_out) expected_grad_value = value.grad value.grad = None expected_grad_other = other.grad other.grad = None out = matmul(src, other, reduce) out.backward(grad_out) assert torch.allclose(expected, out, atol=1e-6) assert torch.allclose(expected_grad_value, value.grad, atol=1e-6) assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def test_metis(device): value1 = torch.randn(6 * 6, device=device).view(6, 6) value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6) value3 = torch.ones(6 * 6, device=device).view(6, 6) for value in [value1, value2, value3]: mat = SparseTensor.from_dense(value) _, partptr, perm = mat.partition(num_parts=2, recursive=False, weighted=True) assert partptr.numel() == 3 assert perm.numel() == 6 _, partptr, perm = mat.partition(num_parts=2, recursive=False, weighted=False) assert partptr.numel() == 3 assert perm.numel() == 6 _, partptr, perm = mat.partition(num_parts=1, recursive=False, weighted=True) assert partptr.numel() == 2 assert perm.numel() == 6
def test_spmm_half_precision(): src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu') src_dense[2:4, :] = 0 # Remove multiple rows. src_dense[:, 2:4] = 0 # Remove multiple columns. src = SparseTensor.from_dense(src_dense) other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu') expected = (src_dense.to(torch.float) @ other).to(torch.half) out = src @ other.to(torch.half) assert torch.allclose(expected, out, atol=1e-2)
def test_spspmm(dtype, device): src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, device=device) src = SparseTensor.from_dense(src) out = matmul(src, src) assert out.sizes() == [3, 3] assert out.has_value() rowptr, col, value = out.csr() assert rowptr.tolist() == [0, 1, 2, 3] assert col.tolist() == [0, 1, 2] assert value.tolist() == [1, 1, 1] src.set_value_(None) out = matmul(src, src) assert out.sizes() == [3, 3] assert not out.has_value() rowptr, col, value = out.csr() assert rowptr.tolist() == [0, 1, 2, 3] assert col.tolist() == [0, 1, 2]
def test_metis(device, weighted): mat1 = torch.randn(6 * 6, device=device).view(6, 6) mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6) mat3 = torch.ones(6 * 6, device=device).view(6, 6) vec1 = None vec2 = torch.rand(6, device=device) for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]): mat = SparseTensor.from_dense(mat) _, partptr, perm = mat.partition(num_parts=1, recursive=False, weighted=weighted, node_weight=vec) assert partptr.numel() == 2 assert perm.numel() == 6 _, partptr, perm = mat.partition(num_parts=2, recursive=False, weighted=weighted, node_weight=vec) assert partptr.numel() == 3 assert perm.numel() == 6