Beispiel #1
0
def test_spmm(dtype, device, reduce):
    src = torch.randn((10, 8), dtype=dtype, device=device)
    src[2:4, :] = 0  # Remove multiple rows.
    src[:, 2:4] = 0  # Remove multiple columns.
    src = SparseTensor.from_dense(src).requires_grad_()
    row, col, value = src.coo()

    other = torch.randn((2, 8, 2),
                        dtype=dtype,
                        device=device,
                        requires_grad=True)

    src_col = other.index_select(-2, col) * value.unsqueeze(-1)
    expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
    if reduce == 'min':
        expected[expected > 1000] = 0
    if reduce == 'max':
        expected[expected < -1000] = 0

    grad_out = torch.randn_like(expected)

    expected.backward(grad_out)
    expected_grad_value = value.grad
    value.grad = None
    expected_grad_other = other.grad
    other.grad = None

    out = matmul(src, other, reduce)
    out.backward(grad_out)

    assert torch.allclose(expected, out, atol=1e-6)
    assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
    assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
Beispiel #2
0
def test_metis(device):
    value1 = torch.randn(6 * 6, device=device).view(6, 6)
    value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    value3 = torch.ones(6 * 6, device=device).view(6, 6)

    for value in [value1, value2, value3]:
        mat = SparseTensor.from_dense(value)

        _, partptr, perm = mat.partition(num_parts=2,
                                         recursive=False,
                                         weighted=True)
        assert partptr.numel() == 3
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=2,
                                         recursive=False,
                                         weighted=False)
        assert partptr.numel() == 3
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=1,
                                         recursive=False,
                                         weighted=True)
        assert partptr.numel() == 2
        assert perm.numel() == 6
Beispiel #3
0
def test_spmm_half_precision():
    src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
    src_dense[2:4, :] = 0  # Remove multiple rows.
    src_dense[:, 2:4] = 0  # Remove multiple columns.
    src = SparseTensor.from_dense(src_dense)

    other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')

    expected = (src_dense.to(torch.float) @ other).to(torch.half)
    out = src @ other.to(torch.half)

    assert torch.allclose(expected, out, atol=1e-2)
def test_spspmm(dtype, device):
    src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
                       device=device)

    src = SparseTensor.from_dense(src)
    out = matmul(src, src)
    assert out.sizes() == [3, 3]
    assert out.has_value()
    rowptr, col, value = out.csr()
    assert rowptr.tolist() == [0, 1, 2, 3]
    assert col.tolist() == [0, 1, 2]
    assert value.tolist() == [1, 1, 1]

    src.set_value_(None)
    out = matmul(src, src)
    assert out.sizes() == [3, 3]
    assert not out.has_value()
    rowptr, col, value = out.csr()
    assert rowptr.tolist() == [0, 1, 2, 3]
    assert col.tolist() == [0, 1, 2]
def test_metis(device, weighted):
    mat1 = torch.randn(6 * 6, device=device).view(6, 6)
    mat2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    mat3 = torch.ones(6 * 6, device=device).view(6, 6)

    vec1 = None
    vec2 = torch.rand(6, device=device)

    for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]):
        mat = SparseTensor.from_dense(mat)

        _, partptr, perm = mat.partition(num_parts=1,
                                         recursive=False,
                                         weighted=weighted,
                                         node_weight=vec)
        assert partptr.numel() == 2
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=2,
                                         recursive=False,
                                         weighted=weighted,
                                         node_weight=vec)
        assert partptr.numel() == 3
        assert perm.numel() == 6