Exemplo n.º 1
0
 def __init__(self, row: Optional[torch.Tensor] = None,
              rowptr: Optional[torch.Tensor] = None,
              col: Optional[torch.Tensor] = None,
              value: Optional[torch.Tensor] = None,
              sparse_sizes: Optional[Tuple[int, int]] = None,
              is_sorted: bool = False):
     self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
                                  value=value, sparse_sizes=sparse_sizes,
                                  rowcount=None, colptr=None, colcount=None,
                                  csr2csc=None, csc2csr=None,
                                  is_sorted=is_sorted)
Exemplo n.º 2
0
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
    row, col, value = src.coo()
    inv_mask = row != col if k == 0 else row != (col - k)
    new_row, new_col = row[inv_mask], col[inv_mask]

    if value is not None:
        value = value[inv_mask]

    rowcount = src.storage._rowcount
    colcount = src.storage._colcount
    if rowcount is not None or colcount is not None:
        mask = ~inv_mask
        if rowcount is not None:
            rowcount = rowcount.clone()
            rowcount[row[mask]] -= 1
        if colcount is not None:
            colcount = colcount.clone()
            colcount[col[mask]] -= 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
Exemplo n.º 3
0
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
    colptr = None
    if isinstance(mat, scipy.sparse.csc_matrix):
        colptr = torch.from_numpy(mat.indptr).to(torch.long)

    mat = mat.tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(torch.long)
    mat = mat.tocoo()
    row = torch.from_numpy(mat.row).to(torch.long)
    col = torch.from_numpy(mat.col).to(torch.long)
    value = None
    if has_value:
        value = torch.from_numpy(mat.data)
    sparse_sizes = mat.shape[:2]

    storage = SparseStorage(row=row,
                            rowptr=rowptr,
                            col=col,
                            value=value,
                            sparse_sizes=sparse_sizes,
                            rowcount=None,
                            colptr=colptr,
                            colcount=None,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)

    return SparseTensor.from_storage(storage)
Exemplo n.º 4
0
def t(src: SparseTensor) -> SparseTensor:
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]

    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
        rowptr=src.storage._colptr,
        col=row[csr2csc],
        value=value,
        sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
        colcount=src.storage._rowcount,
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
Exemplo n.º 5
0
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
                    length: Tuple[int, int]) -> SparseTensor:
    # This function builds the inverse operation of `cat_diag` and should hence
    # only be used on *diagonally stacked* sparse matrices.
    # That's the reason why this method is marked as *private*.

    rowptr, col, value = src.csr()

    rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
    row_start = int(rowptr[0])
    rowptr = rowptr - row_start
    row_length = int(rowptr[-1])

    row = src.storage._row
    if row is not None:
        row = row.narrow(0, row_start, row_length) - start[0]

    col = col.narrow(0, row_start, row_length) - start[1]

    if value is not None:
        value = value.narrow(0, row_start, row_length)

    sparse_sizes = length

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.narrow(0, start[0], length[0])

    colptr = src.storage._colptr
    if colptr is not None:
        colptr = colptr.narrow(0, start[1], length[1] + 1)
        colptr = colptr - int(colptr[0])  # i.e. `row_start`

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.narrow(0, start[1], length[1])

    csr2csc = src.storage._csr2csc
    if csr2csc is not None:
        csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start

    csc2csr = src.storage._csc2csr
    if csc2csr is not None:
        csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start

    storage = SparseStorage(row=row,
                            rowptr=rowptr,
                            col=col,
                            value=value,
                            sparse_sizes=sparse_sizes,
                            rowcount=rowcount,
                            colptr=colptr,
                            colcount=colcount,
                            csr2csc=csr2csc,
                            csc2csr=csc2csr,
                            is_sorted=True)
    return src.from_storage(storage)
