def reduction(src: SparseTensor, dim: Optional[int] = None, reduce: str = 'sum') -> torch.Tensor: value = src.storage.value() if dim is None: if value is not None: if reduce == 'sum' or reduce == 'add': return value.sum() elif reduce == 'mean': return value.mean() elif reduce == 'min': return value.min() elif reduce == 'max': return value.max() else: raise ValueError else: if reduce == 'sum' or reduce == 'add': return torch.tensor(src.nnz(), dtype=src.dtype(), device=src.device()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.tensor(1, dtype=src.dtype(), device=src.device()) else: raise ValueError else: if dim < 0: dim = src.dim() + dim if dim == 0 and value is not None: col = src.storage.col() return scatter(value, col, dim=0, dim_size=src.size(0)) elif dim == 0 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.colcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(1), dtype=src.dtype()) else: raise ValueError elif dim == 1 and value is not None: return segment_csr(value, src.storage.rowptr(), None, reduce) elif dim == 1 and value is None: if reduce == 'sum' or reduce == 'add': return src.storage.rowcount().to(src.dtype()) elif reduce == 'mean' or reduce == 'min' or reduce == 'max': return torch.ones(src.size(0), dtype=src.dtype()) else: raise ValueError elif dim > 1 and value is not None: if reduce == 'sum' or reduce == 'add': return value.sum(dim=dim - 1) elif reduce == 'mean': return value.mean(dim=dim - 1) elif reduce == 'min': return value.min(dim=dim - 1)[0] elif reduce == 'max': return value.max(dim=dim - 1)[0] else: raise ValueError else: raise ValueError
def index_select(src: SparseTensor, dim: int, idx: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert idx.dim() == 1 if dim == 0: old_rowptr, col, value = src.csr() rowcount = src.storage.rowcount() rowcount = rowcount[idx] rowptr = col.new_zeros(idx.size(0) + 1) torch.cumsum(rowcount, dim=0, out=rowptr[1:]) row = torch.arange(idx.size(0), device=col.device).repeat_interleave(rowcount) perm = torch.arange(row.size(0), device=row.device) perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr) col = col[perm] if value is not None: value = value[perm] sparse_sizes = (idx.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: old_colptr, row, value = src.csc() colcount = src.storage.colcount() colcount = colcount[idx] colptr = row.new_zeros(idx.size(0) + 1) torch.cumsum(colcount, dim=0, out=colptr[1:]) col = torch.arange(idx.size(0), device=row.device).repeat_interleave(colcount) perm = torch.arange(col.size(0), device=col.device) perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr) row = row[perm] csc2csr = (idx.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[perm][csc2csr] sparse_sizes = (src.sparse_size(0), idx.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def masked_select(src: SparseTensor, dim: int, mask: torch.Tensor) -> SparseTensor: dim = src.dim() + dim if dim < 0 else dim assert mask.dim() == 1 storage = src.storage if dim == 0: row, col, value = src.coo() rowcount = src.storage.rowcount() rowcount = rowcount[mask] mask = mask[row] row = torch.arange(rowcount.size(0), device=row.device).repeat_interleave(rowcount) col = col[mask] if value is not None: value = value[mask] sparse_sizes = (rowcount.size(0), src.sparse_size(1)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colcount=None, colptr=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: row, col, value = src.coo() csr2csc = src.storage.csr2csc() row = row[csr2csc] col = col[csr2csc] colcount = src.storage.colcount() colcount = colcount[mask] mask = mask[col] col = torch.arange(colcount.size(0), device=col.device).repeat_interleave(colcount) row = row[mask] csc2csr = (colcount.size(0) * row + col).argsort() row, col = row[csc2csr], col[csc2csr] if value is not None: value = value[csr2csc][mask][csc2csr] sparse_sizes = (src.sparse_size(0), colcount.size(0)) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colcount=colcount, colptr=None, csr2csc=None, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: idx = mask.nonzero().flatten() return src.set_value(value.index_select(dim - 1, idx), layout='coo') else: raise ValueError
def narrow(src: SparseTensor, dim: int, start: int, length: int) -> SparseTensor: if dim < 0: dim = src.dim() + dim if start < 0: start = src.size(dim) + start if dim == 0: rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start, length=length + 1) row_start = rowptr[0] rowptr = rowptr - row_start row_length = rowptr[-1] row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start col = col.narrow(0, row_start, row_length) if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = (length, src.sparse_size(1)) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=None, colcount=None, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) elif dim == 1: # This is faster than accessing `csc()` contrary to the `dim=0` case. row, col, value = src.coo() mask = (col >= start) & (col < start + length) row = row[mask] col = col[mask] - start if value is not None: value = value[mask] sparse_sizes = (src.sparse_size(0), length) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr - colptr[0] colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start=start, length=length) storage = SparseStorage(row=row, rowptr=None, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=None, colptr=colptr, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage) else: value = src.storage.value() if value is not None: return src.set_value(value.narrow(dim - 1, start, length), layout='coo') else: raise ValueError