def t(src: SparseTensor) -> SparseTensor: csr2csc = src.storage.csr2csc() row, col, value = src.coo() if value is not None: value = value[csr2csc] sparse_sizes = src.storage.sparse_sizes() storage = SparseStorage( row=col[csr2csc], rowptr=src.storage._colptr, col=row[csr2csc], value=value, sparse_sizes=(sparse_sizes[1], sparse_sizes[0]), rowcount=src.storage._colcount, colptr=src.storage._rowptr, colcount=src.storage._rowcount, csr2csc=src.storage._csc2csr, csc2csr=csr2csc, is_sorted=True, ) return src.from_storage(storage)
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: row, col, value = src.coo() inv_mask = row != col if k == 0 else row != (col - k) new_row, new_col = row[inv_mask], col[inv_mask] if value is not None: value = value[inv_mask] rowcount = src.storage._rowcount colcount = src.storage._colcount if rowcount is not None or colcount is not None: mask = ~inv_mask if rowcount is not None: rowcount = rowcount.clone() rowcount[row[mask]] -= 1 if colcount is not None: colcount = colcount.clone() colcount[col[mask]] -= 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, k: int = 0) -> SparseTensor: src = remove_diag(src, k=k) row, col, value = src.coo() mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0), src.size(1), k) inv_mask = ~mask start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() diag = torch.arange(start, start + num_diag, device=row.device) new_row = row.new_empty(mask.size(0)) new_row[mask] = row new_row[inv_mask] = diag new_col = col.new_empty(mask.size(0)) new_col[mask] = col new_col[inv_mask] = diag.add_(k) new_value: Optional[torch.Tensor] = None if value is not None: new_value = value.new_empty((mask.size(0), ) + value.size()[1:]) new_value[mask] = value if values is not None: new_value[inv_mask] = values else: new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype, device=value.device) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.clone() rowcount[start:start + num_diag] += 1 colcount = src.storage._colcount if colcount is not None: colcount = colcount.clone() colcount[start + k:start + num_diag + k] += 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=new_value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def get_diag(src: SparseTensor) -> Tensor: row, col, value = src.coo() if value is None: value = torch.ones(row.size(0)) sizes = list(value.size()) sizes[0] = min(src.size(0), src.size(1)) out = value.new_zeros(sizes) mask = row == col out[row[mask]] = value[mask] return out
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: assert mask.dim() == 1 if get_layout(layout) == 'csc': mask = mask[src.storage.csc2csr()] row, col, value = src.coo() row, col = row[mask], col[mask] if value is not None: value = value[mask] return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
def index_select_nnz(src: SparseTensor, idx: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: assert idx.dim() == 1 if get_layout(layout) == 'csc': idx = src.storage.csc2csr()[idx] row, col, value = src.coo() row, col = row[idx], col[idx] if value is not None: value = value[idx] return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
def saint_subgraph( src: SparseTensor, node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]: row, col, value = src.coo() rowptr = src.storage.rowptr() data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col) row, col, edge_index = data if value is not None: value = value[edge_index] out = SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=(node_idx.size(0), node_idx.size(0)), is_sorted=True) return out, edge_index
def masked_select(src: SparseTensor, dim: int, mask: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert mask.dim() == 1 storage = src.storage if dim == 0: row, col, value = src.coo() rowcount = src.storage.rowcount() rowcount = rowcount[mask] mask = mask[row] row = torch.arange(rowcount.size(0), device=row.device).repeat_interleave(rowcount) col = col[mask] if value is not None: value = value[mask] sparse_sizes = (rowcount.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colcount=None, colptr=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: row, col, value = src.coo() csr2csc = src.storage.csr2csc() row = row[csr2csc] col = col[csr2csc] colcount = src.storage.colcount() colcount = colcount[mask] mask = mask[col] col = torch.arange(colcount.size(0), device=col.device).repeat_interleave(colcount) row = row[mask] csc2csr = (colcount.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[csr2csc][mask][csc2csr] sparse_sizes = (src.sparse_size(0), colcount.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colcount=colcount, colptr=None, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: idx = mask.nonzero().flatten() return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def narrow(src: SparseTensor, dim: int, start: int, length: int) -> SparseTensor: if dim < 0: dim = src.dim() + dim if start < 0: start = src.size(dim) + start if dim == 0: rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start, length=length + 1) row_start = rowptr[0] rowptr = rowptr - row_start row_length = rowptr[-1] row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start col = col.narrow(0, row_start, row_length) if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = (length, src.sparse_size(1)) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: # This is faster than accessing `csc()` contrary to the `dim=0` case. row, col, value = src.coo() mask = (col >= start) & (col < start + length) row = row[mask] col = col[mask] - start if value is not None: value = value[mask] sparse_sizes = (src.sparse_size(0), length) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr - colptr[0] colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.narrow(dim - 1, start, length), layout='coo') else: raise ValueError