Exemplo n.º 6
0
def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
    rows: List[torch.Tensor] = []
    cols: List[torch.Tensor] = []
    values: List[torch.Tensor] = []
    sparse_sizes: List[int] = [0, 0]
    colptrs: List[torch.Tensor] = []
    colcounts: List[torch.Tensor] = []

    nnz: int = 0
    for tensor in tensors:
        row, col, value = tensor.coo()
        rows.append(row)
        cols.append(tensor.storage._col + sparse_sizes[1])

        if value is not None:
            values.append(value)

        colptr = tensor.storage._colptr
        if colptr is not None:
            colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)

        colcount = tensor.storage._colcount
        if colcount is not None:
            colcounts.append(colcount)

        sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
        sparse_sizes[1] += tensor.sparse_size(1)
        nnz += tensor.nnz()

    row = torch.cat(rows, dim=0)
    col = torch.cat(cols, dim=0)

    value: Optional[torch.Tensor] = None
    if len(values) == len(tensors):
        value = torch.cat(values, dim=0)

    colptr: Optional[torch.Tensor] = None
    if len(colptrs) == len(tensors):
        colptr = torch.cat(colptrs, dim=0)

    colcount: Optional[torch.Tensor] = None
    if len(colcounts) == len(tensors):
        colcount = torch.cat(colcounts, dim=0)

    storage = SparseStorage(row=row,
                            rowptr=None,
                            col=col,
                            value=value,
                            sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
                            rowcount=None,
                            colptr=colptr,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=False)
    return tensors[0].from_storage(storage)
Exemplo n.º 7
0
    def eye(self,
            M: int,
            N: Optional[int] = None,
            has_value: bool = True,
            dtype: Optional[int] = None,
            device: Optional[torch.device] = None,
            fill_cache: bool = False):

        N = M if N is None else N

        row = torch.arange(min(M, N), device=device)
        col = row

        rowptr = torch.arange(M + 1, device=row.device)
        if M > N:
            rowptr[N + 1:] = N

        value: Optional[torch.Tensor] = None
        if has_value:
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None

        if fill_cache:
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
            if M > N:
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
            if N > M:
                colptr[M + 1:] = M
                colcount[M:] = 0
            csr2csc = csc2csr = row

        storage: SparseStorage = SparseStorage(row=row,
                                               rowptr=rowptr,
                                               col=col,
                                               value=value,
                                               sparse_sizes=(M, N),
                                               rowcount=rowcount,
                                               colptr=colptr,
                                               colcount=colcount,
                                               csr2csc=csr2csc,
                                               csc2csr=csc2csr,
                                               is_sorted=True)

        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self
