Exemplo n.º 1
0
def test(model, data, evaluator):
    print('Evaluating full-batch GNN on CPU...')

    weights = [(conv.lin_rel.weight.t().cpu().detach().numpy(),
                conv.lin_rel.bias.cpu().detach().numpy(),
                conv.lin_root.weight.t().cpu().detach().numpy())
               for conv in model.convs]
    model = SAGEInference(weights)

    x = data.x.numpy()
    adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1])
    adj = adj.sum(dim=1).pow(-1).view(-1, 1) * adj
    adj = adj.to_scipy(layout='csr')

    out = model(x, adj)

    y_true = data.y
    y_pred = torch.from_numpy(out).argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': y_true[data.train_mask],
        'y_pred': y_pred[data.train_mask]
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': y_true[data.valid_mask],
        'y_pred': y_pred[data.valid_mask]
    })['acc']
    test_acc = evaluator.eval({
        'y_true': y_true[data.test_mask],
        'y_pred': y_pred[data.test_mask]
    })['acc']

    return train_acc, valid_acc, test_acc
Exemplo n.º 2
0
def get_adj(row, col, N, asymm_norm=False, set_diag=True, remove_diag=False):

    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    if set_diag:
        print('... setting diagonal entries')
        adj = adj.set_diag()
    elif remove_diag:
        print('... removing diagonal entries')
        adj = adj.remove_diag()
    else:
        print('... keeping diag elements as they are')
    if not asymm_norm:
        print('... performing symmetric normalization')
        deg = adj.sum(dim=1).to(torch.float)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)
    else:
        print('... performing asymmetric normalization')
        deg = adj.sum(dim=1).to(torch.float)
        deg_inv = deg.pow(-1.0)
        deg_inv[deg_inv == float('inf')] = 0
        adj = deg_inv.view(-1, 1) * adj

    adj = adj.to_scipy(layout='csr')

    return adj
Exemplo n.º 3
0
def spectral(data, name, embedding_dim=128):
    try:
        result = load_embedding(name, 'spectral')
        return result
    except FileNotFoundError:
        print(
            f'cache/feature/spectral_{name}.pt not found! Regenerating it now')

    from julia.api import Julia
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include(f'{CWD}/norm_spec.jl')
    print('Setting up spectral embedding')

    if data.setting == 'inductive':
        N = data.num_train_nodes
        edge_index = to_undirected(data.train_edge_index,
                                   num_nodes=data.num_train_nodes)
    else:
        N = data.num_nodes
        edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)

    np_edge_index = np.array(edge_index.T)
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.to_scipy(layout='csr')

    result = torch.tensor(Main.main(adj, embedding_dim)).float()
    save_embedding(result, name, 'spectral')
    return result
Exemplo n.º 4
0
    def init_adj(self, edge_index):
        """ cache normalized adjacency and normalized strict two-hop adjacency,
        neither has self loops
        """
        n = self.num_nodes
        
        if isinstance(edge_index, SparseTensor):
            dev = adj_t.device
            adj_t = edge_index
            adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy())
            adj_t[adj_t > 0] = 1
            adj_t[adj_t < 0] = 0
            adj_t = SparseTensor.from_scipy(adj_t).to(dev)
        elif isinstance(edge_index, torch.Tensor):
            row, col = edge_index
            adj_t = SparseTensor(row=col, col=row, value=None, sparse_sizes=(n, n))

        adj_t.remove_diag(0)
        adj_t2 = matmul(adj_t, adj_t)
        adj_t2.remove_diag(0)
        adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy())
        adj_t2 = scipy.sparse.csr_matrix(adj_t2.to_scipy())
        adj_t2 = adj_t2 - adj_t
        adj_t2[adj_t2 > 0] = 1
        adj_t2[adj_t2 < 0] = 0

        adj_t = SparseTensor.from_scipy(adj_t)
        adj_t2 = SparseTensor.from_scipy(adj_t2)
        
        adj_t = gcn_norm(adj_t, None, n, add_self_loops=False)
        adj_t2 = gcn_norm(adj_t2, None, n, add_self_loops=False)

        self.adj_t = adj_t.to(edge_index.device)
        self.adj_t2 = adj_t2.to(edge_index.device)
