Beispiel #1
0
 def compute_local_community(self, g):
     r = []
     for i in xrange(3):
         r.append(minimize(lambda x: -1*(gt.modularity(g, gt.community_structure(g, 100, x)))
         , 0, method='COBYLA', options={'disp': True}).x)
     print "Optimal global Modularity: ", np.mean(r)
     b = gt.community_structure(g, 100, np.mean(r))
     return b
Beispiel #2
0
def modularity_communities(network):
    communities_vp = network.new_vertex_property("int")
    communities = []
    for i,v in enumerate(network.vertices()):
        communities.append([i])
        communities_vp[v] = i
    merges = []
    old_modularity = None
    new_modularity = gt.modularity(network, communities_vp)
    while old_modularity is None or new_modularity > old_modularity:
        old_modularity = new_modularity
        trial_communities = copy.deepcopy(communities)
        trial_communities_vp = network.new_vertex_property("int")
        for i,comm in enumerate(trial_communities):
            for agent in comm:
                trial_communities_vp[agent] = i
        to_merge = None
        for i,u in enumerate(communities):
            for j,v in enumerate(communities):
                if j <=i or len(u) == 0 or len(v) == 0:
                    continue
                for agent in u:
                    trial_communities[j].append(agent)
                    trial_communities_vp[agent] = j
                trial_communities[i] = []
                trial_modularity = gt.modularity(network, trial_communities_vp)
                if trial_modularity >= new_modularity:
                    new_modularity = trial_modularity
                    to_merge = (i, j, new_modularity - old_modularity)
                elif ( to_merge and min(i,j) < min(to_merge[0], to_merge[1])):
                    new_modularity = trial_modularity
                    to_merge = (i, j, new_modularity - old_modularity)
                trial_communities[j] = v
                for agent in u:
                    trial_communities_vp[agent] = i
        if to_merge is not None:
            merges.append(to_merge)
            i, j, dq = to_merge
            u,v = communities[i], communities[j]
            for agent in u:
                communities[j].append(agent)
                communities_vp[agent] = j
            communities[i] = []
        communities = sorted([c for c in communities if len(c) > 0], key=lambda x: len(x), reverse=True)
        return communities
Beispiel #3
0
def community_structure_test(graph):
    sys.stdout.write('Getting community structure ...')
    sys.stdout.flush()

    t0 = time.time()
    state = gt.minimize_blockmodel_dl(graph)
    Q = gt.modularity(graph, state.get_blocks())
    t = time.time()
    sys.stdout.write('Ok! ({0} s.)\n'.format(t - t0))
    sys.stdout.flush()

    return Q
Beispiel #4
0
 def compute_modularity(self, user, bus):
     # Create filtered network including only the cluster which restaurant belongs and the restaurant's network
     # User provides a business and I take my master network and filter based on the community which the restaurant
     # belongs to AND on the user-business network.
     group = self.g.vertex_properties['components'][self.businesses_vertex[bus]]
     business_vertex_list = self.gather_business_verticies_of_a_user(user)
     u = gt.GraphView(self.g, vfilt=lambda x: self.g.vertex_properties['components'][x] == group or x in business_vertex_list)
     # state = gt.minimize_blockmodel_dl(u, eweight= u.edge_properties['rating'])
     # b = state.b
     b = self.compute_local_community(u)
     modularity = gt.modularity(u, b)
     print "LOCAL Modularity: ", modularity
     return modularity
Beispiel #5
0
    def validate_communities(self):
        # Use optimal communities to validate new ratings
        self.b = self.compute_global_community()
        globmod = gt.modularity(self.g, self.b)
        print "Optimal GLOBAL Modularity: ", globmod
        self.g.vertex_properties['components'] = self.b

        self.user_community_vertices = defaultdict()
        self.mat_community = np.zeros((self.num_users, self.num_businesses))

        for user, bus, rating, cruisine in review_[0:10]:
            comm = self.g.vertex_properties['components'][self.users_vertex[user]]
            if user not in self.user_community_vertices:
                u = gt.GraphView(self.g, vfilt=lambda x: self.g.vertex_properties['components'][x] == comm and
                                                len(self.g.vertex_properties['user_id'][x]) > 0)
                self.user_community_vertices[user] = u.vertices()
            # need users vertex only which are in this community
            # then find the respective user-bus for each eacher within the community
            bus_vertex = self.businesses_vertex[bus]
            rate = []
            for i in self.user_community_vertices[user]:
                try:
        #             print i, bus_vertex
                    rating = self.g.edge_properties['rating'][self.g.edge(i, bus_vertex)]
                    rate.append(rating)
                except:
                    continue
            # if no community to base rating of off, we simply use the average
            if len(rate) == 0:
                print 'Average Used'
                rate = self.avg_restaurant_rating[self.businesses_mat_dic[bus]]
            self.mat_community[self.users_mat_dic[user],                  self.businesses_mat_dic[bus]] = np.nanmean(rate)

        mse2 = []
        for i in xrange(self.num_businesses):
            mask = self.base_mat[:,i] != 0
            base_vector = self.base_mat[:,i][mask]
            mask = self.mat_community[:,i] != 0
            community_vector = self.mat_community[:,i][mask]

            mse = np.sum((base_vector - community_vector) ** 2)
            mse2.append(mse)
        mse2 = np.array(mse2)
        print "Community MSE: ", np.nanmean(mse2[mse2!=0])
 def get_modularity(self):
     return gt.modularity(self.g,
                          self.g.vertex_properties['comm'],
                          weight=self.e_weights)