Exemplo n.º 8
0
def set_diag(src: SparseTensor,
             values: Optional[torch.Tensor] = None,
             k: int = 0) -> SparseTensor:
    src = remove_diag(src, k=k)
    row, col, value = src.coo()

    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
    inv_mask = ~mask

    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=row.device)

    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag

    new_col = col.new_empty(mask.size(0))
    new_col[mask] = col
    new_col[inv_mask] = diag.add_(k)

    new_value: Optional[torch.Tensor] = None
    if value is not None:
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
        new_value[mask] = value
        if values is not None:
            new_value[inv_mask] = values
        else:
            new_value[inv_mask] = torch.ones((num_diag, ),
                                             dtype=value.dtype,
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=new_value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
Exemplo n.º 9
0
def coalesce(index: torch.Tensor, m: int, n: int, value: Optional[torch.Tensor] = None, op: str="add"):
    """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
    entries are removed by scattering them together. For scattering, any
    operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
    can be used.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
        op (string, optional): The scatter operation to use. (default:
            :obj:`"add"`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
    sparse_sizes = (m, n)

    storage = SparseStorage(
        row=index[0], col=index[1], value=value,
        sparse_sizes=sparse_sizes,
        is_sorted=False,
        rowptr=None, rowcount=None, colptr=None, colcount=None, csr2csc=None, csc2csr=None
        )
    storage = storage.coalesce(reduce=op)
    return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
Exemplo n.º 10
0
def transpose(index, value, m, n, coalesced=True):
    """Transposes dimensions 0 and 1 of a sparse tensor.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
        coalesced (bool, optional): If set to :obj:`False`, will not coalesce
            the output. (default: :obj:`True`)
    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    row, col = index
    row, col = col, row

    if coalesced:
        sparse_sizes = (n, m)
        storage = SparseStorage(row=row,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                is_sorted=False)
        storage = storage.coalesce()
        row, col, value = storage.row(), storage.col(), storage.value()

    return torch.stack([row, col], dim=0), value
Exemplo n.º 11
0
class SparseTensor(object):
    storage: SparseStorage

    def __init__(self,
                 row: Optional[torch.Tensor] = None,
                 rowptr: Optional[torch.Tensor] = None,
                 col: Optional[torch.Tensor] = None,
                 value: Optional[torch.Tensor] = None,
                 sparse_sizes: Optional[Tuple[int, int]] = None,
                 is_sorted: bool = False):
        self.storage = SparseStorage(row=row,
                                     rowptr=rowptr,
                                     col=col,
                                     value=value,
                                     sparse_sizes=sparse_sizes,
                                     rowcount=None,
                                     colptr=None,
                                     colcount=None,
                                     csr2csc=None,
                                     csc2csr=None,
                                     is_sorted=is_sorted)

    @classmethod
    def from_storage(self, storage: SparseStorage):
        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    @classmethod
    def from_edge_index(self,
                        edge_index: torch.Tensor,
                        edge_attr: Optional[torch.Tensor] = None,
                        sparse_sizes: Optional[Tuple[int, int]] = None,
                        is_sorted: bool = False):
        return SparseTensor(row=edge_index[0],
                            rowptr=None,
                            col=edge_index[1],
                            value=edge_attr,
                            sparse_sizes=sparse_sizes,
                            is_sorted=is_sorted)

    @classmethod
    def from_dense(self, mat: torch.Tensor, has_value: bool = True):
        if mat.dim() > 2:
            index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
        else:
            index = mat.nonzero()
        index = index.t()

        row = index[0]
        col = index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat[row, col]

        return SparseTensor(row=row,
                            rowptr=None,
                            col=col,
                            value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)

    @classmethod
    def from_torch_sparse_coo_tensor(self,
                                     mat: torch.Tensor,
                                     has_value: bool = True):
        mat = mat.coalesce()
        index = mat._indices()
        row, col = index[0], index[1]

        value: Optional[torch.Tensor] = None
        if has_value:
            value = mat.values()

        return SparseTensor(row=row,
                            rowptr=None,
                            col=col,
                            value=value,
                            sparse_sizes=(mat.size(0), mat.size(1)),
                            is_sorted=True)

    @classmethod
    def eye(self,
            M: int,
            N: Optional[int] = None,
            has_value: bool = True,
            dtype: Optional[int] = None,
            device: Optional[torch.device] = None,
            fill_cache: bool = False):

        N = M if N is None else N

        row = torch.arange(min(M, N), device=device)
        col = row

        rowptr = torch.arange(M + 1, device=row.device)
        if M > N:
            rowptr[N + 1:] = N

        value: Optional[torch.Tensor] = None
        if has_value:
            value = torch.ones(row.numel(), dtype=dtype, device=row.device)

        rowcount: Optional[torch.Tensor] = None
        colptr: Optional[torch.Tensor] = None
        colcount: Optional[torch.Tensor] = None
        csr2csc: Optional[torch.Tensor] = None
        csc2csr: Optional[torch.Tensor] = None

        if fill_cache:
            rowcount = torch.ones(M, dtype=torch.long, device=row.device)
            if M > N:
                rowcount[N:] = 0

            colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
            colcount = torch.ones(N, dtype=torch.long, device=row.device)
            if N > M:
                colptr[M + 1:] = M
                colcount[M:] = 0
            csr2csc = csc2csr = row

        storage: SparseStorage = SparseStorage(row=row,
                                               rowptr=rowptr,
                                               col=col,
                                               value=value,
                                               sparse_sizes=(M, N),
                                               rowcount=rowcount,
                                               colptr=colptr,
                                               colcount=colcount,
                                               csr2csc=csr2csc,
                                               csc2csr=csc2csr,
                                               is_sorted=True)

        self = SparseTensor.__new__(SparseTensor)
        self.storage = storage
        return self

    def copy(self):
        return self.from_storage(self.storage)

    def clone(self):
        return self.from_storage(self.storage.clone())

    def type_as(self, tensor: torch.Tensor):
        value = self.storage.value()
        if value is None or tensor.dtype == value.dtype:
            return self
        return self.from_storage(self.storage.type_as(tensor))

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        if tensor.device == self.device():
            return self
        return self.from_storage(self.storage.device_as(tensor, non_blocking))

    # Formats #################################################################

    def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.row(), self.storage.col(), self.storage.value()

    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        return self.storage.rowptr(), self.storage.col(), self.storage.value()

    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        perm = self.storage.csr2csc()
        value = self.storage.value()
        if value is not None:
            value = value[perm]
        return self.storage.colptr(), self.storage.row()[perm], value

    # Storage inheritance #####################################################

    def has_value(self) -> bool:
        return self.storage.has_value()

    def set_value_(self,
                   value: Optional[torch.Tensor],
                   layout: Optional[str] = None):
        self.storage.set_value_(value, layout)
        return self

    def set_value(self,
                  value: Optional[torch.Tensor],
                  layout: Optional[str] = None):
        return self.from_storage(self.storage.set_value(value, layout))

    def sparse_sizes(self) -> Tuple[int, int]:
        return self.storage.sparse_sizes()

    def sparse_size(self, dim: int) -> int:
        return self.storage.sparse_sizes()[dim]

    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
        return self.from_storage(self.storage.sparse_resize(sparse_sizes))

    def sparse_reshape(self, num_rows: int, num_cols: int):
        return self.from_storage(
            self.storage.sparse_reshape(num_rows, num_cols))

    def is_coalesced(self) -> bool:
        return self.storage.is_coalesced()

    def coalesce(self, reduce: str = "sum"):
        return self.from_storage(self.storage.coalesce(reduce))

    def fill_cache_(self):
        self.storage.fill_cache_()
        return self

    def clear_cache_(self):
        self.storage.clear_cache_()
        return self

    # Utility functions #######################################################

    def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ),
                           fill_value,
                           dtype=dtype,
                           device=self.device())
        return self.set_value_(value, layout='coo')

    def fill_value(self, fill_value: float, dtype: Optional[int] = None):
        value = torch.full((self.nnz(), ),
                           fill_value,
                           dtype=dtype,
                           device=self.device())
        return self.set_value(value, layout='coo')

    def sizes(self) -> List[int]:
        sparse_sizes = self.sparse_sizes()
        value = self.storage.value()
        if value is not None:
            return list(sparse_sizes) + list(value.size())[1:]
        else:
            return list(sparse_sizes)

    def size(self, dim: int) -> int:
        return self.sizes()[dim]

    def dim(self) -> int:
        return len(self.sizes())

    def nnz(self) -> int:
        return self.storage.col().numel()

    def numel(self) -> int:
        value = self.storage.value()
        if value is not None:
            return value.numel()
        else:
            return self.nnz()

    def density(self) -> float:
        return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))

    def sparsity(self) -> float:
        return 1 - self.density()

    def avg_row_length(self) -> float:
        return self.nnz() / self.sparse_size(0)

    def avg_col_length(self) -> float:
        return self.nnz() / self.sparse_size(1)

    def bandwidth(self) -> int:
        row, col, _ = self.coo()
        return int((row - col).abs_().max())

    def avg_bandwidth(self) -> float:
        row, col, _ = self.coo()
        return float((row - col).abs_().to(torch.float).mean())

    def bandwidth_proportion(self, bandwidth: int) -> float:
        row, col, _ = self.coo()
        tmp = (row - col).abs_()
        return int((tmp <= bandwidth).sum()) / self.nnz()

    def is_quadratic(self) -> bool:
        return self.sparse_size(0) == self.sparse_size(1)

    def is_symmetric(self) -> bool:
        if not self.is_quadratic():
            return False

        rowptr, col, value1 = self.csr()
        colptr, row, value2 = self.csc()

        if (rowptr != colptr).any() or (col != row).any():
            return False

        if value1 is None or value2 is None:
            return True
        else:
            return bool((value1 == value2).all())

    def to_symmetric(self, reduce: str = "sum"):
        row, col, value = self.coo()

        row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
        if value is not None:
            value = torch.cat([value, value], dim=0)

        N = max(self.size(0), self.size(1))

        out = SparseTensor(row=row,
                           rowptr=None,
                           col=col,
                           value=value,
                           sparse_sizes=(N, N),
                           is_sorted=False)
        out = out.coalesce(reduce)
        return out

    def detach_(self):
        value = self.storage.value()
        if value is not None:
            value.detach_()
        return self

    def detach(self):
        value = self.storage.value()
        if value is not None:
            value = value.detach()
        return self.set_value(value, layout='coo')

    def requires_grad(self) -> bool:
        value = self.storage.value()
        if value is not None:
            return value.requires_grad
        else:
            return False

    def requires_grad_(self,
                       requires_grad: bool = True,
                       dtype: Optional[int] = None):
        if requires_grad and not self.has_value():
            self.fill_value_(1., dtype)

        value = self.storage.value()
        if value is not None:
            value.requires_grad_(requires_grad)
        return self

    def pin_memory(self):
        return self.from_storage(self.storage.pin_memory())

    def is_pinned(self) -> bool:
        return self.storage.is_pinned()

    def device(self):
        return self.storage.col().device

    def cpu(self):
        return self.device_as(torch.tensor(0), non_blocking=False)

    def cuda(self):
        return self.from_storage(self.storage.cuda())

    def is_cuda(self) -> bool:
        return self.storage.col().is_cuda

    def dtype(self):
        value = self.storage.value()
        return value.dtype if value is not None else torch.float

    def is_floating_point(self) -> bool:
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True

    def bfloat16(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.bfloat16, device=self.device()))

    def bool(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.bool, device=self.device()))

    def byte(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.uint8, device=self.device()))

    def char(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.int8, device=self.device()))

    def half(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.half, device=self.device()))

    def float(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.float, device=self.device()))

    def double(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.double, device=self.device()))

    def short(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.short, device=self.device()))

    def int(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.int, device=self.device()))

    def long(self):
        return self.type_as(
            torch.tensor(0, dtype=torch.long, device=self.device()))

    # Conversions #############################################################

    def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
        row, col, value = self.coo()

        if value is not None:
            mat = torch.zeros(self.sizes(),
                              dtype=value.dtype,
                              device=self.device())
        else:
            mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())

        if value is not None:
            mat[row, col] = value
        else:
            mat[row, col] = torch.ones(self.nnz(),
                                       dtype=mat.dtype,
                                       device=mat.device)

        return mat

    def to_torch_sparse_coo_tensor(self,
                                   dtype: Optional[int] = None
                                   ) -> torch.Tensor:
        row, col, value = self.coo()
        index = torch.stack([row, col], dim=0)

        if value is None:
            value = torch.ones(self.nnz(), dtype=dtype, device=self.device())

        return torch.sparse_coo_tensor(index, value, self.sizes())