def preprocess(data,
               preprocess="diffusion",
               num_propagations=10,
               p=None,
               alpha=None,
               use_cache=True,
               post_fix=""):
    if use_cache:
        try:
            x = torch.load(f'embeddings/{preprocess}{post_fix}.pt')
            print('Using cache')
            return x
        except:
            print(
                f'embeddings/{preprocess}{post_fix}.pt not found or not enough iterations! Regenerating it now'
            )
            # Creates a new file
            with open(f'embeddings/{preprocess}{post_fix}.pt', 'w') as fp:
                pass

    if preprocess == "community":
        return community(data, post_fix)

    if preprocess == "spectral":
        return spectral(data, post_fix)

    print('Computing adj...')
    N = data.num_nodes
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    row, col = data.edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.set_diag()
    deg = adj.sum(dim=1).to(torch.float)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)

    adj = adj.to_scipy(layout='csr')

    sgc_dict = {}

    print(f'Start {preprocess} processing')

    if preprocess == "sgc":
        result = sgc(data.x.numpy(), adj, num_propagations)
#     if preprocess == "lp":
#         result = lp(adj, data.y.data, num_propagations, p = p, alpha = alpha, preprocess = preprocess)
    if preprocess == "diffusion":
        result = diffusion(data.x.numpy(),
                           adj,
                           num_propagations,
                           p=p,
                           alpha=alpha)

    torch.save(result, f'embeddings/{preprocess}{post_fix}.pt')

    return result
Exemplo n.º 6
0
def test(model, predictor, data, split_edge, evaluator, batch_size, device):
    predictor.eval()
    print('Evaluating full-batch GNN on CPU...')

    weights = [(conv.weight.cpu().detach().numpy(),
                conv.bias.cpu().detach().numpy()) for conv in model.convs]
    model = GCNInference(weights)

    x = data.x.numpy()
    adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1])
    adj = adj.set_diag()
    deg = adj.sum(dim=1)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)
    adj = adj.to_scipy(layout='csr')

    h = torch.from_numpy(model(x, adj)).to(device)

    def test_split(split):
        source = split_edge[split]['source_node'].to(device)
        target = split_edge[split]['target_node'].to(device)
        target_neg = split_edge[split]['target_node_neg'].to(device)

        pos_preds = []
        for perm in DataLoader(range(source.size(0)), batch_size):
            src, dst = source[perm], target[perm]
            pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()]
        pos_pred = torch.cat(pos_preds, dim=0)

        neg_preds = []
        source = source.view(-1, 1).repeat(1, 1000).view(-1)
        target_neg = target_neg.view(-1)
        for perm in DataLoader(range(source.size(0)), batch_size):
            src, dst_neg = source[perm], target_neg[perm]
            neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()]
        neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000)

        return evaluator.eval({
            'y_pred_pos': pos_pred,
            'y_pred_neg': neg_pred,
        })['mrr_list'].mean().item()

    train_mrr = test_split('eval_train')
    valid_mrr = test_split('valid')
    test_mrr = test_split('test')

    return train_mrr, valid_mrr, test_mrr
Exemplo n.º 7
0
def spectral(data, post_fix):
    from julia.api import Julia
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./norm_spec.jl")
    print('Setting up spectral embedding')
    data.edge_index = to_undirected(data.edge_index)
    np_edge_index = np.array(data.edge_index.T)

    N = data.num_nodes
    row, col = data.edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.to_scipy(layout='csr')
    result = torch.tensor(Main.main(adj, 128)).float()
    torch.save(result, f'embeddings/spectral{post_fix}.pt')

    return result
Exemplo n.º 8
0
def svd(data, name, embedding_dim=64):
    try:
        result = load_embedding(name, 'svd')
        return result
    except FileNotFoundError:
        print(f'cache/feature/svd_{name}.pt not found! Regenerating it now')

    if data.setting == 'inductive':
        N = data.num_train_nodes
        row, col = data.train_edge_index
    else:
        N = data.num_nodes
        row, col = data.edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.to_scipy(layout='csc')
    ut, s, vt = sparsesvd(adj, embedding_dim)

    result = torch.tensor(np.dot(ut.T, np.diag(s)), dtype=torch.float)
    save_embedding(result, name, 'svd')
    return result
