def partition( src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: assert num_parts >= 1 if num_parts == 1: partptr = torch.tensor([0, src.size(0)], device=src.device()) perm = torch.arange(src.size(0), device=src.device()) return src, partptr, perm rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu() if value is not None and weighted: assert value.numel() == col.numel() value = value.view(-1).detach().cpu() if value.is_floating_point(): value = weight2metis(value) else: value = None cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, recursive) cluster = cluster.to(src.device()) cluster, perm = cluster.sort() out = permute(src, perm) partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) return out, partptr, perm
def reverse_cuthill_mckee( src: SparseTensor, is_symmetric: Optional[bool] = None ) -> Tuple[SparseTensor, torch.Tensor]: if is_symmetric is None: is_symmetric = src.is_symmetric() if not is_symmetric: src = src.to_symmetric() sp_src = src.to_scipy(layout='csr') perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy() perm = torch.from_numpy(perm).to(torch.long).to(src.device()) out = permute(src, perm) return out, perm