Exemplo n.º 12
0
def masked_select(src: SparseTensor, dim: int,
                  mask: torch.Tensor) -> SparseTensor:
    dim = src.dim() + dim if dim < 0 else dim

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
        row, col, value = src.coo()
        rowcount = src.storage.rowcount()

        rowcount = rowcount[mask]

        mask = mask[row]
        row = torch.arange(rowcount.size(0),
                           device=row.device).repeat_interleave(rowcount)

        col = col[mask]

        if value is not None:
            value = value[mask]

        sparse_sizes = (rowcount.size(0), src.sparse_size(1))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colcount=None,
                                colptr=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        row, col, value = src.coo()
        csr2csc = src.storage.csr2csc()
        row = row[csr2csc]
        col = col[csr2csc]
        colcount = src.storage.colcount()

        colcount = colcount[mask]

        mask = mask[col]
        col = torch.arange(colcount.size(0),
                           device=col.device).repeat_interleave(colcount)
        row = row[mask]
        csc2csr = (colcount.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]

        if value is not None:
            value = value[csr2csc][mask][csc2csr]

        sparse_sizes = (src.sparse_size(0), colcount.size(0))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colcount=colcount,
                                colptr=None,
                                csr2csc=None,
                                csc2csr=csc2csr,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            idx = mask.nonzero().flatten()
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
Exemplo n.º 13
0
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
    assert len(tensors) > 0

    rows: List[torch.Tensor] = []
    rowptrs: List[torch.Tensor] = []
    cols: List[torch.Tensor] = []
    values: List[torch.Tensor] = []
    sparse_sizes: List[int] = [0, 0]
    rowcounts: List[torch.Tensor] = []
    colptrs: List[torch.Tensor] = []
    colcounts: List[torch.Tensor] = []
    csr2cscs: List[torch.Tensor] = []
    csc2csrs: List[torch.Tensor] = []

    nnz: int = 0
    for tensor in tensors:
        row = tensor.storage._row
        if row is not None:
            rows.append(row + sparse_sizes[0])

        rowptr = tensor.storage._rowptr
        if rowptr is not None:
            rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)

        cols.append(tensor.storage._col + sparse_sizes[1])

        value = tensor.storage._value
        if value is not None:
            values.append(value)

        rowcount = tensor.storage._rowcount
        if rowcount is not None:
            rowcounts.append(rowcount)

        colptr = tensor.storage._colptr
        if colptr is not None:
            colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)

        colcount = tensor.storage._colcount
        if colcount is not None:
            colcounts.append(colcount)

        csr2csc = tensor.storage._csr2csc
        if csr2csc is not None:
            csr2cscs.append(csr2csc + nnz)

        csc2csr = tensor.storage._csc2csr
        if csc2csr is not None:
            csc2csrs.append(csc2csr + nnz)

        sparse_sizes[0] += tensor.sparse_size(0)
        sparse_sizes[1] += tensor.sparse_size(1)
        nnz += tensor.nnz()

    row: Optional[torch.Tensor] = None
    if len(rows) == len(tensors):
        row = torch.cat(rows, dim=0)

    rowptr: Optional[torch.Tensor] = None
    if len(rowptrs) == len(tensors):
        rowptr = torch.cat(rowptrs, dim=0)

    col = torch.cat(cols, dim=0)

    value: Optional[torch.Tensor] = None
    if len(values) == len(tensors):
        value = torch.cat(values, dim=0)

    rowcount: Optional[torch.Tensor] = None
    if len(rowcounts) == len(tensors):
        rowcount = torch.cat(rowcounts, dim=0)

    colptr: Optional[torch.Tensor] = None
    if len(colptrs) == len(tensors):
        colptr = torch.cat(colptrs, dim=0)

    colcount: Optional[torch.Tensor] = None
    if len(colcounts) == len(tensors):
        colcount = torch.cat(colcounts, dim=0)

    csr2csc: Optional[torch.Tensor] = None
    if len(csr2cscs) == len(tensors):
        csr2csc = torch.cat(csr2cscs, dim=0)

    csc2csr: Optional[torch.Tensor] = None
    if len(csc2csrs) == len(tensors):
        csc2csr = torch.cat(csc2csrs, dim=0)

    storage = SparseStorage(row=row,
                            rowptr=rowptr,
                            col=col,
                            value=value,
                            sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
                            rowcount=rowcount,
                            colptr=colptr,
                            colcount=colcount,
                            csr2csc=csr2csc,
                            csc2csr=csc2csr,
                            is_sorted=True)
    return tensors[0].from_storage(storage)