def spectral(data, post_fix):
    # from julia.api import Julia
    # jl = Julia(compiled_modules=False)
    # from julia import Main
    # Main.include("./norm_spec.jl")
    print('Setting up spectral embedding')
    data.edge_index = to_undirected(data.edge_index)
    np_edge_index = np.array(data.edge_index.T)

    N = data.num_nodes
    row, col = data.edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.to_scipy(layout='csr')

    tsvd = TruncatedSVD(n_components=128)
    adj_tsvd = tsvd.fit(adj).transform(adj)
    result = torch.tensor(adj_tsvd).float()

    # result = torch.tensor(adj.todense()).float()
    torch.save(result, f'embeddings/spectral{post_fix}.pt')

    return result
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description='OGBN-papers100M (MLP)')
    parser.add_argument('--data_root_dir', type=str, default='../../dataset')
    parser.add_argument('--num_propagations', type=int, default=3)
    parser.add_argument('--dropedge_rate', type=float, default=0.4)
    parser.add_argument('--node_emb_path', type=str, default=None)
    parser.add_argument('--output_path', type=str, required=True)
    args = parser.parse_args()

    # SGC pre-processing ######################################################

    dataset = PygNodePropPredDataset(name='ogbn-papers100M',
                                     root=args.data_root_dir)
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    x = None
    if args.node_emb_path:
        x = np.load(args.node_emb_path)
    else:
        x = data.x.numpy()
    N = data.num_nodes

    print('Making the graph undirected.')
    ### Randomly drop some edges to save computation
    data.edge_index, _ = dropout_adj(data.edge_index,
                                     p=args.dropedge_rate,
                                     num_nodes=data.num_nodes)
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)

    print(data)

    row, col = data.edge_index

    print('Computing adj...')

    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    adj = adj.set_diag()
    deg = adj.sum(dim=1).to(torch.float)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)

    adj = adj.to_scipy(layout='csr')

    train_idx, valid_idx, test_idx = split_idx['train'], split_idx[
        'valid'], split_idx['test']
    all_idx = torch.cat([train_idx, valid_idx, test_idx])
    mapped_train_idx = torch.arange(len(train_idx))
    mapped_valid_idx = torch.arange(len(train_idx),
                                    len(train_idx) + len(valid_idx))
    mapped_test_idx = torch.arange(
        len(train_idx) + len(valid_idx),
        len(train_idx) + len(valid_idx) + len(test_idx))

    sgc_dict = {}
    sgc_dict['label'] = data.y.data[all_idx].to(torch.long)
    sgc_dict['split_idx'] = {
        'train': mapped_train_idx,
        'valid': mapped_valid_idx,
        'test': mapped_test_idx
    }

    print('Start SGC processing')
    for _ in tqdm(range(args.num_propagations)):
        x = adj @ x
    sgc_dict['sgc_embedding'] = torch.from_numpy(x[all_idx]).to(torch.float)
    torch.save(sgc_dict, args.output_path)
Exemplo n.º 11
0
def amfs(
    A: SparseTensor,
    Sigma=None,
    level=None,
    delta=0.1,
    thresh_kld=1e-6,
    priority=True,
    verbose=False,
) -> Tuple[List[lil_matrix], np.ndarray]:
    r"""
    AMFS bipartite approximation for graph  wavelet signal processing [3]_.

    Parameters
    ----------
    A:          SparseTensor
        The adjacency matrix.
    Sigma:      scipy.spmatrix, optional
        The covariance matrix specified by the Laplacian matrix L. If None,
        :math:`\Sigma^{-1}=L+\delta I`
    level:      int, optional
        The number of bipartite subgraphs, i.e., the decomposition level. If None,
        :math:`level=\lceil log_2( \mathcal{X}) \rceil`, where :math:`\mathcal{X}` is
        the chromatic number of  :obj:`A`.
    delta:      float, optional
        :math:`1/\delta` is interpreted as the variance of the DC compnent. Refer to
        [4]_ for more details.
    thresh_kld: float, optional
        Threshold of Kullback-Leibler divergence to perform `AMFS` decomposition.
    priority:   bool,optional
        If True, KLD holds priority.
    verbose:    bool,optional

    Returns
    -------
    bptG:   List[SparseTensor]
        The bipartite subgraphs.
    beta:   Tensor(N, M)
        The indicator of bipartite sets

    References
    ----------
    .. [3]  Jing Zen, et al, "Bipartite Subgraph Decomposition for Critically
            Sampledwavelet Filterbanks on Arbitrary Graphs," IEEE trans on SP, 2016.
    .. [4]  A. Gadde, et al, "A probablistic interpretation of sampling theory of graph
            signals". ICASSP, 2015.

    """

    N = A.size(-1)
    # compute_sigma consists of laplace matrix which prefers "coo"
    A = A.to_scipy(layout="coo").astype("d")
    if Sigma is None:
        Sigma = compute_sigma(A, delta)
    else:
        assert Sigma.shape == (N, N)
    if level is None:
        chromatic = dsatur(A).n_color
        level = np.ceil(np.log2(chromatic))

    A = A.tolil()
    beta = np.zeros((N, level), dtype=bool)
    bptG = [lil_matrix((N, N), dtype=A.dtype) for _ in range(level)]
    for i in range(level):
        if verbose:
            print(f"\n|----decomposition in level: {i:4d} ----|")
        s1, s2 = amfs1level(A, Sigma, delta, thresh_kld, priority, verbose)
        bt = beta[:, i]
        bt[s1] = 1  # set s1 True
        mask = bipartite_mask(bt)
        bptG[i][mask] = A[mask]
        A[mask] = 0
    return bptG, beta
Exemplo n.º 12
0
def osglm(A: SparseTensor,
          lc: Optional[int] = None,
          vtx_color: Optional[VertexColor] = None):
    r"""
    The oversampled bipartite graph approximation method proposed in [1]_

    Parameters
    ----------
    A:        SparseTensor
      The adjacent matrix of graph.
    lc:       int
      The ordinal of color marking the boundary such that all nodes with a smaller color
      ordinal are grouped into the low-pass channel while those with a larger color
      ordinal are in the high-pass channel.
    vtx_color:iter
      The graph coloring result

    Returns
    -------
    bptG:         lil_matrix
              The oversampled graph(with additional nodes)
    beta :        np.ndarray
    append_nodes: np.ndarray
              The indices of those appended nodes
    vtx_color:    np.ndarray
              The node colors

    References
    ----------
    .. [1]  Akie Sakiyama, et al, "Oversampled Graph Laplacian Matrix for Graph Filter
            Banks", IEEE trans on SP, 2016.

    """
    if vtx_color is None:
        from thgsp.alg import dsatur

        vtx_color = dsatur(A)
    vtx_color = np.asarray(vtx_color)
    n_color = max(vtx_color) + 1

    if lc is None:
        lc = n_color // 2
    assert 1 <= lc < n_color

    A = A.to_scipy(layout="csr").tolil()
    # the foundation bipartite graph Gb
    Gb = lil_matrix(A.shape, dtype=A.dtype)
    N = A.shape[-1]

    bt = np.in1d(vtx_color, range(lc))
    idx_s1 = np.nonzero(bt)[0]  # L
    idx_s2 = np.nonzero(~bt)[0]  # H

    mask = bipartite_mask(bt)  # the desired edges
    Gb[mask] = A[mask]
    A[mask] = 0
    eye_mask = eye(N, N, dtype=bool)
    A[eye_mask] = 1  # add vertical edges

    degree = A.sum(0).getA1()  # 2D np.matrix -> 1D np.array
    append_nodes = (degree != 0).nonzero()[0]

    Nos = len(append_nodes) + N  # oversampled size
    bptG = [lil_matrix((Nos, Nos), dtype=A.dtype)]  # the expanded graph
    bptG[0][:N, N:] = A[:, append_nodes]
    bptG[0][:N, :N] = Gb
    bptG[0][N:, :N] = A[append_nodes, :]

    beta = np.zeros((Nos, 1), dtype=bool)
    beta[idx_s1, 0] = 1
    # appended nodes corresponding to idx_s2 are assigned to
    # the L channel of oversampled graph with idx_s1
    _, node_ordinal_append, _ = np.intersect1d(append_nodes,
                                               idx_s2,
                                               return_indices=True)
    beta[N + node_ordinal_append, 0] = 1
    return bptG, beta, append_nodes, vtx_color
def tree_decomposition(mol, return_vocab=False):
    r"""The tree decomposition algorithm of molecules from the
    `"Junction Tree Variational Autoencoder for Molecular Graph Generation"
    <https://arxiv.org/abs/1802.04364>`_ paper.
    Returns the graph connectivity of the junction tree, the assignment
    mapping of each atom to the clique in the junction tree, and the number
    of cliques.

    Args:
        mol (rdkit.Chem.Mol): A :obj:`rdkit` molecule.
        return_vocab (bool, optional): If set to :obj:`True`, will return an
            identifier for each clique (ring, bond, bridged compounds, single).
            (default: :obj:`False`)

    :rtype: (LongTensor, LongTensor, int)
    """
    import rdkit.Chem as Chem

    # Cliques = rings and bonds.
    cliques = [list(x) for x in Chem.GetSymmSSSR(mol)]
    xs = [0] * len(cliques)
    for bond in mol.GetBonds():
        if not bond.IsInRing():
            cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
            xs.append(1)

    # Generate `atom2clique` mappings.
    atom2clique = [[] for i in range(mol.GetNumAtoms())]
    for c in range(len(cliques)):
        for atom in cliques[c]:
            atom2clique[atom].append(c)

    # Merge rings that share more than 2 atoms as they form bridged compounds.
    for c1 in range(len(cliques)):
        for atom in cliques[c1]:
            for c2 in atom2clique[atom]:
                if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2:
                    continue
                if len(set(cliques[c1]) & set(cliques[c2])) > 2:
                    cliques[c1] = set(cliques[c1]) | set(cliques[c2])
                    xs[c1] = 2
                    cliques[c2] = []
                    xs[c2] = -1
    cliques = [c for c in cliques if len(c) > 0]
    xs = [x for x in xs if x >= 0]

    # Update `atom2clique` mappings.
    atom2clique = [[] for i in range(mol.GetNumAtoms())]
    for c in range(len(cliques)):
        for atom in cliques[c]:
            atom2clique[atom].append(c)

    # Add singleton cliques in case there are more than 2 intersecting
    # cliques. We further compute the "initial" clique graph.
    edges = {}
    for atom in range(mol.GetNumAtoms()):
        cs = atom2clique[atom]
        if len(cs) <= 1:
            continue

        # Number of bond clusters that the atom lies in.
        bonds = [c for c in cs if len(cliques[c]) == 2]
        # Number of ring clusters that the atom lies in.
        rings = [c for c in cs if len(cliques[c]) > 4]

        if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2):
            cliques.append([atom])
            xs.append(3)
            c2 = len(cliques) - 1
            for c1 in cs:
                edges[(c1, c2)] = 1

        elif len(rings) > 2:
            cliques.append([atom])
            xs.append(3)
            c2 = len(cliques) - 1
            for c1 in cs:
                edges[(c1, c2)] = 99

        else:
            for i in range(len(cs)):
                for j in range(i + 1, len(cs)):
                    c1, c2 = cs[i], cs[j]
                    count = len(set(cliques[c1]) & set(cliques[c2]))
                    edges[(c1, c2)] = min(count, edges.get((c1, c2), 99))

    # Update `atom2clique` mappings.
    atom2clique = [[] for i in range(mol.GetNumAtoms())]
    for c in range(len(cliques)):
        for atom in cliques[c]:
            atom2clique[atom].append(c)

    if len(edges) > 0:
        edge_index_T, weight = zip(*edges.items())
        row, col = torch.tensor(edge_index_T).t()
        inv_weight = 100 - torch.tensor(weight)
        clique_graph = SparseTensor(row=row,
                                    col=col,
                                    value=inv_weight,
                                    sparse_sizes=(len(cliques), len(cliques)))
        junc_tree = minimum_spanning_tree(clique_graph.to_scipy('csr'))
        row, col, _ = SparseTensor.from_scipy(junc_tree).coo()
        edge_index = torch.stack([row, col], dim=0)
        edge_index = to_undirected(edge_index, num_nodes=len(cliques))
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    rows = [[i] * len(atom2clique[i]) for i in range(mol.GetNumAtoms())]
    row = torch.tensor(list(chain.from_iterable(rows)))
    col = torch.tensor(list(chain.from_iterable(atom2clique)))
    atom2clique = torch.stack([row, col], dim=0).to(torch.long)

    if return_vocab:
        vocab = torch.tensor(xs, dtype=torch.long)
        return edge_index, atom2clique, len(cliques), vocab
    else:
        return edge_index, atom2clique, len(cliques)
