コード例 #1
0
def correctness(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            out1 = mat @ x
            out2 = mat_pytorch @ x

            assert torch.allclose(out1, out2, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
コード例 #2
0
ファイル: test_cat.py プロジェクト: zeta1999/pytorch_sparse
def test_cat(device):
    row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
    mat1 = SparseTensor(row=row, col=col)
    mat1.fill_cache_()

    row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
    mat2 = SparseTensor(row=row, col=col)
    mat2.fill_cache_()

    out = cat([mat1, mat2], dim=0)
    assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
                                       [0, 1, 0], [1, 0, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.has_rowcount()
    assert out.storage.num_cached_keys() == 1

    out = cat([mat1, mat2], dim=1)
    assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert not out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 2

    out = cat([mat1, mat2], dim=(0, 1))
    assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
                                       [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5

    value = torch.randn((mat1.nnz(), 4), device=device)
    mat1 = mat1.set_value_(value, layout='coo')
    out = cat([mat1, mat1], dim=-1)
    assert out.storage.value().size() == (mat1.nnz(), 8)
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5
コード例 #3
0
def test_remove_diag(dtype, device):
    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    value = tensor([1, 2, 3, 4], dtype, device)
    mat = SparseTensor(row=row, col=col, value=value)
    mat.fill_cache_()

    mat = mat.remove_diag()
    assert mat.storage.row().tolist() == [0, 1]
    assert mat.storage.col().tolist() == [1, 2]
    assert mat.storage.value().tolist() == [2, 3]
    assert mat.storage.num_cached_keys() == 2
    assert mat.storage.rowcount().tolist() == [1, 1, 0]
    assert mat.storage.colcount().tolist() == [0, 1, 1]

    mat = SparseTensor(row=row, col=col, value=value)
    mat.fill_cache_()

    mat = mat.remove_diag(k=1)
    assert mat.storage.row().tolist() == [0, 2]
    assert mat.storage.col().tolist() == [0, 2]
    assert mat.storage.value().tolist() == [1, 4]
    assert mat.storage.num_cached_keys() == 2
    assert mat.storage.rowcount().tolist() == [1, 0, 1]
    assert mat.storage.colcount().tolist() == [1, 0, 1]
コード例 #4
0
def timing(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
    mat_scipy = mat.to_scipy(layout='csr')

    def scatter(x):
        return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0])

    def spmm_scipy(x):
        if x.is_cuda:
            raise RuntimeError('out of memory')
        return mat_scipy @ x

    def spmm_pytorch(x):
        return mat_pytorch @ x

    def spmm(x):
        return mat @ x

    t1, t2, t3, t4 = [], [], [], []

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            t1 += [time_func(scatter, x)]
            t2 += [time_func(spmm_scipy, x)]
            t3 += [time_func(spmm_pytorch, x)]
            t4 += [time_func(spmm, x)]

            del x

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
            for t in (t1, t2, t3, t4):
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):')
    print('\t'.join(['            '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('Scatter     ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('SPMM SciPy  ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('SPMM PyTorch')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('SPMM Own    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()