Beispiel #7
0
def plug_state(adata: AnnData,
    state: Union[gt.NestedBlockState, 
                gt.BlockState, 
                gt.PPBlockState],
    nested: bool = True,
    key_added: str = 'nsbm',
    calculate_affinity: bool = False,
    copy: bool = False

) -> Optional[AnnData]:
    """\
    Add a state to a dataset, populate the AnnData.obs consistenly 

    Parameters
    ----------
    adata
        The annotated data matrix.
    state
        The graph_tool state. Supported types are NestedBlockState
        BlockState and PPBlockState
    nested
        If False plug only the lowest level, otherwise the full hierarchy
    key_added
        The prefix for annotations

    """

    adata = adata.copy() if copy else adata
    g = state.g
    
    model_type = 'nested'
    if type(state) == gt.PPBlockState:
        model_type = 'planted'
    elif type(state) == gt.BlockState:
        model_type = 'flat'
    
    if type(state) == gt.NestedBlockState:
        bs = state.get_bs()
        if not nested:
            bs = bs[:1]
        groups = np.zeros((g.num_vertices(), len(bs)), dtype=int)
        for x in range(len(bs)):
            groups[:, x] = state.project_partition(x, 0).get_array()
        groups = pd.DataFrame(groups).astype('category')
        for c in groups.columns:
            ncat = len(groups[c].cat.categories)
            new_cat = [u'%s' % x for x in range(ncat)]
            groups[c].cat.rename_categories(new_cat, inplace=True)
        levels = groups.columns
        groups.columns = [f"{key_added}_level_{level}" for level in range(len(bs))]
        groups.index = adata.obs_names
        # remove any column with the same key
        keep_columns = [x for x in adata.obs.columns if not x.startswith('%s_level_' % key_added)]
        adata.obs = adata.obs[keep_columns]
        adata.obs = pd.concat([adata.obs, groups], axis=1)

        adata.uns['schist'] = {}
        adata.uns['schist']['stats'] = dict(
        level_entropy=np.array([state.level_entropy(x) for x in range(len(bs))]),
        modularity=np.array([gt.modularity(g, state.project_partition(x, 0))
                         for x in range(len(bs))])
        )
        adata.uns['schist']['state'] = state
        
        if calculate_affinity:
            p0 = get_cell_loglikelihood(state, level=0, as_prob=True)
            adata.obsm[f'CA_{key_added}_level_0'] = p0
            l0 = "%s_level_0" % key_added
            for nl, level in enumerate(groups.columns[1:]):
                cross_tab = pd.crosstab(groups[l0], groups[level])
                cl = np.zeros((p0.shape[0], cross_tab.shape[1]), dtype=p0.dtype)
                for x in range(cl.shape[1]):
                    # sum counts of level_0 groups corresponding to
                    # this group at current level
                    cl[:, x] = p0[:, np.where(cross_tab.iloc[:, x] > 0)[0]].sum(axis=1)
                adata.obsm[f'CA_{key_added}_level_{nl + 1}'] = cl / np.sum(cl, axis=1)[:, None]

    else:
        
        groups = pd.Series(state.get_blocks().get_array()).astype('category')
        ncat = len(groups.cat.categories)
        new_cat = [u'%s' % x for x in range(ncat)]
        groups.cat.rename_categories(new_cat, inplace=True)
        groups.index = adata.obs_names
        adata.obs[key_added] = groups
        adata.uns['schist'] = {}
        adata.uns['schist']['stats'] = dict(
            modularity=gt.modularity(g, state.get_blocks())
        )
        adata.uns['schist']['state'] = state
        if calculate_affinity:
            adata.obsm[f'CA_{key_added}_level_1'] = get_cell_loglikelihood(state, as_prob=True)
            
    adata.uns['schist']['params'] = dict(
    model=model_type,
    calculate_affinity=calculate_affinity,)


    return adata if copy else None
Beispiel #8
0
def planted_model(
    adata: AnnData,
    n_sweep: int = 10,
    beta: float = np.inf,
    tolerance=1e-6,
    collect_marginals: bool = True,
    deg_corr: bool = True,
    samples: int = 100,
    n_jobs: int = -1,
    *,
    restrict_to: Optional[Tuple[str, Sequence[str]]] = None,
    random_seed: Optional[int] = None,
    key_added: str = 'ppbm',
    adjacency: Optional[sparse.spmatrix] = None,
    neighbors_key: Optional[str] = 'neighbors',
    directed: bool = False,
    use_weights: bool = False,
    copy: bool = False,
    save_model: Union[str, None] = None,
    #    minimize_args: Optional[Dict] = {},
    dispatch_backend: Optional[str] = 'processes',
) -> Optional[AnnData]:
    """\
    Cluster cells into subgroups [Peixoto14]_.

    Cluster cells using the  Planted Partition Block Model [Peixoto14]_, performing
    Bayesian inference on node groups. This function, in particular, uses
    the Planted Block Model, which is particularly suitable in case of
    assortative graphs and it returns the optimal number of communities

    This requires having ran :func:`~scanpy.pp.neighbors` or
    :func:`~scanpy.external.pp.bbknn` first.

    Parameters
    ----------
    adata
        The annotated data matrix.
    n_sweep
        Number of MCMC sweeps to get the initial guess
    beta
        Inverse temperature for the initial MCMC sweep        
    tolerance
        Difference in description length to stop MCMC sweep iterations        
    collect_marginals
        Whether or not collect node probability of belonging
        to a specific partition.
    deg_corr
        Whether to use degree correction in the minimization step. In many
        real world networks this is the case, although this doesn't seem
        the case for KNN graphs used in scanpy.
    samples
        Number of initial minimizations to be performed. This influences also the 
        precision for marginals
    key_added
        `adata.obs` key under which to add the cluster labels.
    adjacency
        Sparse adjacency matrix of the graph, defaults to
        `adata.uns['neighbors']['connectivities']` in case of scanpy<=1.4.6 or
        `adata.obsp[neighbors_key][connectivity_key]` for scanpy>1.4.6
    neighbors_key
        The key passed to `sc.pp.neighbors`
    directed
        Whether to treat the graph as directed or undirected.
    use_weights
        If `True`, edge weights from the graph are used in the computation
        (placing more emphasis on stronger edges). Note that this
        increases computation times
    copy
        Whether to copy `adata` or modify it inplace.
    save_model
        If provided, this will be the filename for the PartitionModeState to 
        be saved    
    random_seed
        Random number to be used as seed for graph-tool
    n_jobs
        Number of parallel computations used during model initialization

    Returns
    -------
    `adata.obs[key_added]`
        Array of dim (number of samples) that stores the subgroup id
        (`'0'`, `'1'`, ...) for each cell.
    `adata.uns['schist']['params']`
        A dict with the values for the parameters `resolution`, `random_state`,
        and `n_iterations`.
    `adata.uns['schist']['stats']`
        A dict with the values returned by mcmc_sweep
    `adata.obsm['CM_ppbm']`
        A `np.ndarray` with cell probability of belonging to a specific group
    `adata.uns['schist']['state']`
        The BlockModel state object
    """

    if random_seed:
        np.random.seed(random_seed)

    seeds = np.random.choice(range(samples**2), size=samples, replace=False)

    if collect_marginals and samples < 100:
        logg.warning(
            'Collecting marginals requires sufficient number of samples\n'
            f'It is now set to {samples} and should be at least 100')

    start = logg.info('minimizing the Planted Partition Block Model')
    adata = adata.copy() if copy else adata
    # are we clustering a user-provided graph or the default AnnData one?
    if adjacency is None:
        if neighbors_key not in adata.uns:
            raise ValueError('You need to run `pp.neighbors` first '
                             'to compute a neighborhood graph.')
        elif 'connectivities_key' in adata.uns[neighbors_key]:
            # scanpy>1.4.6 has matrix in another slot
            conn_key = adata.uns[neighbors_key]['connectivities_key']
            adjacency = adata.obsp[conn_key]
        else:
            # scanpy<=1.4.6 has sparse matrix here
            adjacency = adata.uns[neighbors_key]['connectivities']
    if restrict_to is not None:
        restrict_key, restrict_categories = restrict_to
        adjacency, restrict_indices = restrict_adjacency(
            adata,
            restrict_key,
            restrict_categories,
            adjacency,
        )
    # convert it to igraph and graph-tool
    g = get_igraph_from_adjacency(adjacency, directed=directed)
    g = g.to_graph_tool()
    gt.remove_parallel_edges(g)

    recs = []
    rec_types = []
    if use_weights:
        # this is not ideal to me, possibly we may need to transform
        # weights. More tests needed.
        recs = [g.ep.weight]
        rec_types = ['real-normal']

    if samples < 1:
        samples = 1

    # initialize  the block states
    def fast_min(state, beta, n_sweep, fast_tol, seed=None):
        if seed:
            gt.seed_rng(seed)
        dS = 1
        while np.abs(dS) > fast_tol:
            dS, _, _ = state.multiflip_mcmc_sweep(beta=beta, niter=n_sweep)
        return state

    states = [gt.PPBlockState(g) for x in range(samples)]

    # perform a mcmc sweep on each
    # no list comprehension as I need to collect stats

    states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
        delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x])
        for x in range(samples))
    logg.info('        minimization step done', time=start)
    pmode = gt.PartitionModeState([x.get_blocks().a for x in states],
                                  converge=True)

    bs = pmode.get_max(g)
    logg.info('        consensus step done', time=start)

    if save_model:
        import pickle
        fname = save_model
        if not fname.endswith('pkl'):
            fname = f'{fname}.pkl'
        logg.info(f'Saving model into {fname}')
        with open(fname, 'wb') as fout:
            pickle.dump(pmode, fout, 2)

    state = gt.PPBlockState(g, b=bs)
    logg.info('    done', time=start)

    groups = np.array(bs.get_array())
    u_groups = np.unique(groups)
    n_groups = len(u_groups)
    last_group = np.max(u_groups) + 1
    if collect_marginals:
        pv_array = pmode.get_marginal(g).get_2d_array(
            range(last_group)).T[:, u_groups] / samples

    rosetta = dict(zip(u_groups, range(len(u_groups))))
    groups = np.array([rosetta[x] for x in groups])

    if restrict_to is not None:
        if key_added == 'ppbm':
            key_added += '_R'
        groups = rename_groups(
            adata,
            key_added,
            restrict_key,
            restrict_categories,
            restrict_indices,
            groups,
        )

    # add column names
    adata.obs[key_added] = pd.Categorical(
        values=groups.astype('U'),
        categories=natsorted(map(str, np.unique(groups))),
    )

    # now add marginal probabilities.

    if collect_marginals:
        # cell marginals will be a list of arrays with probabilities
        # of belonging to a specific group
        adata.obsm[f"CM_{key_added}"] = pv_array

    # add some unstructured info
    if not 'schist' in adata.uns:
        adata.uns['schist'] = {}

    adata.uns['schist'][f'{key_added}'] = {}
    adata.uns['schist'][f'{key_added}']['stats'] = dict(
        entropy=state.entropy(),
        modularity=gt.modularity(g, state.get_blocks()))

    # record state as list of blocks
    # for compatibility with nested model, use a dictionary with a single key here
    # although a np.array would be ok
    adata.uns['schist'][f'{key_added}']['blocks'] = {
        '0': np.array(state.get_blocks().a)
    }

    # last step is recording some parameters used in this analysis
    adata.uns['schist'][f'{key_added}']['params'] = dict(
        model='planted',
        use_weights=use_weights,
        neighbors_key=neighbors_key,
        key_added=key_added,
        samples=samples,
        collect_marginals=collect_marginals,
        random_seed=random_seed,
        deg_corr=deg_corr,
        recs=recs,
        rec_types=rec_types)

    logg.info(
        '    finished',
        time=start,
        deep=(
            f'found {state.get_B()} clusters and added\n'
            f'    {key_added!r}, the cluster labels (adata.obs, categorical)'),
    )
    return adata if copy else None