Exemplo n.º 14
0
def harary(
    A: SparseTensor,
    vtx_color: Optional[VertexColor] = None,
    threshold: float = 0.97
) -> Tuple[List[lil_matrix], np.ndarray, np.ndarray, np.ndarray, Dict[int,
                                                                      int]]:
    """
    Harary bipartite decomposition

    Parameters
    ----------
    A:      :py:class:`SparseTensor`
        The adjacency matrix
    vtx_color: array_like, optional
        All valid type for :py:func:`np.asarray` is acceptable, including
        :py:class:`torch.Tensor` on cpu. If None, this function will invoke
        :py:func:`thgsp.alg.dsatur` silently.

    threshold: float, optional

    Returns
    -------
    bptG:   List[lil_matrix]
            An array consisting of :obj:`M` bipartite subgraphs formatted
            as :class:`scipy.sparse.lil_matrix`
    beta:   array
        :obj:`beta[:,i]` is the bipartite set indicator of :obj:`i`-th subgraph
    beta_dist:  array
        A table showing the relationship between :obj:`beta` and  :obj:`channels`
    new_vtx_color:  array
        The node colors
    mapper:     dict
        Map **new_vtx_color** to the original ones. For example mapper={1:2, 2:3, 3:1}
        map 1,2 and 3-th color to 2,3 and 1, respectively.

    """
    if vtx_color is None:
        vtx_color = dsatur(A)
    vtx_color = np.asarray(vtx_color)
    n_color = max(vtx_color) + 1
    if n_color > 256:
        raise RuntimeError(
            "Too many colors will lead to a too complicated channel division")

    A = A.to_scipy(layout="csr").tolil()
    M = int(np.ceil(np.log2(n_color)))  # the number of bipartite graphs
    N = A.shape[-1]  # the number of nodes

    new_color_ordinal = new_order(n_color)
    mapper = {c: i for i, c in enumerate(new_color_ordinal)}
    new_vtx_color = [mapper[c] for c in vtx_color]

    beta_dist = distribute_color(n_color, M)
    bptG = [lil_matrix((N, N), dtype=A.dtype) for _ in range(M)]
    link_weights = -np.ones(M)
    beta = np.zeros((N, M), dtype=bool)
    for i in range(M):
        colors_L = (beta_dist[:, i] == 1).nonzero()[0]
        bt = np.in1d(new_vtx_color, colors_L)

        beta[:, i] = bt
        mask = bipartite_mask(bt)
        bpt_edges = A[mask]
        bptG[i][mask] = bpt_edges
        link_weights[i] = bpt_edges.sum()
        A[mask] = 0

    ratio_link_weights = link_weights.cumsum(0) / link_weights.sum()
    bpt_idx = (ratio_link_weights >= threshold).nonzero()[0]
    M1 = bpt_idx[0] + 1
    bptG = bptG[:M1]
    max_color = np.power(2, M1)

    beta_dist = distribute_color(max_color, M1)
    beta = beta[:, :M1]
    return bptG, beta, beta_dist, vtx_color, mapper