Exemplo n.º 14
0
def test_sparse_reshape(dtype, device):
    row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
    storage = SparseStorage(row=row, col=col)

    storage = storage.sparse_reshape(2, 8)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]

    storage = storage.sparse_reshape(-1, 4)
    assert storage.sparse_sizes() == (4, 4)
    assert storage.row().tolist() == [0, 1, 2, 3]
    assert storage.col().tolist() == [0, 1, 2, 3]

    storage = storage.sparse_reshape(2, -1)
    assert storage.sparse_sizes() == (2, 8)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 5, 2, 7]
Exemplo n.º 15
0
def test_storage(dtype, device):
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)

    storage = SparseStorage(row=row, col=col)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value() is None
    assert storage.sparse_sizes() == (2, 2)

    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
    value = tensor([2, 1, 4, 3], dtype, device)
    storage = SparseStorage(row=row, col=col, value=value)
    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
    assert storage.sparse_sizes() == (2, 2)
Exemplo n.º 16
0
def test_coalesce(dtype, device):
    row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
    value = tensor([1, 1, 1, 3, 4], dtype, device)
    storage = SparseStorage(row=row, col=col, value=value)

    assert storage.row().tolist() == row.tolist()
    assert storage.col().tolist() == col.tolist()
    assert storage.value().tolist() == value.tolist()

    assert not storage.is_coalesced()
    storage = storage.coalesce()
    assert storage.is_coalesced()

    assert storage.row().tolist() == [0, 0, 1, 1]
    assert storage.col().tolist() == [0, 1, 0, 1]
    assert storage.value().tolist() == [1, 2, 3, 4]