Beispiel #9
0
def planted_model(
    adata: AnnData,
    n_sweep: int = 10,
    beta: float = np.inf,
    tolerance=1e-6,
    max_iterations: int = 1000000,
    epsilon: float = 0,
    equilibrate: bool = False,
    wait: int = 1000,
    nbreaks: int = 2,
    collect_marginals: bool = False,
    niter_collect: int = 10000,
    deg_corr: bool = True,
    n_init: int = 1,
    beta_range: Tuple[float] = (1., 100.),
    steps_anneal: int = 5,
    resume: bool = False,
    *,
    restrict_to: Optional[Tuple[str, Sequence[str]]] = None,
    random_seed: Optional[int] = None,
    key_added: str = 'ppbm',
    adjacency: Optional[sparse.spmatrix] = None,
    neighbors_key: Optional[str] = 'neighbors',
    directed: bool = False,
    use_weights: bool = False,
    copy: bool = False,
    minimize_args: Optional[Dict] = {},
    equilibrate_args: Optional[Dict] = {},
) -> Optional[AnnData]:
    """\
    Cluster cells into subgroups [Peixoto14]_.

    Cluster cells using the  Stochastic Block Model [Peixoto14]_, performing
    Bayesian inference on node groups. This function, in particular, uses
    the Planted Block Model, which is particularly suitable in case of
    assortative graphs and it returns the optimal number of communities

    This requires having ran :func:`~scanpy.pp.neighbors` or
    :func:`~scanpy.external.pp.bbknn` first.

    Parameters
    ----------
    adata
        The annotated data matrix.
    n_sweep
        Number of MCMC sweeps to get the initial guess
    beta
        Inverse temperature for the initial MCMC sweep        
    tolerance
        Difference in description length to stop MCMC sweep iterations        
    max_iterations
        Maximal number of iterations to be performed by the equilibrate step.
    epsilon
        Relative changes in entropy smaller than epsilon will
        not be considered as record-breaking.
    equilibrate
        Whether or not perform the mcmc_equilibrate step.
        Equilibration should always be performed. Note, also, that without
        equilibration it won't be possible to collect marginals.
    collect_marginals
        Whether or not collect node probability of belonging
        to a specific partition.
    niter_collect
        Number of iterations to force when collecting marginals. This will
        increase the precision when calculating probabilites
    wait
        Number of iterations to wait for a record-breaking event.
        Higher values result in longer computations. Set it to small values
        when performing quick tests.
    nbreaks
        Number of iteration intervals (of size `wait`) without
        record-breaking events necessary to stop the algorithm.
    deg_corr
        Whether to use degree correction in the minimization step. In many
        real world networks this is the case, although this doesn't seem
        the case for KNN graphs used in scanpy.
    n_init
        Number of initial minimizations to be performed. The one with smaller
        entropy is chosen
    beta_range
        Inverse temperature at the beginning and the end of the equilibration
    steps_anneal
        Number of steps in which the simulated annealing is performed
    resume
        Start from a previously created model, if any, without initializing a novel
        model    
    key_added
        `adata.obs` key under which to add the cluster labels.
    adjacency
        Sparse adjacency matrix of the graph, defaults to
        `adata.uns['neighbors']['connectivities']` in case of scanpy<=1.4.6 or
        `adata.obsp[neighbors_key][connectivity_key]` for scanpy>1.4.6
    neighbors_key
        The key passed to `sc.pp.neighbors`
    directed
        Whether to treat the graph as directed or undirected.
    use_weights
        If `True`, edge weights from the graph are used in the computation
        (placing more emphasis on stronger edges). Note that this
        increases computation times
    copy
        Whether to copy `adata` or modify it inplace.
    random_seed
        Random number to be used as seed for graph-tool

    Returns
    -------
    `adata.obs[key_added]`
        Array of dim (number of samples) that stores the subgroup id
        (`'0'`, `'1'`, ...) for each cell.
    `adata.uns['sbm']['params']`
        A dict with the values for the parameters `resolution`, `random_state`,
        and `n_iterations`.
    `adata.uns['sbm']['stats']`
        A dict with the values returned by mcmc_sweep
    `adata.uns['sbm']['cell_affinity']`
        A `np.ndarray` with cell probability of belonging to a specific group
    `adata.uns['sbm']['state']`
        The BlockModel state object
    """

    # first things first
    check_gt_version()

    if resume:
        equilibrate = True

    if resume and (key_added not in adata.uns
                   or 'state' not in adata.uns[key_added]):
        # let the model proceed as default
        logg.warning('Resuming has been specified but a state was not found\n'
                     'Will continue with default minimization step')

        resume = False

    if random_seed:
        np.random.seed(random_seed)
        gt.seed_rng(random_seed)

    if collect_marginals:
        logg.warning('Collecting marginals has a large impact on running time')
        if not equilibrate:
            raise ValueError(
                "You can't collect marginals without MCMC equilibrate "
                "step. Either set `equlibrate` to `True` or "
                "`collect_marginals` to `False`")

    start = logg.info('minimizing the Planted Partition Block Model')
    adata = adata.copy() if copy else adata
    # are we clustering a user-provided graph or the default AnnData one?
    if adjacency is None:
        if neighbors_key not in adata.uns:
            raise ValueError('You need to run `pp.neighbors` first '
                             'to compute a neighborhood graph.')
        elif 'connectivities_key' in adata.uns[neighbors_key]:
            # scanpy>1.4.6 has matrix in another slot
            conn_key = adata.uns[neighbors_key]['connectivities_key']
            adjacency = adata.obsp[conn_key]
        else:
            # scanpy<=1.4.6 has sparse matrix here
            adjacency = adata.uns[neighbors_key]['connectivities']
    if restrict_to is not None:
        restrict_key, restrict_categories = restrict_to
        adjacency, restrict_indices = restrict_adjacency(
            adata,
            restrict_key,
            restrict_categories,
            adjacency,
        )
    # convert it to igraph
    g = get_graph_tool_from_adjacency(adjacency, directed=directed)

    recs = []
    rec_types = []
    if use_weights:
        # this is not ideal to me, possibly we may need to transform
        # weights. More tests needed.
        recs = [g.ep.weight]
        rec_types = ['real-normal']

    if resume:
        # create the state and make sure sampling is performed
        state = adata.uns[key_added]['state'].copy()
        g = state.g
    else:
        if n_init < 1:
            n_init = 1

        # initialize  the block states
        states = [gt.PPBlockState(g) for x in range(n_init)]

        # perform a mcmc sweep on each
        # no list comprehension as I need to collect stats

        _dS = np.zeros(n_init)
        _nattempts = np.zeros(n_init)
        _nmoves = np.zeros(n_init)
        for x in range(n_init):
            t_ds = 1
            while np.abs(t_ds) > tolerance:
                # perform sweep until a tolerance is reached
                t_ds, t_natt, t_nm = states[x].multiflip_mcmc_sweep(
                    beta=beta, niter=n_sweep)
                _dS[x] += t_ds
                _nattempts[x] += t_natt
                _nmoves[x] += t_nm

        _amin = np.argmin([s.entropy() for s in states])
        state = states[_amin]
        dS = _dS[_amin]
        nattempts = _nattempts[_amin]
        nmoves = _nmoves[_amin]

        logg.info('    done', time=start)

    # equilibrate the Markov chain
    if equilibrate:
        logg.info('running MCMC equilibration step')
        equilibrate_args['wait'] = wait
        equilibrate_args['nbreaks'] = nbreaks
        equilibrate_args['max_niter'] = max_iterations
        equilibrate_args['mcmc_args'] = {'niter': 10}

        dS, nattempts, nmoves = gt.mcmc_anneal(
            state,
            mcmc_equilibrate_args=equilibrate_args,
            niter=steps_anneal,
            beta_range=beta_range)

    if collect_marginals and equilibrate:
        # we here only retain level_0 counts, until I can't figure out
        # how to propagate correctly counts to higher levels
        # I wonder if this should be placed after group definition or not
        logg.info('    collecting marginals')
        group_marginals = np.zeros(g.num_vertices() + 1)

        def _collect_marginals(s):
            group_marginals[s.get_B()] += 1

        gt.mcmc_equilibrate(state,
                            wait=wait,
                            nbreaks=nbreaks,
                            epsilon=epsilon,
                            max_niter=max_iterations,
                            multiflip=True,
                            force_niter=niter_collect,
                            mcmc_args=dict(niter=10),
                            callback=_collect_marginals)
        logg.info('    done', time=start)

    # everything is in place, we need to fill all slots
    # first build an array with
    groups = pd.Series(state.get_blocks().get_array()).astype('category')
    new_cat_names = dict([(cx, u'%s' % cn)
                          for cn, cx in enumerate(groups.cat.categories)])
    groups.cat.rename_categories(new_cat_names, inplace=True)

    if restrict_to is not None:
        groups.index = adata.obs[restrict_key].index
    else:
        groups.index = adata.obs_names

    # add column names
    adata.obs.loc[:, key_added] = groups

    # add some unstructured info

    adata.uns[key_added] = {}
    adata.uns[key_added]['stats'] = dict(dS=dS,
                                         nattempts=nattempts,
                                         nmoves=nmoves,
                                         modularity=gt.modularity(
                                             g, state.get_blocks()))
    adata.uns[key_added]['state'] = state

    # now add marginal probabilities.

    if collect_marginals:
        # cell marginals will be a list of arrays with probabilities
        # of belonging to a specific group
        adata.uns[key_added]['group_marginals'] = group_marginals

    # calculate log-likelihood of cell moves over the remaining levels

    # adata.uns[key_added]['cell_affinity'] = {'1':get_cell_loglikelihood(state, as_prob=True, rescale=True)}

    # last step is recording some parameters used in this analysis
    adata.uns[key_added]['params'] = dict(epsilon=epsilon,
                                          wait=wait,
                                          nbreaks=nbreaks,
                                          equilibrate=equilibrate,
                                          collect_marginals=collect_marginals,
                                          random_seed=random_seed)

    logg.info(
        '    finished',
        time=start,
        deep=(
            f'found {state.get_B()} clusters and added\n'
            f'    {key_added!r}, the cluster labels (adata.obs, categorical)'),
    )
    return adata if copy else None
