示例#1
0
def test_saint_subgraph():
    row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
    col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
    adj = SparseTensor(row=row, col=col)
    node_idx = torch.tensor([0, 1, 2])

    adj, edge_index = adj.saint_subgraph(node_idx)
示例#2
0
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
    row, col, value = src.coo()
    inv_mask = row != col if k == 0 else row != (col - k)
    new_row, new_col = row[inv_mask], col[inv_mask]

    if value is not None:
        value = value[inv_mask]

    rowcount = src.storage._rowcount
    colcount = src.storage._colcount
    if rowcount is not None or colcount is not None:
        mask = ~inv_mask
        if rowcount is not None:
            rowcount = rowcount.clone()
            rowcount[row[mask]] -= 1
        if colcount is not None:
            colcount = colcount.clone()
            colcount[col[mask]] -= 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
示例#3
0
def partition(
        src: SparseTensor,
        num_parts: int,
        recursive: bool = False,
        weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()
    rowptr, col = rowptr.cpu(), col.cpu()

    if value is not None and weighted:
        assert value.numel() == col.numel()
        value = value.view(-1).detach().cpu()
        if value.is_floating_point():
            value = weight2metis(value)
    else:
        value = None

    cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
                                               recursive)
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
示例#4
0
def t(src: SparseTensor) -> SparseTensor:
    csr2csc = src.storage.csr2csc()

    row, col, value = src.coo()

    if value is not None:
        value = value[csr2csc]

    sparse_sizes = src.storage.sparse_sizes()

    storage = SparseStorage(
        row=col[csr2csc],
        rowptr=src.storage._colptr,
        col=row[csr2csc],
        value=value,
        sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
        rowcount=src.storage._colcount,
        colptr=src.storage._rowptr,
        colcount=src.storage._rowcount,
        csr2csc=src.storage._csc2csr,
        csc2csr=csr2csc,
        is_sorted=True,
    )

    return src.from_storage(storage)
示例#5
0
def test_fill_diag(dtype, device):
    row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
    value = tensor([1, 2, 3, 4], dtype, device)
    mat = SparseTensor(row=row, col=col, value=value)

    mat = mat.fill_diag(-8, k=-1)
    mat = mat.fill_diag(-8, k=1)
示例#6
0
def adamic_adar(indexA, valueA, indexB, valueB, m, k, n, coalesced=False, sampling=True):
    A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
                     sparse_sizes=(m, k), is_sorted=not coalesced)
    B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
                     sparse_sizes=(k, n), is_sorted=not coalesced)

    deg_A = A.storage.colcount()
    deg_B = B.storage.rowcount()
    deg_normalized = 1.0 / (deg_A + deg_B).to(torch.float)
    deg_normalized[deg_normalized == float('inf')] = 0.0

    D = SparseTensor(row=torch.arange(deg_normalized.size(0), device=valueA.device),
                     col=torch.arange(deg_normalized.size(0), device=valueA.device),
                     value=deg_normalized.type_as(valueA),
                     sparse_sizes=(deg_normalized.size(0), deg_normalized.size(0)))

    out = A @ D @ B
    row, col, values = out.coo()

    num_samples = min(int(valueA.numel()), int(valueB.numel()), values.numel())
    if sampling and values.numel() > num_samples:
        idx = torch.multinomial(values, num_samples=num_samples,
                                replacement=False)
        row, col, values = row[idx], col[idx], values[idx]

    return torch.stack([row, col], dim=0), values
示例#7
0
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced (use the :obj:`coalesced` attribute to force).

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first corresponding dense matrix.
        k (int): The second dimension of first corresponding dense matrix and
            first dimension of second corresponding dense matrix.
        n (int): The second dimension of second corresponding dense matrix.
        coalesced (bool, optional): If set to :obj:`True`, will coalesce both
            input sparse matrices. (default: :obj:`False`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
                     sparse_sizes=(m, k), is_sorted=not coalesced)
    B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
                     sparse_sizes=(k, n), is_sorted=not coalesced)

    C = matmul(A, B)
    row, col, value = C.coo()

    return torch.stack([row, col], dim=0), value
