def __init__(self, row: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, sparse_sizes: Optional[Tuple[int, int]] = None, is_sorted: bool = False): self.storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=is_sorted)
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: row, col, value = src.coo() inv_mask = row != col if k == 0 else row != (col - k) new_row, new_col = row[inv_mask], col[inv_mask] if value is not None: value = value[inv_mask] rowcount = src.storage._rowcount colcount = src.storage._colcount if rowcount is not None or colcount is not None: mask = ~inv_mask if rowcount is not None: rowcount = rowcount.clone() rowcount[row[mask]] -= 1 if colcount is not None: colcount = colcount.clone() colcount[col[mask]] -= 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor: colptr = None if isinstance(mat, scipy.sparse.csc_matrix): colptr = torch.from_numpy(mat.indptr).to(torch.long) mat = mat.tocsr() rowptr = torch.from_numpy(mat.indptr).to(torch.long) mat = mat.tocoo() row = torch.from_numpy(mat.row).to(torch.long) col = torch.from_numpy(mat.col).to(torch.long) value = None if has_value: value = torch.from_numpy(mat.data) sparse_sizes = mat.shape[:2] storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return SparseTensor.from_storage(storage)
def t(src: SparseTensor) -> SparseTensor: csr2csc = src.storage.csr2csc() row, col, value = src.coo() if value is not None: value = value[csr2csc] sparse_sizes = src.storage.sparse_sizes() storage = SparseStorage( row=col[csr2csc], rowptr=src.storage._colptr, col=row[csr2csc], value=value, sparse_sizes=(sparse_sizes[1], sparse_sizes[0]), rowcount=src.storage._colcount, colptr=src.storage._rowptr, colcount=src.storage._rowcount, csr2csc=src.storage._csc2csr, csc2csr=csr2csc, is_sorted=True, ) return src.from_storage(storage)
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int], length: Tuple[int, int]) -> SparseTensor: # This function builds the inverse operation of `cat_diag` and should hence # only be used on *diagonally stacked* sparse matrices. # That's the reason why this method is marked as *private*. rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1) row_start = int(rowptr[0]) rowptr = rowptr - row_start row_length = int(rowptr[-1]) row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start[0] col = col.narrow(0, row_start, row_length) - start[1] if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = length rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start[0], length[0]) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start[1], length[1] + 1) colptr = colptr - int(colptr[0]) # i.e. `row_start` colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start[1], length[1]) csr2csc = src.storage._csr2csc if csr2csc is not None: csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start csc2csr = src.storage._csc2csr if csc2csr is not None: csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage)
def cat_second(tensors: List[SparseTensor]) -> SparseTensor: rows: List[torch.Tensor] = [] cols: List[torch.Tensor] = [] values: List[torch.Tensor] = [] sparse_sizes: List[int] = [0, 0] colptrs: List[torch.Tensor] = [] colcounts: List[torch.Tensor] = [] nnz: int = 0 for tensor in tensors: row, col, value = tensor.coo() rows.append(row) cols.append(tensor.storage._col + sparse_sizes[1]) if value is not None: values.append(value) colptr = tensor.storage._colptr if colptr is not None: colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr) colcount = tensor.storage._colcount if colcount is not None: colcounts.append(colcount) sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0)) sparse_sizes[1] += tensor.sparse_size(1) nnz += tensor.nnz() row = torch.cat(rows, dim=0) col = torch.cat(cols, dim=0) value: Optional[torch.Tensor] = None if len(values) == len(tensors): value = torch.cat(values, dim=0) colptr: Optional[torch.Tensor] = None if len(colptrs) == len(tensors): colptr = torch.cat(colptrs, dim=0) colcount: Optional[torch.Tensor] = None if len(colcounts) == len(tensors): colcount = torch.cat(colcounts, dim=0) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]), rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=False) return tensors[0].from_storage(storage)
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True, dtype: Optional[int] = None, device: Optional[torch.device] = None, fill_cache: bool = False): N = M if N is None else N row = torch.arange(min(M, N), device=device) col = row rowptr = torch.arange(M + 1, device=row.device) if M > N: rowptr[N + 1:] = N value: Optional[torch.Tensor] = None if has_value: value = torch.ones(row.numel(), dtype=dtype, device=row.device) rowcount: Optional[torch.Tensor] = None colptr: Optional[torch.Tensor] = None colcount: Optional[torch.Tensor] = None csr2csc: Optional[torch.Tensor] = None csc2csr: Optional[torch.Tensor] = None if fill_cache: rowcount = torch.ones(M, dtype=torch.long, device=row.device) if M > N: rowcount[N:] = 0 colptr = torch.arange(N + 1, dtype=torch.long, device=row.device) colcount = torch.ones(N, dtype=torch.long, device=row.device) if N > M: colptr[M + 1:] = M colcount[M:] = 0 csr2csc = csc2csr = row storage: SparseStorage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N), rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) self = SparseTensor.__new__(SparseTensor) self.storage = storage return self
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, k: int = 0) -> SparseTensor: src = remove_diag(src, k=k) row, col, value = src.coo() mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0), src.size(1), k) inv_mask = ~mask start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() diag = torch.arange(start, start + num_diag, device=row.device) new_row = row.new_empty(mask.size(0)) new_row[mask] = row new_row[inv_mask] = diag new_col = col.new_empty(mask.size(0)) new_col[mask] = col new_col[inv_mask] = diag.add_(k) new_value: Optional[torch.Tensor] = None if value is not None: new_value = value.new_empty((mask.size(0), ) + value.size()[1:]) new_value[mask] = value if values is not None: new_value[inv_mask] = values else: new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype, device=value.device) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.clone() rowcount[start:start + num_diag] += 1 colcount = src.storage._colcount if colcount is not None: colcount = colcount.clone() colcount[start + k:start + num_diag + k] += 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=new_value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def coalesce(index: torch.Tensor, m: int, n: int, value: Optional[torch.Tensor] = None, op: str="add"): """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate entries are removed by scattering them together. For scattering, any operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_ can be used. Args: index (:class:`LongTensor`): The index tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix. m (int): The first dimension of corresponding dense matrix. n (int): The second dimension of corresponding dense matrix. op (string, optional): The scatter operation to use. (default: :obj:`"add"`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ sparse_sizes = (m, n) storage = SparseStorage( row=index[0], col=index[1], value=value, sparse_sizes=sparse_sizes, is_sorted=False, rowptr=None, rowcount=None, colptr=None, colcount=None, csr2csc=None, csc2csr=None ) storage = storage.coalesce(reduce=op) return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
def transpose(index, value, m, n, coalesced=True): """Transposes dimensions 0 and 1 of a sparse tensor. Args: index (:class:`LongTensor`): The index tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix. m (int): The first dimension of corresponding dense matrix. n (int): The second dimension of corresponding dense matrix. coalesced (bool, optional): If set to :obj:`False`, will not coalesce the output. (default: :obj:`True`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ row, col = index row, col = col, row if coalesced: sparse_sizes = (n, m) storage = SparseStorage(row=row, col=col, value=value, sparse_sizes=sparse_sizes, is_sorted=False) storage = storage.coalesce() row, col, value = storage.row(), storage.col(), storage.value() return torch.stack([row, col], dim=0), value
class SparseTensor(object): storage: SparseStorage def __init__(self, row: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, sparse_sizes: Optional[Tuple[int, int]] = None, is_sorted: bool = False): self.storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=is_sorted) @classmethod def from_storage(self, storage: SparseStorage): self = SparseTensor.__new__(SparseTensor) self.storage = storage return self @classmethod def from_edge_index(self, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None, sparse_sizes: Optional[Tuple[int, int]] = None, is_sorted: bool = False): return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1], value=edge_attr, sparse_sizes=sparse_sizes, is_sorted=is_sorted) @classmethod def from_dense(self, mat: torch.Tensor, has_value: bool = True): if mat.dim() > 2: index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero() else: index = mat.nonzero() index = index.t() row = index[0] col = index[1] value: Optional[torch.Tensor] = None if has_value: value = mat[row, col] return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=(mat.size(0), mat.size(1)), is_sorted=True) @classmethod def from_torch_sparse_coo_tensor(self, mat: torch.Tensor, has_value: bool = True): mat = mat.coalesce() index = mat._indices() row, col = index[0], index[1] value: Optional[torch.Tensor] = None if has_value: value = mat.values() return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=(mat.size(0), mat.size(1)), is_sorted=True) @classmethod def eye(self, M: int, N: Optional[int] = None, has_value: bool = True, dtype: Optional[int] = None, device: Optional[torch.device] = None, fill_cache: bool = False): N = M if N is None else N row = torch.arange(min(M, N), device=device) col = row rowptr = torch.arange(M + 1, device=row.device) if M > N: rowptr[N + 1:] = N value: Optional[torch.Tensor] = None if has_value: value = torch.ones(row.numel(), dtype=dtype, device=row.device) rowcount: Optional[torch.Tensor] = None colptr: Optional[torch.Tensor] = None colcount: Optional[torch.Tensor] = None csr2csc: Optional[torch.Tensor] = None csc2csr: Optional[torch.Tensor] = None if fill_cache: rowcount = torch.ones(M, dtype=torch.long, device=row.device) if M > N: rowcount[N:] = 0 colptr = torch.arange(N + 1, dtype=torch.long, device=row.device) colcount = torch.ones(N, dtype=torch.long, device=row.device) if N > M: colptr[M + 1:] = M colcount[M:] = 0 csr2csc = csc2csr = row storage: SparseStorage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N), rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) self = SparseTensor.__new__(SparseTensor) self.storage = storage return self def copy(self): return self.from_storage(self.storage) def clone(self): return self.from_storage(self.storage.clone()) def type_as(self, tensor: torch.Tensor): value = self.storage.value() if value is None or tensor.dtype == value.dtype: return self return self.from_storage(self.storage.type_as(tensor)) def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): if tensor.device == self.device(): return self return self.from_storage(self.storage.device_as(tensor, non_blocking)) # Formats ################################################################# def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.storage.row(), self.storage.col(), self.storage.value() def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.storage.rowptr(), self.storage.col(), self.storage.value() def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: perm = self.storage.csr2csc() value = self.storage.value() if value is not None: value = value[perm] return self.storage.colptr(), self.storage.row()[perm], value # Storage inheritance ##################################################### def has_value(self) -> bool: return self.storage.has_value() def set_value_(self, value: Optional[torch.Tensor], layout: Optional[str] = None): self.storage.set_value_(value, layout) return self def set_value(self, value: Optional[torch.Tensor], layout: Optional[str] = None): return self.from_storage(self.storage.set_value(value, layout)) def sparse_sizes(self) -> Tuple[int, int]: return self.storage.sparse_sizes() def sparse_size(self, dim: int) -> int: return self.storage.sparse_sizes()[dim] def sparse_resize(self, sparse_sizes: Tuple[int, int]): return self.from_storage(self.storage.sparse_resize(sparse_sizes)) def sparse_reshape(self, num_rows: int, num_cols: int): return self.from_storage( self.storage.sparse_reshape(num_rows, num_cols)) def is_coalesced(self) -> bool: return self.storage.is_coalesced() def coalesce(self, reduce: str = "sum"): return self.from_storage(self.storage.coalesce(reduce)) def fill_cache_(self): self.storage.fill_cache_() return self def clear_cache_(self): self.storage.clear_cache_() return self # Utility functions ####################################################### def fill_value_(self, fill_value: float, dtype: Optional[int] = None): value = torch.full((self.nnz(), ), fill_value, dtype=dtype, device=self.device()) return self.set_value_(value, layout='coo') def fill_value(self, fill_value: float, dtype: Optional[int] = None): value = torch.full((self.nnz(), ), fill_value, dtype=dtype, device=self.device()) return self.set_value(value, layout='coo') def sizes(self) -> List[int]: sparse_sizes = self.sparse_sizes() value = self.storage.value() if value is not None: return list(sparse_sizes) + list(value.size())[1:] else: return list(sparse_sizes) def size(self, dim: int) -> int: return self.sizes()[dim] def dim(self) -> int: return len(self.sizes()) def nnz(self) -> int: return self.storage.col().numel() def numel(self) -> int: value = self.storage.value() if value is not None: return value.numel() else: return self.nnz() def density(self) -> float: return self.nnz() / (self.sparse_size(0) * self.sparse_size(1)) def sparsity(self) -> float: return 1 - self.density() def avg_row_length(self) -> float: return self.nnz() / self.sparse_size(0) def avg_col_length(self) -> float: return self.nnz() / self.sparse_size(1) def bandwidth(self) -> int: row, col, _ = self.coo() return int((row - col).abs_().max()) def avg_bandwidth(self) -> float: row, col, _ = self.coo() return float((row - col).abs_().to(torch.float).mean()) def bandwidth_proportion(self, bandwidth: int) -> float: row, col, _ = self.coo() tmp = (row - col).abs_() return int((tmp <= bandwidth).sum()) / self.nnz() def is_quadratic(self) -> bool: return self.sparse_size(0) == self.sparse_size(1) def is_symmetric(self) -> bool: if not self.is_quadratic(): return False rowptr, col, value1 = self.csr() colptr, row, value2 = self.csc() if (rowptr != colptr).any() or (col != row).any(): return False if value1 is None or value2 is None: return True else: return bool((value1 == value2).all()) def to_symmetric(self, reduce: str = "sum"): row, col, value = self.coo() row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) if value is not None: value = torch.cat([value, value], dim=0) N = max(self.size(0), self.size(1)) out = SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=(N, N), is_sorted=False) out = out.coalesce(reduce) return out def detach_(self): value = self.storage.value() if value is not None: value.detach_() return self def detach(self): value = self.storage.value() if value is not None: value = value.detach() return self.set_value(value, layout='coo') def requires_grad(self) -> bool: value = self.storage.value() if value is not None: return value.requires_grad else: return False def requires_grad_(self, requires_grad: bool = True, dtype: Optional[int] = None): if requires_grad and not self.has_value(): self.fill_value_(1., dtype) value = self.storage.value() if value is not None: value.requires_grad_(requires_grad) return self def pin_memory(self): return self.from_storage(self.storage.pin_memory()) def is_pinned(self) -> bool: return self.storage.is_pinned() def device(self): return self.storage.col().device def cpu(self): return self.device_as(torch.tensor(0), non_blocking=False) def cuda(self): return self.from_storage(self.storage.cuda()) def is_cuda(self) -> bool: return self.storage.col().is_cuda def dtype(self): value = self.storage.value() return value.dtype if value is not None else torch.float def is_floating_point(self) -> bool: value = self.storage.value() return torch.is_floating_point(value) if value is not None else True def bfloat16(self): return self.type_as( torch.tensor(0, dtype=torch.bfloat16, device=self.device())) def bool(self): return self.type_as( torch.tensor(0, dtype=torch.bool, device=self.device())) def byte(self): return self.type_as( torch.tensor(0, dtype=torch.uint8, device=self.device())) def char(self): return self.type_as( torch.tensor(0, dtype=torch.int8, device=self.device())) def half(self): return self.type_as( torch.tensor(0, dtype=torch.half, device=self.device())) def float(self): return self.type_as( torch.tensor(0, dtype=torch.float, device=self.device())) def double(self): return self.type_as( torch.tensor(0, dtype=torch.double, device=self.device())) def short(self): return self.type_as( torch.tensor(0, dtype=torch.short, device=self.device())) def int(self): return self.type_as( torch.tensor(0, dtype=torch.int, device=self.device())) def long(self): return self.type_as( torch.tensor(0, dtype=torch.long, device=self.device())) # Conversions ############################################################# def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor: row, col, value = self.coo() if value is not None: mat = torch.zeros(self.sizes(), dtype=value.dtype, device=self.device()) else: mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device()) if value is not None: mat[row, col] = value else: mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype, device=mat.device) return mat def to_torch_sparse_coo_tensor(self, dtype: Optional[int] = None ) -> torch.Tensor: row, col, value = self.coo() index = torch.stack([row, col], dim=0) if value is None: value = torch.ones(self.nnz(), dtype=dtype, device=self.device()) return torch.sparse_coo_tensor(index, value, self.sizes())
def masked_select(src: SparseTensor, dim: int, mask: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert mask.dim() == 1 storage = src.storage if dim == 0: row, col, value = src.coo() rowcount = src.storage.rowcount() rowcount = rowcount[mask] mask = mask[row] row = torch.arange(rowcount.size(0), device=row.device).repeat_interleave(rowcount) col = col[mask] if value is not None: value = value[mask] sparse_sizes = (rowcount.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colcount=None, colptr=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: row, col, value = src.coo() csr2csc = src.storage.csr2csc() row = row[csr2csc] col = col[csr2csc] colcount = src.storage.colcount() colcount = colcount[mask] mask = mask[col] col = torch.arange(colcount.size(0), device=col.device).repeat_interleave(colcount) row = row[mask] csc2csr = (colcount.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[csr2csc][mask][csc2csr] sparse_sizes = (src.sparse_size(0), colcount.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colcount=colcount, colptr=None, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: idx = mask.nonzero().flatten() return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: assert len(tensors) > 0 rows: List[torch.Tensor] = [] rowptrs: List[torch.Tensor] = [] cols: List[torch.Tensor] = [] values: List[torch.Tensor] = [] sparse_sizes: List[int] = [0, 0] rowcounts: List[torch.Tensor] = [] colptrs: List[torch.Tensor] = [] colcounts: List[torch.Tensor] = [] csr2cscs: List[torch.Tensor] = [] csc2csrs: List[torch.Tensor] = [] nnz: int = 0 for tensor in tensors: row = tensor.storage._row if row is not None: rows.append(row + sparse_sizes[0]) rowptr = tensor.storage._rowptr if rowptr is not None: rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr) cols.append(tensor.storage._col + sparse_sizes[1]) value = tensor.storage._value if value is not None: values.append(value) rowcount = tensor.storage._rowcount if rowcount is not None: rowcounts.append(rowcount) colptr = tensor.storage._colptr if colptr is not None: colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr) colcount = tensor.storage._colcount if colcount is not None: colcounts.append(colcount) csr2csc = tensor.storage._csr2csc if csr2csc is not None: csr2cscs.append(csr2csc + nnz) csc2csr = tensor.storage._csc2csr if csc2csr is not None: csc2csrs.append(csc2csr + nnz) sparse_sizes[0] += tensor.sparse_size(0) sparse_sizes[1] += tensor.sparse_size(1) nnz += tensor.nnz() row: Optional[torch.Tensor] = None if len(rows) == len(tensors): row = torch.cat(rows, dim=0) rowptr: Optional[torch.Tensor] = None if len(rowptrs) == len(tensors): rowptr = torch.cat(rowptrs, dim=0) col = torch.cat(cols, dim=0) value: Optional[torch.Tensor] = None if len(values) == len(tensors): value = torch.cat(values, dim=0) rowcount: Optional[torch.Tensor] = None if len(rowcounts) == len(tensors): rowcount = torch.cat(rowcounts, dim=0) colptr: Optional[torch.Tensor] = None if len(colptrs) == len(tensors): colptr = torch.cat(colptrs, dim=0) colcount: Optional[torch.Tensor] = None if len(colcounts) == len(tensors): colcount = torch.cat(colcounts, dim=0) csr2csc: Optional[torch.Tensor] = None if len(csr2cscs) == len(tensors): csr2csc = torch.cat(csr2cscs, dim=0) csc2csr: Optional[torch.Tensor] = None if len(csc2csrs) == len(tensors): csc2csr = torch.cat(csc2csrs, dim=0) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]), rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) return tensors[0].from_storage(storage)
def test_sparse_reshape(dtype, device): row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device) storage = SparseStorage(row=row, col=col) storage = storage.sparse_reshape(2, 8) assert storage.sparse_sizes() == (2, 8) assert storage.row().tolist() == [0, 0, 1, 1] assert storage.col().tolist() == [0, 5, 2, 7] storage = storage.sparse_reshape(-1, 4) assert storage.sparse_sizes() == (4, 4) assert storage.row().tolist() == [0, 1, 2, 3] assert storage.col().tolist() == [0, 1, 2, 3] storage = storage.sparse_reshape(2, -1) assert storage.sparse_sizes() == (2, 8) assert storage.row().tolist() == [0, 0, 1, 1] assert storage.col().tolist() == [0, 5, 2, 7]
def test_storage(dtype, device): row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device) storage = SparseStorage(row=row, col=col) assert storage.row().tolist() == [0, 0, 1, 1] assert storage.col().tolist() == [0, 1, 0, 1] assert storage.value() is None assert storage.sparse_sizes() == (2, 2) row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device) value = tensor([2, 1, 4, 3], dtype, device) storage = SparseStorage(row=row, col=col, value=value) assert storage.row().tolist() == [0, 0, 1, 1] assert storage.col().tolist() == [0, 1, 0, 1] assert storage.value().tolist() == [1, 2, 3, 4] assert storage.sparse_sizes() == (2, 2)
def test_coalesce(dtype, device): row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device) value = tensor([1, 1, 1, 3, 4], dtype, device) storage = SparseStorage(row=row, col=col, value=value) assert storage.row().tolist() == row.tolist() assert storage.col().tolist() == col.tolist() assert storage.value().tolist() == value.tolist() assert not storage.is_coalesced() storage = storage.coalesce() assert storage.is_coalesced() assert storage.row().tolist() == [0, 0, 1, 1] assert storage.col().tolist() == [0, 1, 0, 1] assert storage.value().tolist() == [1, 2, 3, 4]
def test_caching(dtype, device): row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device) storage = SparseStorage(row=row, col=col) assert storage._row.tolist() == row.tolist() assert storage._col.tolist() == col.tolist() assert storage._value is None assert storage._rowcount is None assert storage._rowptr is None assert storage._colcount is None assert storage._colptr is None assert storage._csr2csc is None assert storage.num_cached_keys() == 0 storage.fill_cache_() assert storage._rowcount.tolist() == [2, 2] assert storage._rowptr.tolist() == [0, 2, 4] assert storage._colcount.tolist() == [2, 2] assert storage._colptr.tolist() == [0, 2, 4] assert storage._csr2csc.tolist() == [0, 2, 1, 3] assert storage._csc2csr.tolist() == [0, 2, 1, 3] assert storage.num_cached_keys() == 5 storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col, value=storage._value, sparse_sizes=storage._sparse_sizes, rowcount=storage._rowcount, colptr=storage._colptr, colcount=storage._colcount, csr2csc=storage._csr2csc, csc2csr=storage._csc2csr) assert storage._rowcount.tolist() == [2, 2] assert storage._rowptr.tolist() == [0, 2, 4] assert storage._colcount.tolist() == [2, 2] assert storage._colptr.tolist() == [0, 2, 4] assert storage._csr2csc.tolist() == [0, 2, 1, 3] assert storage._csc2csr.tolist() == [0, 2, 1, 3] assert storage.num_cached_keys() == 5 storage.clear_cache_() assert storage._rowcount is None assert storage._rowptr is not None assert storage._colcount is None assert storage._colptr is None assert storage._csr2csc is None assert storage.num_cached_keys() == 0
def test_utility(dtype, device): row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device) value = tensor([1, 2, 3, 4], dtype, device) storage = SparseStorage(row=row, col=col, value=value) assert storage.has_value() storage.set_value_(value, layout='csc') assert storage.value().tolist() == [1, 3, 2, 4] storage.set_value_(value, layout='coo') assert storage.value().tolist() == [1, 2, 3, 4] storage = storage.set_value(value, layout='csc') assert storage.value().tolist() == [1, 3, 2, 4] storage = storage.set_value(value, layout='coo') assert storage.value().tolist() == [1, 2, 3, 4] storage = storage.sparse_resize((3, 3)) assert storage.sparse_sizes() == (3, 3) new_storage = storage.copy() assert new_storage != storage assert new_storage.col().data_ptr() == storage.col().data_ptr() new_storage = storage.clone() assert new_storage != storage assert new_storage.col().data_ptr() != storage.col().data_ptr()
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: assert len(tensors) > 0 if dim < 0: dim = tensors[0].dim() + dim if dim == 0: rows: List[torch.Tensor] = [] rowptrs: List[torch.Tensor] = [] cols: List[torch.Tensor] = [] values: List[torch.Tensor] = [] sparse_sizes: List[int] = [0, 0] rowcounts: List[torch.Tensor] = [] nnz: int = 0 for tensor in tensors: row = tensor.storage._row if row is not None: rows.append(row + sparse_sizes[0]) rowptr = tensor.storage._rowptr if rowptr is not None: if len(rowptrs) > 0: rowptr = rowptr[1:] rowptrs.append(rowptr + nnz) cols.append(tensor.storage._col) value = tensor.storage._value if value is not None: values.append(value) rowcount = tensor.storage._rowcount if rowcount is not None: rowcounts.append(rowcount) sparse_sizes[0] += tensor.sparse_size(0) sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1)) nnz += tensor.nnz() row: Optional[torch.Tensor] = None if len(rows) == len(tensors): row = torch.cat(rows, dim=0) rowptr: Optional[torch.Tensor] = None if len(rowptrs) == len(tensors): rowptr = torch.cat(rowptrs, dim=0) col = torch.cat(cols, dim=0) value: Optional[torch.Tensor] = None if len(values) == len(tensors): value = torch.cat(values, dim=0) rowcount: Optional[torch.Tensor] = None if len(rowcounts) == len(tensors): rowcount = torch.cat(rowcounts, dim=0) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return tensors[0].from_storage(storage) elif dim == 1: rows: List[torch.Tensor] = [] cols: List[torch.Tensor] = [] values: List[torch.Tensor] = [] sparse_sizes: List[int] = [0, 0] colptrs: List[torch.Tensor] = [] colcounts: List[torch.Tensor] = [] nnz: int = 0 for tensor in tensors: row, col, value = tensor.coo() rows.append(row) cols.append(tensor.storage._col + sparse_sizes[1]) if value is not None: values.append(value) colptr = tensor.storage._colptr if colptr is not None: if len(colptrs) > 0: colptr = colptr[1:] colptrs.append(colptr + nnz) colcount = tensor.storage._colcount if colcount is not None: colcounts.append(colcount) sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0)) sparse_sizes[1] += tensor.sparse_size(1) nnz += tensor.nnz() row = torch.cat(rows, dim=0) col = torch.cat(cols, dim=0) value: Optional[torch.Tensor] = None if len(values) == len(tensors): value = torch.cat(values, dim=0) colptr: Optional[torch.Tensor] = None if len(colptrs) == len(tensors): colptr = torch.cat(colptrs, dim=0) colcount: Optional[torch.Tensor] = None if len(colcounts) == len(tensors): colcount = torch.cat(colcounts, dim=0) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=False) return tensors[0].from_storage(storage) elif dim > 1 and dim < tensors[0].dim(): values: List[torch.Tensor] = [] for tensor in tensors: value = tensor.storage.value() if value is not None: values.append(value) value: Optional[torch.Tensor] = None if len(values) == len(tensors): value = torch.cat(values, dim=dim - 1) return tensors[0].set_value(value, layout='coo') else: raise IndexError( (f'Dimension out of range: Expected to be in range of ' f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'))
def index_select(src: SparseTensor, dim: int, idx: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert idx.dim() == 1 if dim == 0: old_rowptr, col, value = src.csr() rowcount = src.storage.rowcount() rowcount = rowcount[idx] rowptr = col.new_zeros(idx.size(0) + 1) torch.cumsum(rowcount, dim=0, out=rowptr[1:]) row = torch.arange(idx.size(0), device=col.device).repeat_interleave(rowcount) perm = torch.arange(row.size(0), device=row.device) perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr) col = col[perm] if value is not None: value = value[perm] sparse_sizes = (idx.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: old_colptr, row, value = src.csc() colcount = src.storage.colcount() colcount = colcount[idx] colptr = row.new_zeros(idx.size(0) + 1) torch.cumsum(colcount, dim=0, out=colptr[1:]) col = torch.arange(idx.size(0), device=row.device).repeat_interleave(colcount) perm = torch.arange(col.size(0), device=col.device) perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr) row = row[perm] csc2csr = (idx.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[perm][csc2csr] sparse_sizes = (src.sparse_size(0), idx.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def narrow(src: SparseTensor, dim: int, start: int, length: int) -> SparseTensor: if dim < 0: dim = src.dim() + dim if start < 0: start = src.size(dim) + start if dim == 0: rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start, length=length + 1) row_start = rowptr[0] rowptr = rowptr - row_start row_length = rowptr[-1] row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start col = col.narrow(0, row_start, row_length) if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = (length, src.sparse_size(1)) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: # This is faster than accessing `csc()` contrary to the `dim=0` case. row, col, value = src.coo() mask = (col >= start) & (col < start + length) row = row[mask] col = col[mask] - start if value is not None: value = value[mask] sparse_sizes = (src.sparse_size(0), length) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr - colptr[0] colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.narrow(dim - 1, start, length), layout='coo') else: raise ValueError