def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: assert src.sparse_size(1) == other.sparse_size(0) rowptrA, colA, valueA = src.csr() rowptrB, colB, valueB = other.csr() M, K = src.sparse_size(0), other.sparse_size(1) rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrA, colA, valueA, rowptrB, colB, valueB, K) return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, sparse_sizes=(M, K), is_sorted=True)
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: assert src.sparse_size(1) == other.sparse_size(0) rowptrA, colA, valueA = src.csr() rowptrB, colB, valueB = other.csr() value = valueA if valueA is not None else valueB if valueA is not None and valueA.dtype == torch.half: valueA = valueA.to(torch.float) if valueB is not None and valueB.dtype == torch.half: valueB = valueB.to(torch.float) M, K = src.sparse_size(0), other.sparse_size(1) rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrA, colA, valueA, rowptrB, colB, valueB, K) if valueC is not None and value is not None: valueC = valueC.to(value.dtype) return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, sparse_sizes=(M, K), is_sorted=True)
def partition( src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: assert num_parts >= 1 if num_parts == 1: partptr = torch.tensor([0, src.size(0)], device=src.device()) perm = torch.arange(src.size(0), device=src.device()) return src, partptr, perm rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu() if value is not None and weighted: assert value.numel() == col.numel() value = value.view(-1).detach().cpu() if value.is_floating_point(): value = weight2metis(value) else: value = None cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, recursive) cluster = cluster.to(src.device()) cluster, perm = cluster.sort() out = permute(src, perm) partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) return out, partptr, perm
def spmm_max(src: SparseTensor, other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: rowptr, col, value = src.csr() if value is not None: value = value.to(other.dtype) return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int], length: Tuple[int, int]) -> SparseTensor: # This function builds the inverse operation of `cat_diag` and should hence # only be used on *diagonally stacked* sparse matrices. # That's the reason why this method is marked as *private*. rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1) row_start = int(rowptr[0]) rowptr = rowptr - row_start row_length = int(rowptr[-1]) row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start[0] col = col.narrow(0, row_start, row_length) - start[1] if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = length rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start[0], length[0]) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start[1], length[1] + 1) colptr = colptr - int(colptr[0]) # i.e. `row_start` colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start[1], length[1]) csr2csc = src.storage._csr2csc if csr2csc is not None: csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start csc2csr = src.storage._csc2csr if csc2csr is not None: csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage)
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = value.add_(other.to(value.dtype)) else: value = other.add_(1) return src.set_value_(value, layout='coo')
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: rowptr, col, value = src.csr() row = src.storage._row csr2csc = src.storage._csr2csc colptr = src.storage._colptr if value is not None and value.requires_grad: row = src.storage.row() if other.requires_grad: row = src.storage.row() csr2csc = src.storage.csr2csc() colptr = src.storage.colptr() return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr, csr2csc, other)
def sample(src: SparseTensor, num_neighbors: int, subset: Optional[torch.Tensor] = None) -> torch.Tensor: rowptr, col, _ = src.csr() rowcount = src.storage.rowcount() if subset is not None: rowcount = rowcount[subset] rowptr = rowptr[subset] rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device) rand.mul_(rowcount.to(rand.dtype).view(-1, 1)) rand = rand.to(torch.long) rand.add_(rowptr.view(-1, 1)) return col[rand]
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int, replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]: rowptr, col, value = src.csr() rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj( rowptr, col, subset, num_neighbors, replace) if value is not None: value = value[e_id] out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value, sparse_sizes=(subset.size(0), n_id.size(0)), is_sorted=True) return out, n_id
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: rowptr, col, value = src.csr() row = src.storage._row rowcount = src.storage._rowcount csr2csc = src.storage._csr2csc colptr = src.storage._colptr if value is not None: value = value.to(other.dtype) if value is not None and value.requires_grad: row = src.storage.row() if other.requires_grad: row = src.storage.row() rowcount = src.storage.rowcount() csr2csc = src.storage.csr2csc() colptr = src.storage.colptr() return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount, colptr, csr2csc, other)
def spmm_max(src: SparseTensor, other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: rowptr, col, value = src.csr() return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
def index_select(src: SparseTensor, dim: int, idx: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert idx.dim() == 1 if dim == 0: old_rowptr, col, value = src.csr() rowcount = src.storage.rowcount() rowcount = rowcount[idx] rowptr = col.new_zeros(idx.size(0) + 1) torch.cumsum(rowcount, dim=0, out=rowptr[1:]) row = torch.arange(idx.size(0), device=col.device).repeat_interleave(rowcount) perm = torch.arange(row.size(0), device=row.device) perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr) col = col[perm] if value is not None: value = value[perm] sparse_sizes = (idx.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: old_colptr, row, value = src.csc() colcount = src.storage.colcount() colcount = colcount[idx] colptr = row.new_zeros(idx.size(0) + 1) torch.cumsum(colcount, dim=0, out=colptr[1:]) col = torch.arange(idx.size(0), device=row.device).repeat_interleave(colcount) perm = torch.arange(col.size(0), device=col.device) perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr) row = row[perm] csc2csr = (idx.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[perm][csc2csr] sparse_sizes = (src.sparse_size(0), idx.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def random_walk(src: SparseTensor, start: torch.Tensor, walk_length: int) -> torch.Tensor: rowptr, col, _ = src.csr() return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
def narrow(src: SparseTensor, dim: int, start: int, length: int) -> SparseTensor: if dim < 0: dim = src.dim() + dim if start < 0: start = src.size(dim) + start if dim == 0: rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start, length=length + 1) row_start = rowptr[0] rowptr = rowptr - row_start row_length = rowptr[-1] row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start col = col.narrow(0, row_start, row_length) if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = (length, src.sparse_size(1)) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: # This is faster than accessing `csc()` contrary to the `dim=0` case. row, col, value = src.coo() mask = (col >= start) & (col < start + length) row = row[mask] col = col[mask] - start if value is not None: value = value[mask] sparse_sizes = (src.sparse_size(0), length) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr - colptr[0] colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.narrow(dim - 1, start, length), layout='coo') else: raise ValueError