Beispiel #10
0
 def compute_local_blocks(self):
     state = gt.minimize_blockmodel_dl(self.g, verbose=True)
     b = state.b
     mod = gt.modularity(self.g, b)
     print mod
     return b
Beispiel #11
0
 def optimize_global_modularity(self, x):
     b = gt.community_structure(self.g, 100, x)
     return -1*(gt.modularity(self.g, b))
Beispiel #12
0
def nested_model(
    adata: AnnData,
    max_iterations: int = 1000000,
    epsilon: float = 0,
    equilibrate: bool = False,
    wait: int = 1000,
    nbreaks: int = 2,
    collect_marginals: bool = False,
    niter_collect: int = 10000,
    hierarchy_length: int = 10,
    deg_corr: bool = True,
    multiflip: bool = True,
    fast_model: bool = False,
    fast_tol: float = 1e-6,
    n_sweep: int = 10,
    beta: float = np.inf,
    n_init: int = 1,
    beta_range: Tuple[float] = (1., 1000.),
    steps_anneal: int = 3,
    resume: bool = False,
    *,
    restrict_to: Optional[Tuple[str, Sequence[str]]] = None,
    random_seed: Optional[int] = None,
    key_added: str = 'nsbm',
    adjacency: Optional[sparse.spmatrix] = None,
    neighbors_key: Optional[str] = 'neighbors',
    directed: bool = False,
    use_weights: bool = False,
    prune: bool = False,
    return_low: bool = False,
    copy: bool = False,
    minimize_args: Optional[Dict] = {},
    equilibrate_args: Optional[Dict] = {},
) -> Optional[AnnData]:
    """\
    Cluster cells into subgroups [Peixoto14]_.

    Cluster cells using the nested Stochastic Block Model [Peixoto14]_,
    a hierarchical version of Stochastic Block Model [Holland83]_, performing
    Bayesian inference on node groups. NSBM should circumvent classical
    limitations of SBM in detecting small groups in large graphs
    replacing the noninformative priors used by a hierarchy of priors
    and hyperpriors.

    This requires having ran :func:`~scanpy.pp.neighbors` or
    :func:`~scanpy.external.pp.bbknn` first.

    Parameters
    ----------
    adata
        The annotated data matrix.
    max_iterations
        Maximal number of iterations to be performed by the equilibrate step.
    epsilon
        Relative changes in entropy smaller than epsilon will
        not be considered as record-breaking.
    equilibrate
        Whether or not perform the mcmc_equilibrate step.
        Equilibration should always be performed. Note, also, that without
        equilibration it won't be possible to collect marginals.
    collect_marginals
        Whether or not collect node probability of belonging
        to a specific partition.
    niter_collect
        Number of iterations to force when collecting marginals. This will
        increase the precision when calculating probabilites
    wait
        Number of iterations to wait for a record-breaking event.
        Higher values result in longer computations. Set it to small values
        when performing quick tests.
    nbreaks
        Number of iteration intervals (of size `wait`) without
        record-breaking events necessary to stop the algorithm.
    hierarchy_length
        Initial length of the hierarchy. When large values are
        passed, the top-most levels will be uninformative as they
        will likely contain the very same groups. Increase this valus
        if a very large number of cells is analyzed (>100.000).
    deg_corr
        Whether to use degree correction in the minimization step. In many
        real world networks this is the case, although this doesn't seem
        the case for KNN graphs used in scanpy.
    multiflip
        Whether to perform MCMC sweep with multiple simultaneous moves to sample
        network partitions. It may result in slightly longer runtimes, but under
        the hood it allows for a more efficient space exploration.
    fast_model
        Whether to skip initial minization step and let the MCMC find a solution. 
        This approach tend to be faster and consume less memory, but may be
        less accurate.
    fast_tol
        Tolerance for fast model convergence.
    n_sweep 
        Number of iterations to be performed in the fast model MCMC greedy approach
    beta
        Inverse temperature for MCMC greedy approach    
    n_init
        Number of initial minimizations to be performed. The one with smaller
        entropy is chosen
    beta_range
        Inverse temperature at the beginning and the end of the equilibration
    steps_anneal
        Number of steps in which the simulated annealing is performed
    resume
        Start from a previously created model, if any, without initializing a novel
        model    
    key_added
        `adata.obs` key under which to add the cluster labels.
    adjacency
        Sparse adjacency matrix of the graph, defaults to
        `adata.uns['neighbors']['connectivities']` in case of scanpy<=1.4.6 or
        `adata.obsp[neighbors_key][connectivity_key]` for scanpy>1.4.6
    neighbors_key
        The key passed to `sc.pp.neighbors`
    directed
        Whether to treat the graph as directed or undirected.
    use_weights
        If `True`, edge weights from the graph are used in the computation
        (placing more emphasis on stronger edges). Note that this
        increases computation times
    prune
        Some high levels in hierarchy may contain the same information in terms of 
        cell assignments, even if they apparently have different group names. When this
        option is set to `True`, the function only returns informative levels.
        Note, however, that cell affinities are still reported for all levels. Pruning
        does not rename group levels
    return_low
        Whether or not return nsbm_level_0 in adata.obs. This level usually contains
        so many groups that it cannot be plot anyway, but it may be useful for particular
        analysis. By default it is not returned
    copy
        Whether to copy `adata` or modify it inplace.
    random_seed
        Random number to be used as seed for graph-tool

    Returns
    -------
    `adata.obs[key_added]`
        Array of dim (number of samples) that stores the subgroup id
        (`'0'`, `'1'`, ...) for each cell. 
    `adata.uns['nsbm']['params']`
        A dict with the values for the parameters `resolution`, `random_state`,
        and `n_iterations`.
    `adata.uns['nsbm']['stats']`
        A dict with the values returned by mcmc_sweep
    `adata.uns['nsbm']['cell_affinity']`
        A `np.ndarray` with cell probability of belonging to a specific group
    `adata.uns['nsbm']['state']`
        The NestedBlockModel state object
    """

    if resume:
        # if the fast_model is chosen perform equilibration anyway
        # also if a model has previously created
        equilibrate = True

    if resume and ('nsbm' not in adata.uns
                   or 'state' not in adata.uns['nsbm']):
        # let the model proceed as default
        logg.warning('Resuming has been specified but a state was not found\n'
                     'Will continue with default minimization step')

        resume = False

    if random_seed:
        np.random.seed(random_seed)
        gt.seed_rng(random_seed)

    if collect_marginals:
        logg.warning('Collecting marginals has a large impact on running time')
        if not equilibrate:
            raise ValueError(
                "You can't collect marginals without MCMC equilibrate "
                "step. Either set `equlibrate` to `True` or "
                "`collect_marginals` to `False`")

    start = logg.info('minimizing the nested Stochastic Block Model')
    adata = adata.copy() if copy else adata
    # are we clustering a user-provided graph or the default AnnData one?
    if adjacency is None:
        if neighbors_key not in adata.uns:
            raise ValueError('You need to run `pp.neighbors` first '
                             'to compute a neighborhood graph.')
        elif 'connectivities_key' in adata.uns[neighbors_key]:
            # scanpy>1.4.6 has matrix in another slot
            conn_key = adata.uns[neighbors_key]['connectivities_key']
            adjacency = adata.obsp[conn_key]
        else:
            # scanpy<=1.4.6 has sparse matrix here
            adjacency = adata.uns[neighbors_key]['connectivities']
    if restrict_to is not None:
        restrict_key, restrict_categories = restrict_to
        adjacency, restrict_indices = restrict_adjacency(
            adata,
            restrict_key,
            restrict_categories,
            adjacency,
        )
    # convert it to igraph
    g = get_graph_tool_from_adjacency(adjacency, directed=directed)

    recs = []
    rec_types = []
    if use_weights:
        # this is not ideal to me, possibly we may need to transform
        # weights. More tests needed.
        recs = [g.ep.weight]
        rec_types = ['real-normal']

    if n_init < 1:
        n_init = 1

    if fast_model:
        # do not minimize, start with a dummy state and perform only equilibrate

        states = [
            gt.NestedBlockState(g=g,
                                state_args=dict(deg_corr=deg_corr,
                                                recs=recs,
                                                rec_types=rec_types))
            for n in range(n_init)
        ]
        for x in range(n_init):
            dS = 1
            while np.abs(dS) > fast_tol:
                # perform sweep until a tolerance is reached
                dS, _, _ = states[x].multiflip_mcmc_sweep(beta=beta,
                                                          niter=n_sweep)

        _amin = np.argmin([s.entropy() for s in states])
        state = states[_amin]

        #        dS = 1
        #        while np.abs(dS) > fast_tol:
        #            dS, nattempts, nmoves = state.multiflip_mcmc_sweep(niter=10, beta=np.inf)
        bs = state.get_bs()
        logg.info('    done', time=start)

    elif resume:
        # create the state and make sure sampling is performed
        state = adata.uns['nsbm']['state'].copy(sampling=True)
        bs = state.get_bs()
        # get the graph from state
        g = state.g
    else:

        states = [
            gt.minimize_nested_blockmodel_dl(
                g,
                deg_corr=deg_corr,
                state_args=dict(recs=recs, rec_types=rec_types),
                **minimize_args) for n in range(n_init)
        ]

        state = states[np.argmin([s.entropy() for s in states])]
        #        state = gt.minimize_nested_blockmodel_dl(g, deg_corr=deg_corr,
        #                                                 state_args=dict(recs=recs,
        #                                                 rec_types=rec_types),
        #                                                 **minimize_args)
        logg.info('    done', time=start)
        bs = state.get_bs()
        if len(bs) <= hierarchy_length:
            # increase hierarchy length up to the specified value
            # according to Tiago Peixoto 10 is reasonably large as number of
            # groups decays exponentially
            bs += [np.zeros(1)] * (hierarchy_length - len(bs))
        else:
            logg.warning(
                f'A hierarchy length of {hierarchy_length} has been specified\n'
                f'but the minimized model contains {len(bs)} levels')
            pass
        # create a new state with inferred blocks
        state = gt.NestedBlockState(g,
                                    bs,
                                    state_args=dict(recs=recs,
                                                    rec_types=rec_types),
                                    sampling=True)

    # equilibrate the Markov chain
    if equilibrate:
        logg.info('running MCMC equilibration step')
        # equlibration done by simulated annealing

        equilibrate_args['wait'] = wait
        equilibrate_args['nbreaks'] = nbreaks
        equilibrate_args['max_niter'] = max_iterations
        equilibrate_args['multiflip'] = multiflip
        equilibrate_args['mcmc_args'] = {'niter': 10}

        dS, nattempts, nmoves = gt.mcmc_anneal(
            state,
            mcmc_equilibrate_args=equilibrate_args,
            niter=steps_anneal,
            beta_range=beta_range)
    if collect_marginals and equilibrate:
        # we here only retain level_0 counts, until I can't figure out
        # how to propagate correctly counts to higher levels
        # I wonder if this should be placed after group definition or not
        logg.info('    collecting marginals')
        group_marginals = [
            np.zeros(g.num_vertices() + 1) for s in state.get_levels()
        ]

        def _collect_marginals(s):
            levels = s.get_levels()
            for l, sl in enumerate(levels):
                group_marginals[l][sl.get_nonempty_B()] += 1

        gt.mcmc_equilibrate(state,
                            wait=wait,
                            nbreaks=nbreaks,
                            epsilon=epsilon,
                            max_niter=max_iterations,
                            multiflip=True,
                            force_niter=niter_collect,
                            mcmc_args=dict(niter=10),
                            callback=_collect_marginals)
        logg.info('    done', time=start)

    # everything is in place, we need to fill all slots
    # first build an array with
    groups = np.zeros((g.num_vertices(), len(bs)), dtype=int)

    for x in range(len(bs)):
        # for each level, project labels to the vertex level
        # so that every cell has a name. Note that at this level
        # the labels are not necessarily consecutive
        groups[:, x] = state.project_partition(x, 0).get_array()

    groups = pd.DataFrame(groups).astype('category')

    # rename categories from 0 to n
    for c in groups.columns:
        new_cat_names = dict([
            (cx, u'%s' % cn)
            for cn, cx in enumerate(groups.loc[:, c].cat.categories)
        ])
        groups.loc[:, c].cat.rename_categories(new_cat_names, inplace=True)

    if restrict_to is not None:
        groups.index = adata.obs[restrict_key].index
    else:
        groups.index = adata.obs_names

    # add column names
    groups.columns = [
        "%s_level_%d" % (key_added, level) for level in range(len(bs))
    ]

    # remove any column with the same key
    keep_columns = [
        x for x in adata.obs.columns
        if not x.startswith('%s_level_' % key_added)
    ]
    adata.obs = adata.obs.loc[:, keep_columns]
    # concatenate obs with new data, skipping level_0 which is usually
    # crap. In the future it may be useful to reintegrate it
    # we need it in this function anyway, to match groups with node marginals
    if return_low:
        adata.obs = pd.concat([adata.obs, groups], axis=1)
    else:
        adata.obs = pd.concat([adata.obs, groups.iloc[:, 1:]], axis=1)

    # add some unstructured info

    adata.uns['nsbm'] = {}
    adata.uns['nsbm']['stats'] = dict(level_entropy=np.array(
        [state.level_entropy(x) for x in range(len(state.levels))]),
                                      modularity=np.array([
                                          gt.modularity(
                                              g, state.project_partition(x, 0))
                                          for x in range(len((state.levels)))
                                      ]))
    if equilibrate:
        adata.uns['nsbm']['stats']['dS'] = dS
        adata.uns['nsbm']['stats']['nattempts'] = nattempts
        adata.uns['nsbm']['stats']['nmoves'] = nmoves

    adata.uns['nsbm']['state'] = state

    # now add marginal probabilities.

    if collect_marginals:
        # refrain group marginals. We collected data in vector as long as
        # the number of cells, cut them into appropriate length data
        adata.uns['nsbm']['group_marginals'] = {}
        for nl, level_marginals in enumerate(group_marginals):
            idx = np.where(level_marginals > 0)[0] + 1
            adata.uns['nsbm']['group_marginals'][nl] = np.array(
                level_marginals[:np.max(idx)])

    # prune uninformative levels, if any
    if prune:
        to_remove = prune_groups(groups)
        logg.info(f'    Removing levels f{to_remove}')
        adata.obs.drop(to_remove, axis='columns', inplace=True)

    # calculate log-likelihood of cell moves over the remaining levels
    # we have to calculate events at level 0 and propagate to upper levels
    logg.info('    calculating cell affinity to groups')
    levels = [
        int(x.split('_')[-1]) for x in adata.obs.columns
        if x.startswith(f'{key_added}_level')
    ]
    adata.uns['nsbm']['cell_affinity'] = dict.fromkeys(
        [str(x) for x in levels])
    p0 = get_cell_loglikelihood(state, level=0, as_prob=True)

    adata.uns['nsbm']['cell_affinity'][0] = p0
    l0 = "%s_level_0" % key_added
    for nl, level in enumerate(groups.columns[1:]):
        cross_tab = pd.crosstab(groups.loc[:, l0], groups.loc[:, level])
        cl = np.zeros((p0.shape[0], cross_tab.shape[1]), dtype=p0.dtype)
        for x in range(cl.shape[1]):
            # sum counts of level_0 groups corresponding to
            # this group at current level
            cl[:, x] = p0[:, np.where(cross_tab.iloc[:, x] > 0)[0]].sum(axis=1)
        adata.uns['nsbm']['cell_affinity'][str(nl + 1)] = cl / np.sum(
            cl, axis=1)[:, None]

    # last step is recording some parameters used in this analysis
    adata.uns['nsbm']['params'] = dict(
        epsilon=epsilon,
        wait=wait,
        nbreaks=nbreaks,
        equilibrate=equilibrate,
        fast_model=fast_model,
        collect_marginals=collect_marginals,
        hierarchy_length=hierarchy_length,
        random_seed=random_seed,
        prune=prune,
    )

    logg.info(
        '    finished',
        time=start,
        deep=
        (f'found {state.get_levels()[1].get_nonempty_B()} clusters at level_1, and added\n'
         f'    {key_added!r}, the cluster labels (adata.obs, categorical)'),
    )
    return adata if copy else None
Beispiel #13
0
def nested_model_multi(
    adatas: List[AnnData],
    deg_corr: bool = True,
    tolerance: float = 1e-6,
    n_sweep: int = 10,
    beta: float = np.inf,
    samples: int = 100,
    collect_marginals: bool = True,
    n_jobs: int = -1,
    *,
    random_seed: Optional[int] = None,
    key_added: str = 'multi_nsbm',
    adjacency: Optional[List[sparse.spmatrix]] = None,
    neighbors_key: Optional[List[str]] = ['neighbors'],
    directed: bool = False,
    use_weights: bool = False,
    save_model: Union[str, None] = None,
    copy: bool = False,
    #    minimize_args: Optional[Dict] = {},
    dispatch_backend: Optional[str] = 'processes',
    #    equilibrate_args: Optional[Dict] = {},
) -> Optional[List[AnnData]]:
    """\
    Cluster cells into subgroups using multiple modalities.

    Cluster cells using the nested Stochastic Block Model [Peixoto14]_,
    performing Bayesian inference on node groups. This function takes multiple
    experiments, possibly across different modalities, and perform joint
    clustering.
    

    This requires having ran :func:`~scanpy.pp.neighbors` or
    :func:`~scanpy.external.pp.bbknn` first. It also requires cells having the same
    names if coming from paired experiments

    Parameters
    ----------
    adatas
        A list of processed AnnData. Neighbors must have been already
        calculated.
    deg_corr
        Whether to use degree correction in the minimization step. In many
        real world networks this is the case, although this doesn't seem
        the case for KNN graphs used in scanpy.
    tolerance
        Tolerance for fast model convergence.
    n_sweep 
        Number of iterations to be performed in the fast model MCMC greedy approach
    beta
        Inverse temperature for MCMC greedy approach    
    samples
        Number of initial minimizations to be performed. The one with smaller
        entropy is chosen
    n_jobs
        Number of parallel computations used during model initialization
    key_added
        `adata.obs` key under which to add the cluster labels.
    adjacency
        Sparse adjacency matrix of the graph, defaults to
        `adata.uns['neighbors']['connectivities']` in case of scanpy<=1.4.6 or
        `adata.obsp[neighbors_key][connectivity_key]` for scanpy>1.4.6
    neighbors_key
        The key passed to `sc.pp.neighbors`. If all AnnData share the same key, one
        only has to be specified, otherwise the full tuple of all keys must 
        be provided
    directed
        Whether to treat the graph as directed or undirected.
    use_weights
        If `True`, edge weights from the graph are used in the computation
        (placing more emphasis on stronger edges). Note that this
        increases computation times
    save_model
        If provided, this will be the filename for the PartitionModeState to 
        be saved    
    copy
        Whether to copy `adata` or modify it inplace.
    random_seed
        Random number to be used as seed for graph-tool

    Returns
    -------
    `adata.obs[key_added]`
        Array of dim (number of samples) that stores the subgroup id
        (`'0'`, `'1'`, ...) for each cell. 
    `adata.uns['schist']['multi_level_params']`
        A dict with the values for the parameters `resolution`, `random_state`,
        and `n_iterations`.
    `adata.uns['schist']['multi_level_stats']`
        A dict with the values returned by mcmc_sweep
    `adata.obsm['CA_multi_nsbm_level_{n}']`
        A `np.ndarray` with cell probability of belonging to a specific group
    `adata.uns['schist']['multi_level_state']`
        The NestedBlockModel state object
    """

    if random_seed:
        np.random.seed(random_seed)

    seeds = np.random.choice(range(samples**2), size=samples, replace=False)

    if collect_marginals and samples < 100:
        logg.warning(
            'Collecting marginals requires sufficient number of samples\n'
            f'It is now set to {samples} and should be at least 100')

    start = logg.info('minimizing the nested Stochastic Block Model')

    if copy:
        adatas = [x.copy() for x in adatas]

    n_keys = len(neighbors_key)
    n_data = len(adatas)
    # are we clustering a user-provided graph or the default AnnData one?
    if adjacency is None:
        adjacency = []
        if n_keys > 1 and n_keys < n_data:
            raise ValueError(
                'The number of neighbors keys does not match'
                'the number of data matrices. Either fix this'
                'or pass a neighbor key that is shared across all modalities')
        if n_keys == 1:
            neighbors_key = [neighbors_key[0] for x in range(n_data)]
        for x in range(n_data):
            logg.info(f'getting adjacency for data {x}', time=start)
            if neighbors_key[x] not in adatas[x].uns:
                raise ValueError('You need to run `pp.neighbors` first '
                                 'to compute a neighborhood graph. for'
                                 f'data entry {x}')
            elif 'connectivities_key' in adatas[x].uns[neighbors_key[x]]:
                # scanpy>1.4.6 has matrix in another slot
                conn_key = adatas[x].uns[
                    neighbors_key[x]]['connectivities_key']
                adjacency.append(adatas[x].obsp[conn_key])
            else:
                # scanpy<=1.4.6 has sparse matrix here
                adjacency.append(
                    adatas[x].uns[neighbors_key[x]]['connectivities'])

    # convert it to igraph and graph-tool

    graph_list = []
    for x in range(n_data):
        g = get_igraph_from_adjacency(adjacency[x], directed=directed)
        g = g.to_graph_tool()
        gt.remove_parallel_edges(g)
        # add cell names to graph, this will be used to create
        # layered graph
        g_names = g.new_vertex_property('string')
        d_names = adatas[x].obs_names
        for xn in range(len(d_names)):
            g_names[xn] = d_names[xn]
        g.vp['cell'] = g_names
        graph_list.append(g)

# skip weights for now
#    recs=[]
#    rec_types=[]
#    if use_weights:
# this is not ideal to me, possibly we may need to transform
# weights. More tests needed.
#        recs=[g.ep.weight]
#        rec_types=['real-normal']

# get a non-redundant list of all cell names across all modalities
    all_names = set(adatas[0].obs_names)
    [all_names.update(adatas[x].obs_names) for x in range(1, n_data)]
    all_names = list(all_names)
    # create the shared graph
    union_g = gt.Graph(directed=False)
    union_g.add_vertex(len(all_names))
    u_names = union_g.new_vertex_property('string')
    for xn in range(len(all_names)):
        u_names[xn] = all_names[xn]
    union_g.vp['cell'] = u_names

    # now handle in a non elegant way the index mapping across all
    # modalities and the unified Graph

    u_cell_index = dict([(union_g.vp['cell'][x], x)
                         for x in range(union_g.num_vertices())])
    # now create layers
    layer = union_g.new_edge_property('int')
    for ng in range(n_data):
        for e in graph_list[ng].edges():
            S, T = e.source(), e.target()
            Sn = graph_list[ng].vp['cell'][S]
            Tn = graph_list[ng].vp['cell'][T]
            Sidx = u_cell_index[Sn]
            Tidx = u_cell_index[Tn]
            ne = union_g.add_edge(Sidx, Tidx)
            layer[ne] = ng + 1  # this is the layer label

    union_g.ep['layer'] = layer
    # DONE! now proceed with standard minimization, ish

    if samples < 1:
        samples = 1

    states = [
        gt.NestedBlockState(g=union_g,
                            base_type=gt.LayeredBlockState,
                            state_args=dict(deg_corr=deg_corr,
                                            ec=union_g.ep.layer,
                                            layers=True))
        for n in range(samples)
    ]

    def fast_min(state, beta, n_sweep, fast_tol, seed=None):
        if seed:
            gt.seed_rng(seed)
        dS = 1
        while np.abs(dS) > fast_tol:
            dS, _, _ = state.multiflip_mcmc_sweep(beta=beta,
                                                  niter=n_sweep,
                                                  c=0.5)
        return state

    states = Parallel(n_jobs=n_jobs, prefer=dispatch_backend)(
        delayed(fast_min)(states[x], beta, n_sweep, tolerance, seeds[x])
        for x in range(samples))
    logg.info('        minimization step done', time=start)
    pmode = gt.PartitionModeState([x.get_bs() for x in states],
                                  converge=True,
                                  nested=True)
    bs = pmode.get_max_nested()
    logg.info('        consensus step done', time=start)

    if save_model:
        import pickle
        fname = save_model
        if not fname.endswith('pkl'):
            fname = f'{fname}.pkl'
        logg.info(f'Saving model into {fname}')
        with open(fname, 'wb') as fout:
            pickle.dump(pmode, fout, 2)

    # prune redundant levels at the top
    bs = [x for x in bs if len(np.unique(x)) > 1]
    bs.append(np.array([0],
                       dtype=np.int32))  #in case of type changes, check this
    state = gt.NestedBlockState(union_g,
                                bs=bs,
                                base_type=gt.LayeredBlockState,
                                state_args=dict(deg_corr=deg_corr,
                                                ec=union_g.ep.layer,
                                                layers=True))

    logg.info('    done', time=start)
    u_groups = np.unique(bs[0])
    n_groups = len(u_groups)
    last_group = np.max(u_groups) + 1

    if collect_marginals:
        # note that the size of this will be equal to the number of the groups in Mode
        # but some entries won't sum to 1 as in the collection there may be differently
        # sized partitions
        pv_array = pmode.get_marginal(union_g).get_2d_array(
            range(last_group)).T[:, u_groups] / samples

    groups = np.zeros((union_g.num_vertices(), len(bs)), dtype=int)

    for x in range(len(bs)):
        # for each level, project labels to the vertex level
        # so that every cell has a name. Note that at this level
        # the labels are not necessarily consecutive
        groups[:, x] = state.project_partition(x, 0).get_array()

    groups = pd.DataFrame(groups).astype('category')

    # rename categories from 0 to n
    for c in groups.columns:
        ncat = len(groups[c].cat.categories)
        new_cat = [u'%s' % x for x in range(ncat)]
        groups[c].cat.rename_categories(new_cat, inplace=True)

    levels = groups.columns

    # recode block names to have consistency with group names
    i_groups = groups.astype(int)
    bs = [i_groups.iloc[:, 0].values]
    for x in range(1, groups.shape[1]):
        bs.append(
            np.where(
                pd.crosstab(i_groups.iloc[:, x - 1], i_groups.iloc[:,
                                                                   x]) > 0)[1])
    state = gt.NestedBlockState(union_g, bs)
    del (i_groups)

    groups.index = all_names

    # add column names
    groups.columns = [f"{key_added}_level_{level}" for level in range(len(bs))]

    # remove any column with the same key
    for xn in range(n_data):
        drop_columns = groups.columns.intersection(adatas[xn].obs.columns)
        adatas[xn].obs.drop(drop_columns, 'columns', inplace=True)
        adatas[xn].obs = pd.concat(
            [adatas[xn].obs, groups.loc[adatas[xn].obs_names]], axis=1)

        # now add marginal probabilities.

        if collect_marginals:
            # add marginals for level 0, the sum up according to the hierarchy
            _groups = groups.loc[adatas[xn].obs_names]
            _pv_array = pd.DataFrame(
                pv_array, index=all_names).loc[adatas[xn].obs_names].values
            adatas[xn].obsm[f"CM_{key_added}_level_0"] = _pv_array
            for group in groups.columns[1:]:
                ct = pd.crosstab(_groups[_groups.columns[0]],
                                 _groups[group],
                                 normalize='index',
                                 dropna=False)
                adatas[xn].obsm[f'CM_{group}'] = _pv_array @ ct.values

        # add some unstructured info
        if not 'schist' in adatas[xn].uns:
            adatas[xn].uns['schist'] = {}

        adatas[xn].uns['schist'][f'{key_added}'] = {}
        adatas[xn].uns['schist'][f'{key_added}']['stats'] = dict(
            level_entropy=np.array(
                [state.level_entropy(x) for x in range(len(state.levels))]),
            modularity=np.array([
                gt.modularity(union_g, state.project_partition(x, 0))
                for x in range(len((state.levels)))
            ]))

        bl_d = {}
        levels = state.get_levels()
        for nl in range(len(levels)):
            bl_d[str(nl)] = np.array(levels[nl].get_blocks().a)
        adatas[xn].uns['schist'][f'{key_added}']['blocks'] = bl_d

        # last step is recording some parameters used in this analysis
        adatas[xn].uns['schist'][f'{key_added}']['params'] = dict(
            model='multiome_nested',
            use_weights=use_weights,
            neighbors_key=neighbors_key[xn],
            key_added=key_added,
            samples=samples,
            collect_marginals=collect_marginals,
            random_seed=random_seed,
            deg_corr=deg_corr,
            #            recs=recs,
            #            rec_types=rec_types
        )

    logg.info(
        '    finished',
        time=start,
        deep=(
            f'and added\n'
            f'    {key_added!r}, the cluster labels (adata.obs, categorical)'),
    )
    return adatas if copy else None