def reverse_cuthill_mckee( src: SparseTensor, is_symmetric: Optional[bool] = None ) -> Tuple[SparseTensor, torch.Tensor]: if is_symmetric is None: is_symmetric = src.is_symmetric() if not is_symmetric: src = src.to_symmetric() sp_src = src.to_scipy(layout='csr') perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy() perm = torch.from_numpy(perm).to(torch.long).to(src.device()) out = permute(src, perm) return out, perm
def timing(dataset): group, name = dataset mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long) col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long) mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape) mat.fill_cache_() mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce() mat_scipy = mat.to_scipy(layout='csr') def scatter(x): return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0]) def spmm_scipy(x): if x.is_cuda: raise RuntimeError('out of memory') return mat_scipy @ x def spmm_pytorch(x): return mat_pytorch @ x def spmm(x): return mat @ x t1, t2, t3, t4 = [], [], [], [] for size in sizes: try: x = torch.randn((mat.size(1), size), device=args.device) t1 += [time_func(scatter, x)] t2 += [time_func(spmm_scipy, x)] t3 += [time_func(spmm_pytorch, x)] t4 += [time_func(spmm, x)] del x except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache() for t in (t1, t2, t3, t4): t.append(float('inf')) ts = torch.tensor([t1, t2, t3, t4]) winner = torch.zeros_like(ts, dtype=torch.bool) winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1 winner = winner.tolist() name = f'{group}/{name}' print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):') print('\t'.join([' '] + [f'{size:>5}' for size in sizes])) print('\t'.join([bold('Scatter ')] + [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])])) print('\t'.join([bold('SPMM SciPy ')] + [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])])) print('\t'.join([bold('SPMM PyTorch')] + [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])])) print('\t'.join([bold('SPMM Own ')] + [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])])) print()