Example #1
0
def correctness(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            out1 = mat @ x
            out2 = mat_pytorch @ x

            assert torch.allclose(out1, out2, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Example #2
0
def timing(dataset):
    group, name = dataset
    mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
    col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
    mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
    mat.fill_cache_()
    mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
    mat_scipy = mat.to_scipy(layout='csr')

    def scatter(x):
        return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0])

    def spmm_scipy(x):
        if x.is_cuda:
            raise RuntimeError('out of memory')
        return mat_scipy @ x

    def spmm_pytorch(x):
        return mat_pytorch @ x

    def spmm(x):
        return mat @ x

    t1, t2, t3, t4 = [], [], [], []

    for size in sizes:
        try:
            x = torch.randn((mat.size(1), size), device=args.device)

            t1 += [time_func(scatter, x)]
            t2 += [time_func(spmm_scipy, x)]
            t3 += [time_func(spmm_pytorch, x)]
            t4 += [time_func(spmm, x)]

            del x

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
            for t in (t1, t2, t3, t4):
                t.append(float('inf'))

    ts = torch.tensor([t1, t2, t3, t4])
    winner = torch.zeros_like(ts, dtype=torch.bool)
    winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
    winner = winner.tolist()

    name = f'{group}/{name}'
    print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):')
    print('\t'.join(['            '] + [f'{size:>5}' for size in sizes]))
    print('\t'.join([bold('Scatter     ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
    print('\t'.join([bold('SPMM SciPy  ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
    print('\t'.join([bold('SPMM PyTorch')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
    print('\t'.join([bold('SPMM Own    ')] +
                    [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
    print()