Exemplo n.º 17
0
def test_caching(dtype, device):
    row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
    storage = SparseStorage(row=row, col=col)

    assert storage._row.tolist() == row.tolist()
    assert storage._col.tolist() == col.tolist()
    assert storage._value is None

    assert storage._rowcount is None
    assert storage._rowptr is None
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
    assert storage.num_cached_keys() == 0

    storage.fill_cache_()
    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
    assert storage.num_cached_keys() == 5

    storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
                            value=storage._value,
                            sparse_sizes=storage._sparse_sizes,
                            rowcount=storage._rowcount, colptr=storage._colptr,
                            colcount=storage._colcount,
                            csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)

    assert storage._rowcount.tolist() == [2, 2]
    assert storage._rowptr.tolist() == [0, 2, 4]
    assert storage._colcount.tolist() == [2, 2]
    assert storage._colptr.tolist() == [0, 2, 4]
    assert storage._csr2csc.tolist() == [0, 2, 1, 3]
    assert storage._csc2csr.tolist() == [0, 2, 1, 3]
    assert storage.num_cached_keys() == 5

    storage.clear_cache_()
    assert storage._rowcount is None
    assert storage._rowptr is not None
    assert storage._colcount is None
    assert storage._colptr is None
    assert storage._csr2csc is None
    assert storage.num_cached_keys() == 0
