Example #1
0
def partition(
        src: SparseTensor,
        num_parts: int,
        recursive: bool = False,
        weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:

    assert num_parts >= 1
    if num_parts == 1:
        partptr = torch.tensor([0, src.size(0)], device=src.device())
        perm = torch.arange(src.size(0), device=src.device())
        return src, partptr, perm

    rowptr, col, value = src.csr()
    rowptr, col = rowptr.cpu(), col.cpu()

    if value is not None and weighted:
        assert value.numel() == col.numel()
        value = value.view(-1).detach().cpu()
        if value.is_floating_point():
            value = weight2metis(value)
    else:
        value = None

    cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
                                               recursive)
    cluster = cluster.to(src.device())

    cluster, perm = cluster.sort()
    out = permute(src, perm)
    partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    return out, partptr, perm
Example #2
0
def reverse_cuthill_mckee(
        src: SparseTensor,
        is_symmetric: Optional[bool] = None
) -> Tuple[SparseTensor, torch.Tensor]:

    if is_symmetric is None:
        is_symmetric = src.is_symmetric()

    if not is_symmetric:
        src = src.to_symmetric()

    sp_src = src.to_scipy(layout='csr')
    perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy()
    perm = torch.from_numpy(perm).to(torch.long).to(src.device())

    out = permute(src, perm)

    return out, perm