示例#8
0
def test_permute(device):
    row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
    value = tensor([1, 2, 3, 4, 5], torch.float, device)
    adj = SparseTensor(row=row, col=col, value=value)

    row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
    assert row.tolist() == [0, 1, 1, 2, 2]
    assert col.tolist() == [1, 0, 1, 0, 2]
    assert value.tolist() == [3, 2, 1, 4, 5]
示例#9
0
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
                    length: Tuple[int, int]) -> SparseTensor:
    # This function builds the inverse operation of `cat_diag` and should hence
    # only be used on *diagonally stacked* sparse matrices.
    # That's the reason why this method is marked as *private*.

    rowptr, col, value = src.csr()

    rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
    row_start = int(rowptr[0])
    rowptr = rowptr - row_start
    row_length = int(rowptr[-1])

    row = src.storage._row
    if row is not None:
        row = row.narrow(0, row_start, row_length) - start[0]

    col = col.narrow(0, row_start, row_length) - start[1]

    if value is not None:
        value = value.narrow(0, row_start, row_length)

    sparse_sizes = length

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.narrow(0, start[0], length[0])

    colptr = src.storage._colptr
    if colptr is not None:
        colptr = colptr.narrow(0, start[1], length[1] + 1)
        colptr = colptr - int(colptr[0])  # i.e. `row_start`

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.narrow(0, start[1], length[1])

    csr2csc = src.storage._csr2csc
    if csr2csc is not None:
        csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start

    csc2csr = src.storage._csc2csr
    if csc2csr is not None:
        csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start

    storage = SparseStorage(row=row,
                            rowptr=rowptr,
                            col=col,
                            value=value,
                            sparse_sizes=sparse_sizes,
                            rowcount=rowcount,
                            colptr=colptr,
                            colcount=colcount,
                            csr2csc=csr2csc,
                            csc2csr=csc2csr,
                            is_sorted=True)
    return src.from_storage(storage)
示例#10
0
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
    num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
    if k < 0:
        num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))

    value = src.storage.value()
    if value is not None:
        sizes = [num_diag] + src.sizes()[2:]
        return set_diag(src, value.new_full(sizes, fill_value), k)
    else:
        return set_diag(src, None, k)
示例#11
0
def set_diag(src: SparseTensor,
             values: Optional[torch.Tensor] = None,
             k: int = 0) -> SparseTensor:
    src = remove_diag(src, k=k)
    row, col, value = src.coo()

    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
    inv_mask = ~mask

    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=row.device)

    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag

    new_col = col.new_empty(mask.size(0))
    new_col[mask] = col
    new_col[inv_mask] = diag.add_(k)

    new_value: Optional[torch.Tensor] = None
    if value is not None:
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
        new_value[mask] = value
        if values is not None:
            new_value[inv_mask] = values
        else:
            new_value[inv_mask] = torch.ones((num_diag, ),
                                             dtype=value.dtype,
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

    storage = SparseStorage(row=new_row,
                            rowptr=None,
                            col=new_col,
                            value=new_value,
                            sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount,
                            colptr=None,
                            colcount=colcount,
                            csr2csc=None,
                            csc2csr=None,
                            is_sorted=True)
    return src.from_storage(storage)
示例#12
0
def get_diag(src: SparseTensor) -> Tensor:
    row, col, value = src.coo()

    if value is None:
        value = torch.ones(row.size(0))

    sizes = list(value.size())
    sizes[0] = min(src.size(0), src.size(1))

    out = value.new_zeros(sizes)

    mask = row == col
    out[row[mask]] = value[mask]
    return out
示例#13
0
    def forward(self, x, g, **kwargs):
        """

        """
        # Pre-compute (I+D)^{-1} (A+I) (multiply each row of A+I with (1+di)^-1)
        adj_mat_filled_diag = SparseTensor.from_torch_sparse_coo_tensor(
            g.adjacency_matrix(False)).fill_diag(1.)
        adj_mat_filled_diag = adj_mat_filled_diag / adj_mat_filled_diag.sum(
            -1).unsqueeze(-1)  # Divide each row by (1+di)

        if torch.cuda.is_available() and x.is_cuda:
            adj_mat_filled_diag = adj_mat_filled_diag.cuda()

        # This will be of shape [num_output_dim, nx, nx] -> Prohibitive for big nx
        covar_xx = self.covar_module(x).evaluate()

        # covar_xx = (I+D)^{-1} (A+I) K_xx (A+I)^top (I+D)^{-1}
        # First compute  (I+D)^{-1} (A+I) @ K_xx
        xx_t1 = self.sparse_adj_matmul(adj_mat_filled_diag, covar_xx)
        # Then compute  (I+D)^{-1} (A+I) @ ((A+I) @ K_xx).T = (A+I) @ K_xx @ (A+I).T
        covar_full = self.sparse_adj_matmul(adj_mat_filled_diag,
                                            xx_t1.transpose(-2, -1))

        mean_full = self.mean_module(x)
        mean_full = self.sparse_adj_matmul(adj_mat_filled_diag, mean_full)

        return gpytorch.distributions.MultivariateNormal(mean_full, covar_full)
示例#14
0
def test_metis(device):
    value1 = torch.randn(6 * 6, device=device).view(6, 6)
    value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
    value3 = torch.ones(6 * 6, device=device).view(6, 6)

    for value in [value1, value2, value3]:
        mat = SparseTensor.from_dense(value)

        _, partptr, perm = mat.partition(num_parts=2,
                                         recursive=False,
                                         weighted=True)
        assert partptr.numel() == 3
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=2,
                                         recursive=False,
                                         weighted=False)
        assert partptr.numel() == 3
        assert perm.numel() == 6

        _, partptr, perm = mat.partition(num_parts=1,
                                         recursive=False,
                                         weighted=True)
        assert partptr.numel() == 2
        assert perm.numel() == 6
示例#15
0
def test_spmm(dtype, device, reduce):
    src = torch.randn((10, 8), dtype=dtype, device=device)
    src[2:4, :] = 0  # Remove multiple rows.
    src[:, 2:4] = 0  # Remove multiple columns.
    src = SparseTensor.from_dense(src).requires_grad_()
    row, col, value = src.coo()

    other = torch.randn((2, 8, 2),
                        dtype=dtype,
                        device=device,
                        requires_grad=True)

    src_col = other.index_select(-2, col) * value.unsqueeze(-1)
    expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
    if reduce == 'min':
        expected[expected > 1000] = 0
    if reduce == 'max':
        expected[expected < -1000] = 0

    grad_out = torch.randn_like(expected)

    expected.backward(grad_out)
    expected_grad_value = value.grad
    value.grad = None
    expected_grad_other = other.grad
    other.grad = None

    out = matmul(src, other, reduce)
    out.backward(grad_out)

    assert torch.allclose(expected, out, atol=1e-6)
    assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
    assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
示例#16
0
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
    rowptr, col, value = src.csr()
    if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
        other = gather_csr(other.squeeze(1), rowptr)
    elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
        other = other.squeeze(0)[col]
    else:
        raise ValueError(
            f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
            f'(1, {src.size(1)}, ...), but got size {other.size()}.')

    if value is not None:
        value = value.add_(other.to(value.dtype))
    else:
        value = other.add_(1)
    return src.set_value_(value, layout='coo')
