示例#1
0
def t(src: SparseTensor) -> SparseTensor:
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]

    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
        rowptr=src.storage._colptr,
        col=row[csr2csc],
        value=value,
        sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
        colcount=src.storage._rowcount,
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
示例#2
0
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
    row, col, value = src.coo()
    inv_mask = row != col if k == 0 else row != (col - k)
    new_row, new_col = row[inv_mask], col[inv_mask]

    if value is not None:
        value = value[inv_mask]

    rowcount = src.storage._rowcount
    colcount = src.storage._colcount
    if rowcount is not None or colcount is not None:
        mask = ~inv_mask
        if rowcount is not None:
            rowcount = rowcount.clone()
            rowcount[row[mask]] -= 1
        if colcount is not None:
            colcount = colcount.clone()
            colcount[col[mask]] -= 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
示例#3
0
def set_diag(src: SparseTensor,
             values: Optional[torch.Tensor] = None,
             k: int = 0) -> SparseTensor:
    src = remove_diag(src, k=k)
    row, col, value = src.coo()

    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
    inv_mask = ~mask

    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=row.device)

    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag

    new_col = col.new_empty(mask.size(0))
    new_col[mask] = col
    new_col[inv_mask] = diag.add_(k)

    new_value: Optional[torch.Tensor] = None
    if value is not None:
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
        new_value[mask] = value
        if values is not None:
            new_value[inv_mask] = values
        else:
            new_value[inv_mask] = torch.ones((num_diag, ),
                                             dtype=value.dtype,
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=new_value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
示例#4
0
def get_diag(src: SparseTensor) -> Tensor:
    row, col, value = src.coo()

    if value is None:
        value = torch.ones(row.size(0))

    sizes = list(value.size())
    sizes[0] = min(src.size(0), src.size(1))

    out = value.new_zeros(sizes)

    mask = row == col
    out[row[mask]] = value[mask]
    return out
示例#5
0
def masked_select_nnz(src: SparseTensor,
                      mask: torch.Tensor,
                      layout: Optional[str] = None) -> SparseTensor:
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
        mask = mask[src.storage.csc2csr()]

    row, col, value = src.coo()
    row, col = row[mask], col[mask]

    if value is not None:
        value = value[mask]

    return SparseTensor(row=row,
                        rowptr=None,
                        col=col,
                        value=value,
                        sparse_sizes=src.sparse_sizes(),
                        is_sorted=True)
示例#6
0
def index_select_nnz(src: SparseTensor,
                     idx: torch.Tensor,
                     layout: Optional[str] = None) -> SparseTensor:
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = src.storage.csc2csr()[idx]

    row, col, value = src.coo()
    row, col = row[idx], col[idx]

    if value is not None:
        value = value[idx]

    return SparseTensor(row=row,
                        rowptr=None,
                        col=col,
                        value=value,
                        sparse_sizes=src.sparse_sizes(),
                        is_sorted=True)
示例#7
0
def saint_subgraph(
        src: SparseTensor,
        node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
    row, col, value = src.coo()
    rowptr = src.storage.rowptr()

    data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
    row, col, edge_index = data

    if value is not None:
        value = value[edge_index]

    out = SparseTensor(row=row,
                       rowptr=None,
                       col=col,
                       value=value,
                       sparse_sizes=(node_idx.size(0), node_idx.size(0)),
                       is_sorted=True)

    return out, edge_index
示例#8
0
def masked_select(src: SparseTensor, dim: int,
                  mask: torch.Tensor) -> SparseTensor:
    dim = src.dim() + dim if dim < 0 else dim

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
        row, col, value = src.coo()
        rowcount = src.storage.rowcount()

        rowcount = rowcount[mask]

        mask = mask[row]
        row = torch.arange(rowcount.size(0),
                           device=row.device).repeat_interleave(rowcount)

        col = col[mask]

        if value is not None:
            value = value[mask]

        sparse_sizes = (rowcount.size(0), src.sparse_size(1))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colcount=None,
                                colptr=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        row, col, value = src.coo()
        csr2csc = src.storage.csr2csc()
        row = row[csr2csc]
        col = col[csr2csc]
        colcount = src.storage.colcount()

        colcount = colcount[mask]

        mask = mask[col]
        col = torch.arange(colcount.size(0),
                           device=col.device).repeat_interleave(colcount)
        row = row[mask]
        csc2csr = (colcount.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]

        if value is not None:
            value = value[csr2csc][mask][csc2csr]

        sparse_sizes = (src.sparse_size(0), colcount.size(0))

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colcount=colcount,
                                colptr=None,
                                csr2csc=None,
                                csc2csr=csc2csr,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            idx = mask.nonzero().flatten()
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
示例#9
0
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
    if dim < 0:
        dim = src.dim() + dim

    if start < 0:
        start = src.size(dim) + start

    if dim == 0:
        rowptr, col, value = src.csr()

        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]

        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start

        col = col.narrow(0, row_start, row_length)

        if value is not None:
            value = value.narrow(0, row_start, row_length)

        sparse_sizes = (length, src.sparse_size(1))

        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=rowptr,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=rowcount,
                                colptr=None,
                                colcount=None,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
        row, col, value = src.coo()
        mask = (col >= start) & (col < start + length)

        row = row[mask]
        col = col[mask] - start

        if value is not None:
            value = value[mask]

        sparse_sizes = (src.sparse_size(0), length)

        colptr = src.storage._colptr
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row,
                                rowptr=None,
                                col=col,
                                value=value,
                                sparse_sizes=sparse_sizes,
                                rowcount=None,
                                colptr=colptr,
                                colcount=colcount,
                                csr2csc=None,
                                csc2csr=None,
                                is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError