def reduction(src: SparseTensor, dim: Optional[int] = None, reduce: str = 'sum') -> torch.Tensor: value = src.storage.value() if dim is None: if value is not None: if reduce == 'sum' or reduce == 'add': return value.sum() elif reduce == 'mean': return value.mean() elif reduce == 'min': return value.min() elif reduce == 'max': return value.max() else: raise ValueError else: if reduce == 'sum' or reduce == 'add': return torch.tensor(src.nnz(), dtype=src.dtype(), device=src.device()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.tensor(1, dtype=src.dtype(), device=src.device()) else: raise ValueError else: if dim < 0: dim = src.dim() + dim if dim == 0 and value is not None: col = src.storage.col() return scatter(value, col, dim=0, dim_size=src.size(0)) elif dim == 0 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.colcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(1), dtype=src.dtype()) else: raise ValueError elif dim == 1 and value is not None: return segment_csr(value, src.storage.rowptr(), None, reduce) elif dim == 1 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.rowcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(0), dtype=src.dtype()) else: raise ValueError elif dim > 1 and value is not None: if reduce == 'sum' or reduce == 'add': return value.sum(dim=dim - 1) elif reduce == 'mean': return value.mean(dim=dim - 1) elif reduce == 'min': return value.min(dim=dim - 1)[0] elif reduce == 'max': return value.max(dim=dim - 1)[0] else: raise ValueError else: raise ValueError
def partition( src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: assert num_parts >= 1 if num_parts == 1: partptr = torch.tensor([0, src.size(0)], device=src.device()) perm = torch.arange(src.size(0), device=src.device()) return src, partptr, perm rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu() if value is not None and weighted: assert value.numel() == col.numel() value = value.view(-1).detach().cpu() if value.is_floating_point(): value = weight2metis(value) else: value = None cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, recursive) cluster = cluster.to(src.device()) cluster, perm = cluster.sort() out = permute(src, perm) partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) return out, partptr, perm
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, k: int = 0) -> SparseTensor: src = remove_diag(src, k=k) row, col, value = src.coo() mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0), src.size(1), k) inv_mask = ~mask start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() diag = torch.arange(start, start + num_diag, device=row.device) new_row = row.new_empty(mask.size(0)) new_row[mask] = row new_row[inv_mask] = diag new_col = col.new_empty(mask.size(0)) new_col[mask] = col new_col[inv_mask] = diag.add_(k) new_value: Optional[torch.Tensor] = None if value is not None: new_value = value.new_empty((mask.size(0), ) + value.size()[1:]) new_value[mask] = value if values is not None: new_value[inv_mask] = values else: new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype, device=value.device) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.clone() rowcount[start:start + num_diag] += 1 colcount = src.storage._colcount if colcount is not None: colcount = colcount.clone() colcount[start + k:start + num_diag + k] += 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=new_value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def get_diag(src: SparseTensor) -> Tensor: row, col, value = src.coo() if value is None: value = torch.ones(row.size(0)) sizes = list(value.size()) sizes[0] = min(src.size(0), src.size(1)) out = value.new_zeros(sizes) mask = row == col out[row[mask]] = value[mask] return out
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = value.add_(other.to(value.dtype)) else: value = other.add_(1) return src.set_value_(value, layout='coo')
def correctness(dataset): group, name = dataset mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long) col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long) mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape) mat.fill_cache_() mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce() for size in sizes: try: x = torch.randn((mat.size(1), size), device=args.device) out1 = mat @ x out2 = mat_pytorch @ x assert torch.allclose(out1, out2, atol=1e-4) except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache()
def timing(dataset): group, name = dataset mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long) col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long) mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape) mat.fill_cache_() mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce() mat_scipy = mat.to_scipy(layout='csr') def scatter(x): return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0]) def spmm_scipy(x): if x.is_cuda: raise RuntimeError('out of memory') return mat_scipy @ x def spmm_pytorch(x): return mat_pytorch @ x def spmm(x): return mat @ x t1, t2, t3, t4 = [], [], [], [] for size in sizes: try: x = torch.randn((mat.size(1), size), device=args.device) t1 += [time_func(scatter, x)] t2 += [time_func(spmm_scipy, x)] t3 += [time_func(spmm_pytorch, x)] t4 += [time_func(spmm, x)] del x except RuntimeError as e: if 'out of memory' not in str(e): raise RuntimeError(e) torch.cuda.empty_cache() for t in (t1, t2, t3, t4): t.append(float('inf')) ts = torch.tensor([t1, t2, t3, t4]) winner = torch.zeros_like(ts, dtype=torch.bool) winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1 winner = winner.tolist() name = f'{group}/{name}' print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):') print('\t'.join([' '] + [f'{size:>5}' for size in sizes])) print('\t'.join([bold('Scatter ')] + [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])])) print('\t'.join([bold('SPMM SciPy ')] + [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])])) print('\t'.join([bold('SPMM PyTorch')] + [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])])) print('\t'.join([bold('SPMM Own ')] + [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])])) print()
def narrow(src: SparseTensor, dim: int, start: int, length: int) -> SparseTensor: if dim < 0: dim = src.dim() + dim if start < 0: start = src.size(dim) + start if dim == 0: rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start, length=length + 1) row_start = rowptr[0] rowptr = rowptr - row_start row_length = rowptr[-1] row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start col = col.narrow(0, row_start, row_length) if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = (length, src.sparse_size(1)) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: # This is faster than accessing `csc()` contrary to the `dim=0` case. row, col, value = src.coo() mask = (col >= start) & (col < start + length) row = row[mask] col = col[mask] - start if value is not None: value = value[mask] sparse_sizes = (src.sparse_size(0), length) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr - colptr[0] colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.narrow(dim - 1, start, length), layout='coo') else: raise ValueError