示例#17
0
def spmm_max(src: SparseTensor,
             other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    rowptr, col, value = src.csr()

    if value is not None:
        value = value.to(other.dtype)

    return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
示例#18
0
def add_nnz_(src: SparseTensor, other: torch.Tensor,
             layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
        value = value.add_(other.to(value.dtype))
    else:
        value = other.add(1)
    return src.set_value_(value, layout=layout)
示例#19
0
def sparse_select(adj, dim, index):
    """index select on sparse tensor (temporary function)
       torch.index_select on sparse tesnor is too slow to be useful
       https://github.com/pytorch/pytorch/issues/54561
    """
    adj = SparseTensor.from_torch_sparse_coo_tensor(adj)
    adj = index_select(adj, dim, index)
    return adj.to_torch_sparse_coo_tensor()
示例#20
0
def test_eye(dtype, device):
    mat = SparseTensor.eye(3, dtype=dtype, device=device)
    assert mat.storage.col().device == device
    assert mat.storage.sparse_sizes() == (3, 3)
    assert mat.storage.row().tolist() == [0, 1, 2]
    assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
    assert mat.storage.col().tolist() == [0, 1, 2]
    assert mat.storage.value().tolist() == [1, 1, 1]
    assert mat.storage.value().dtype == dtype
    assert mat.storage.num_cached_keys() == 0

    mat = SparseTensor.eye(3, has_value=False)
    assert mat.storage.col().device == device
    assert mat.storage.sparse_sizes() == (3, 3)
    assert mat.storage.row().tolist() == [0, 1, 2]
    assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
    assert mat.storage.col().tolist() == [0, 1, 2]
    assert mat.storage.value() is None
    assert mat.storage.num_cached_keys() == 0

    mat = SparseTensor.eye(3, 4, fill_cache=True)
    assert mat.storage.col().device == device
    assert mat.storage.sparse_sizes() == (3, 4)
    assert mat.storage.row().tolist() == [0, 1, 2]
    assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
    assert mat.storage.col().tolist() == [0, 1, 2]
    assert mat.storage.num_cached_keys() == 5
    assert mat.storage.rowcount().tolist() == [1, 1, 1]
    assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3]
    assert mat.storage.colcount().tolist() == [1, 1, 1, 0]
    assert mat.storage.csr2csc().tolist() == [0, 1, 2]
    assert mat.storage.csc2csr().tolist() == [0, 1, 2]

    mat = SparseTensor.eye(4, 3, fill_cache=True)
    assert mat.storage.col().device == device
    assert mat.storage.sparse_sizes() == (4, 3)
    assert mat.storage.row().tolist() == [0, 1, 2]
    assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]
    assert mat.storage.col().tolist() == [0, 1, 2]
    assert mat.storage.num_cached_keys() == 5
    assert mat.storage.rowcount().tolist() == [1, 1, 1, 0]
    assert mat.storage.colptr().tolist() == [0, 1, 2, 3]
    assert mat.storage.colcount().tolist() == [1, 1, 1]
    assert mat.storage.csr2csc().tolist() == [0, 1, 2]
    assert mat.storage.csc2csr().tolist() == [0, 1, 2]
示例#21
0
def reverse_cuthill_mckee(
        src: SparseTensor,
        is_symmetric: Optional[bool] = None
) -> Tuple[SparseTensor, torch.Tensor]:

    if is_symmetric is None:
        is_symmetric = src.is_symmetric()

    if not is_symmetric:
        src = src.to_symmetric()

    sp_src = src.to_scipy(layout='csr')
    perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy()
    perm = torch.from_numpy(perm).to(torch.long).to(src.device())

    out = permute(src, perm)

    return out, perm
示例#22
0
def mul_nnz(src: SparseTensor,
            other: torch.Tensor,
            layout: Optional[str] = None) -> SparseTensor:
    value = src.storage.value()
    if value is not None:
        value = value.mul(other.to(value.dtype))
    else:
        value = other
    return src.set_value(value, layout=layout)
示例#23
0
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
    assert src.sparse_size(1) == other.sparse_size(0)
    rowptrA, colA, valueA = src.csr()
    rowptrB, colB, valueB = other.csr()
    M, K = src.sparse_size(0), other.sparse_size(1)
    rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
        rowptrA, colA, valueA, rowptrB, colB, valueB, K)
    return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
                        sparse_sizes=(M, K), is_sorted=True)
示例#24
0
def masked_select_nnz(src: SparseTensor,
                      mask: torch.Tensor,
                      layout: Optional[str] = None) -> SparseTensor:
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
        mask = mask[src.storage.csc2csr()]

    row, col, value = src.coo()
    row, col = row[mask], col[mask]

    if value is not None:
        value = value[mask]

    return SparseTensor(row=row,
                        rowptr=None,
                        col=col,
                        value=value,
                        sparse_sizes=src.sparse_sizes(),
                        is_sorted=True)
示例#25
0
def saint_subgraph(
        src: SparseTensor,
        node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]:
    row, col, value = src.coo()
    rowptr = src.storage.rowptr()

    data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
    row, col, edge_index = data

    if value is not None:
        value = value[edge_index]

    out = SparseTensor(row=row,
                       rowptr=None,
                       col=col,
                       value=value,
                       sparse_sizes=(node_idx.size(0), node_idx.size(0)),
                       is_sorted=True)

    return out, edge_index
示例#26
0
def index_select_nnz(src: SparseTensor,
                     idx: torch.Tensor,
                     layout: Optional[str] = None) -> SparseTensor:
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = src.storage.csc2csr()[idx]

    row, col, value = src.coo()
    row, col = row[idx], col[idx]

    if value is not None:
        value = value[idx]

    return SparseTensor(row=row,
                        rowptr=None,
                        col=col,
                        value=value,
                        sparse_sizes=src.sparse_sizes(),
                        is_sorted=True)
示例#27
0
def add(src, other):  # noqa: F811
    if isinstance(other, Tensor):
        rowptr, col, value = src.csr()
        if other.size(0) == src.size(0) and other.size(1) == 1:  # Row-wise.
            other = gather_csr(other.squeeze(1), rowptr)
        elif other.size(0) == 1 and other.size(1) == src.size(1):  # Col-wise.
            other = other.squeeze(0)[col]
        else:
            raise ValueError(
                f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
                f'(1, {src.size(1)}, ...), but got size {other.size()}.')
        if value is not None:
            value = other.to(value.dtype).add_(value)
        else:
            value = other.add_(1)
        return src.set_value(value, layout='coo')

    elif isinstance(other, SparseTensor):
        rowA, colA, valueA = src.coo()
        rowB, colB, valueB = other.coo()

        row = torch.cat([rowA, rowB], dim=0)
        col = torch.cat([colA, colB], dim=0)

        value: Optional[Tensor] = None
        if valueA is not None and valueB is not None:
            value = torch.cat([valueA, valueB], dim=0)

        M = max(src.size(0), other.size(0))
        N = max(src.size(1), other.size(1))
        sparse_sizes = (M, N)

        out = SparseTensor(row=row,
                           col=col,
                           value=value,
                           sparse_sizes=sparse_sizes)
        out = out.coalesce(reduce='sum')
        return out

    else:
        raise NotImplementedError
示例#28
0
def test_spmm_half_precision():
    src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
    src_dense[2:4, :] = 0  # Remove multiple rows.
    src_dense[:, 2:4] = 0  # Remove multiple columns.
    src = SparseTensor.from_dense(src_dense)

    other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')

    expected = (src_dense.to(torch.float) @ other).to(torch.half)
    out = src @ other.to(torch.half)

    assert torch.allclose(expected, out, atol=1e-2)
示例#29
0
def sample_adj(src: SparseTensor,
               subset: torch.Tensor,
               num_neighbors: int,
               replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:

    rowptr, col, value = src.csr()

    rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
        rowptr, col, subset, num_neighbors, replace)

    if value is not None:
        value = value[e_id]

    out = SparseTensor(rowptr=rowptr,
                       row=None,
                       col=col,
                       value=value,
                       sparse_sizes=(subset.size(0), n_id.size(0)),
                       is_sorted=True)

    return out, n_id
示例#30
0
def test_get_diag(dtype, device):
    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device)
    mat = SparseTensor(row=row, col=col, value=value)
    assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]]

    row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
    mat = SparseTensor(row=row, col=col)
    assert mat.get_diag().tolist() == [1, 0, 1]