Exemplo n.º 18
0
def test_utility(dtype, device):
    row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
    value = tensor([1, 2, 3, 4], dtype, device)
    storage = SparseStorage(row=row, col=col, value=value)

    assert storage.has_value()

    storage.set_value_(value, layout='csc')
    assert storage.value().tolist() == [1, 3, 2, 4]
    storage.set_value_(value, layout='coo')
    assert storage.value().tolist() == [1, 2, 3, 4]

    storage = storage.set_value(value, layout='csc')
    assert storage.value().tolist() == [1, 3, 2, 4]
    storage = storage.set_value(value, layout='coo')
    assert storage.value().tolist() == [1, 2, 3, 4]

    storage = storage.sparse_resize((3, 3))
    assert storage.sparse_sizes() == (3, 3)

    new_storage = storage.copy()
    assert new_storage != storage
    assert new_storage.col().data_ptr() == storage.col().data_ptr()

    new_storage = storage.clone()
    assert new_storage != storage
    assert new_storage.col().data_ptr() != storage.col().data_ptr()
Exemplo n.º 19
0
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
    assert len(tensors) > 0
    if dim < 0:
        dim = tensors[0].dim() + dim

    if dim == 0:
        rows: List[torch.Tensor] = []
        rowptrs: List[torch.Tensor] = []
        cols: List[torch.Tensor] = []
        values: List[torch.Tensor] = []
        sparse_sizes: List[int] = [0, 0]
        rowcounts: List[torch.Tensor] = []

        nnz: int = 0
        for tensor in tensors:
            row = tensor.storage._row
            if row is not None:
                rows.append(row + sparse_sizes[0])

            rowptr = tensor.storage._rowptr
            if rowptr is not None:
                if len(rowptrs) > 0:
                    rowptr = rowptr[1:]
                rowptrs.append(rowptr + nnz)

            cols.append(tensor.storage._col)

            value = tensor.storage._value
            if value is not None:
                values.append(value)

            rowcount = tensor.storage._rowcount
            if rowcount is not None:
                rowcounts.append(rowcount)

            sparse_sizes[0] += tensor.sparse_size(0)
            sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
            nnz += tensor.nnz()

        row: Optional[torch.Tensor] = None
        if len(rows) == len(tensors):
            row = torch.cat(rows, dim=0)

        rowptr: Optional[torch.Tensor] = None
        if len(rowptrs) == len(tensors):
            rowptr = torch.cat(rowptrs, dim=0)

        col = torch.cat(cols, dim=0)

        value: Optional[torch.Tensor] = None
        if len(values) == len(tensors):
            value = torch.cat(values, dim=0)

        rowcount: Optional[torch.Tensor] = None
        if len(rowcounts) == len(tensors):
            rowcount = torch.cat(rowcounts, dim=0)

        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return tensors[0].from_storage(storage)

    elif dim == 1:
        rows: List[torch.Tensor] = []
        cols: List[torch.Tensor] = []
        values: List[torch.Tensor] = []
        sparse_sizes: List[int] = [0, 0]
        colptrs: List[torch.Tensor] = []
        colcounts: List[torch.Tensor] = []

        nnz: int = 0
        for tensor in tensors:
            row, col, value = tensor.coo()

            rows.append(row)

            cols.append(tensor.storage._col + sparse_sizes[1])

            if value is not None:
                values.append(value)

            colptr = tensor.storage._colptr
            if colptr is not None:
                if len(colptrs) > 0:
                    colptr = colptr[1:]
                colptrs.append(colptr + nnz)

            colcount = tensor.storage._colcount
            if colcount is not None:
                colcounts.append(colcount)

            sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
            sparse_sizes[1] += tensor.sparse_size(1)
            nnz += tensor.nnz()

        row = torch.cat(rows, dim=0)

        col = torch.cat(cols, dim=0)

        value: Optional[torch.Tensor] = None
        if len(values) == len(tensors):
            value = torch.cat(values, dim=0)

        colptr: Optional[torch.Tensor] = None
        if len(colptrs) == len(tensors):
            colptr = torch.cat(colptrs, dim=0)

        colcount: Optional[torch.Tensor] = None
        if len(colcounts) == len(tensors):
            colcount = torch.cat(colcounts, dim=0)

        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=None, is_sorted=False)
        return tensors[0].from_storage(storage)

    elif dim > 1 and dim < tensors[0].dim():
        values: List[torch.Tensor] = []
        for tensor in tensors:
            value = tensor.storage.value()
            if value is not None:
                values.append(value)

        value: Optional[torch.Tensor] = None
        if len(values) == len(tensors):
            value = torch.cat(values, dim=dim - 1)

        return tensors[0].set_value(value, layout='coo')
    else:
        raise IndexError(
            (f'Dimension out of range: Expected to be in range of '
             f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'))
Exemplo n.º 20
0
def index_select(src: SparseTensor, dim: int,
                 idx: torch.Tensor) -> SparseTensor:
    dim = src.dim() + dim if dim < 0 else dim
    assert idx.dim() == 1

    if dim == 0:
        old_rowptr, col, value = src.csr()
        rowcount = src.storage.rowcount()

        rowcount = rowcount[idx]

        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

        perm = torch.arange(row.size(0), device=row.device)
        perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)

        col = col[perm]

        if value is not None:
            value = value[perm]

        sparse_sizes = (idx.size(0), src.sparse_size(1))

        storage = SparseStorage(row=row,
                                rowptr=rowptr,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colptr=None,
                                colcount=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        old_colptr, row, value = src.csc()
        colcount = src.storage.colcount()

        colcount = colcount[idx]

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])

        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

        perm = torch.arange(col.size(0), device=col.device)
        perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)

        row = row[perm]
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]

        if value is not None:
            value = value[perm][csc2csr]

        sparse_sizes = (src.sparse_size(0), idx.size(0))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colptr=colptr,
                                colcount=colcount,
                                csr2csc=None,
                                csc2csr=csc2csr,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
Exemplo n.º 21
0
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
    if dim < 0:
        dim = src.dim() + dim

    if start < 0:
        start = src.size(dim) + start

    if dim == 0:
        rowptr, col, value = src.csr()

        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]

        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start

        col = col.narrow(0, row_start, row_length)

        if value is not None:
            value = value.narrow(0, row_start, row_length)

        sparse_sizes = (length, src.sparse_size(1))

        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=rowptr,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colptr=None,
                                colcount=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
        row, col, value = src.coo()
        mask = (col >= start) & (col < start + length)

        row = row[mask]
        col = col[mask] - start

        if value is not None:
            value = value[mask]

        sparse_sizes = (src.sparse_size(0), length)

        colptr = src.storage._colptr
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colptr=colptr,
                                colcount=colcount,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError