예제 #1
0
    def fit(self, X, y=None):
        n_samples = X.shape[0]

        if n_samples > self.min_split_samples:
            cluster = GaussianCluster(min_components=1,
                                      max_components=2,
                                      n_init=20)
            cluster.fit(X)
            self.model_ = cluster
        else:
            self.pred_labels_ = np.zeros(X.shape[0])
            self.left_ = None
            self.right_ = None
            self.model_ = None
            return self

        # recurse
        if cluster.n_components_ != 1:
            pred_labels = cluster.predict(X)
            self.pred_labels_ = pred_labels
            indicator = pred_labels == 0
            self.X_left_ = X[indicator, :]
            self.X_right_ = X[~indicator, :]
            split_left = PartitionCluster()
            self.left_ = split_left.fit(self.X_left_)

            split_right = PartitionCluster()
            self.right_ = split_right.fit(self.X_right_)
        else:
            self.pred_labels_ = np.zeros(X.shape[0])
            self.left_ = None
            self.right_ = None
            self.model_ = None
        return self
예제 #2
0
def test_mug2vec():
    graphs, labels = generate_data()

    mugs = mug2vec(pass_to_ranks=None)
    xhat = mugs.fit_transform(graphs)

    gmm = GaussianCluster(5)
    gmm.fit(xhat, labels)

    assert_equal(gmm.n_components_, 2)
예제 #3
0
def estimate_assignments(graph,
                         n_communities,
                         n_components=None,
                         method="gc",
                         metric=None):
    """Given a graph and n_comunities, sweeps over covariance structures
    Not deterministic
    Not using graph bic or mse to calculate best

    1. Does an embedding on the raw graph
    2. GaussianCluster on the embedding. This will sweep covariance structure for the 
       given n_communities
    3. Returns n_parameters based on the number used in GaussianCluster

    method can be "gc" or "bc" 

    method 
    "gc" : use graspy GaussianCluster
        this defaults to full covariance
    "bc" : tommyclust with defaults
        so sweep covariance, agglom, linkage
    "bc-metric" : tommyclust with custom metric
        still sweep everything
    "bc-none" : mostly for testing, should behave just like GaussianCluster

    """
    embed_graph = graph.copy()
    latent = AdjacencySpectralEmbed(
        n_components=n_components).fit_transform(embed_graph)
    if isinstance(latent, tuple):
        latent = np.concatenate(latent, axis=1)
    if method == "gc":
        gc = GaussianCluster(
            min_components=n_communities,
            max_components=n_communities,
            covariance_type="all",
        )
        vertex_assignments = gc.fit_predict(latent)
        n_params = gc.model_._n_parameters()
    elif method == "bc":
        vertex_assignments, n_params = brute_cluster(latent, [n_communities])
    elif method == "bc-metric":
        vertex_assignments, n_params = brute_cluster(latent, [n_communities],
                                                     metric=metric)
    elif method == "bc-none":
        vertex_assignments, n_params = brute_cluster(
            latent,
            [n_communities],
            affinities=["none"],
            linkages=["none"],
            covariance_types=["full"],
        )
    else:
        raise ValueError("Unspecified clustering method")
    return (vertex_assignments, n_params)
예제 #4
0
def fit_and_score(X_train, X_test, k, **kws):
    gc = GaussianCluster(min_components=k, max_components=k, **kws)
    gc.fit(X_train)
    model = gc.model_
    train_bic = model.bic(X_train)
    train_lik = model.score(X_train)
    test_bic = model.bic(X_test)
    test_lik = model.score(X_test)
    bic = model.bic(np.concatenate((X_train, X_test), axis=0))
    res = {
        "train_bic": -train_bic,
        "train_lik": train_lik,
        "test_bic": -test_bic,
        "test_lik": test_lik,
        "bic": -bic,
        "lik": train_lik + test_lik,
        "k": k,
        "model": gc.model_,
    }
    return res, model
예제 #5
0
 def fit(self, X, y=None):
     n_samples = X.shape[0]
     self.n_samples_ = n_samples
     if n_samples > self.min_split_samples:
         cluster = GaussianCluster(min_components=1,
                                   max_components=2,
                                   n_init=40)
         cluster.fit(X)
         pred_labels = cluster.predict(X)
         self.pred_labels_ = pred_labels
         self.model_ = cluster
         if cluster.n_components_ != 1:
             indicator = pred_labels == 0
             self.X_children_ = (X[indicator, :], X[~indicator, :])
             children = []
             for i, X_child in enumerate(self.X_children_):
                 child = DivisiveCluster(name=self.name + str(i),
                                         parent=self)
                 child = child.fit(X_child)
                 children.append(child)
             self.children = children
     return self
예제 #6
0
 def fit(self, X, y=None):
     n_samples = X.shape[0]
     self.n_samples_ = n_samples
     self.cum_dist_ = 0
     if n_samples > self.min_split_samples:
         if self.cluster_method == "graspy-gmm":
             cluster = GaussianCluster(
                 min_components=1,
                 max_components=2,
                 n_init=self.n_init,
                 covariance_type="all",
             )
         elif self.cluster_method == "auto-gmm":
             cluster = AutoGMMCluster(
                 min_components=1, max_components=2, max_agglom_size=None
             )
         elif self.cluster_method == "vmm":
             # cluster = VonMisesFisherMixture(n)
             pass
         else:
             raise ValueError(f"`cluster_method` must be one of {valid_methods}")
         cluster.fit(X)
         pred_labels = cluster.predict(X)
         self.pred_labels_ = pred_labels
         self.model_ = cluster
         if hasattr(cluster, "bic_"):
             bics = cluster.bic_
             self.bics_ = bics
             bic_ratio = bics.loc[2].min() / bics.loc[1].min()
             self.bic_ratio_ = bic_ratio
         if cluster.n_components_ != 1:  # recurse
             indicator = pred_labels == 0
             self.X_children_ = (X[indicator, :], X[~indicator, :])
             children = []
             for i, X_child in enumerate(self.X_children_):
                 child = DivisiveCluster(
                     name=self.name + str(i),
                     parent=self,
                     min_split_samples=self.min_split_samples,
                     n_init=self.n_init,
                     cluster_method=self.cluster_method,
                 )
                 child = child.fit(X_child)
                 children.append(child)
             self.children = children
     return self
예제 #7
0
n = 100
d = 3

np.random.seed(3)

X1 = np.random.normal(0.5, 0.5, size=(n, d))
X2 = np.random.normal(-0.5, 0.5, size=(n, d))
X3 = np.random.normal(0.8, 0.6, size=(n, d))
X4 = np.random.uniform(0.2, 0.3, size=(n, d))
X = np.vstack((X1, X2, X3, X4))
pairplot(X)

np.random.seed(3)

gclust = GaussianCluster(min_components=2, max_components=2, n_init=1, max_iter=100)
gclust.fit(X)

bic1 = gclust.bic_

np.random.seed(3)

gclust = GaussianCluster(min_components=2, max_components=2, n_init=50, max_iter=100)
gclust.fit(X)

bic2 = gclust.bic_

# we'd hope that bic2 is a little bit lower than bic1
print(bic1)
print(bic2)
    #- BIC
    bic_ = 2 * likeli - temp_n_params * np.log(n)

    #- ARI
    ari_ = ari(true_labels, temp_c_hat)

    return [combo, likeli, ari_, bic_]


np.random.seed(16661)
A = binarize(right_adj)
X_hat = np.concatenate(ASE(n_components=3).fit_transform(A), axis=1)
n, d = X_hat.shape

gclust = GCLUST(max_components=15)
est_labels = gclust.fit_predict(X_hat)

loglikelihoods = [np.sum(gclust.model_.score_samples(X_hat))]
combos = [None]
aris = [ari(right_labels, est_labels)]
bic = [gclust.model_.bic(X_hat)]

unique_labels = np.unique(est_labels)

class_idx = np.array([np.where(est_labels == u)[0] for u in unique_labels])

for k in range(len(unique_labels)):
    for combo in list(combinations(np.unique(est_labels), k + 1)):
        combo = np.array(list(combo)).astype(int)
        combos.append(combo)
예제 #9
0
    if normalize:
        sums = X.sum(axis=1)
        sums[sums == 0] = 1
        X = X / sums[:, None]

    if log_cluster:
        X = np.log10(X + 1)

    agmm = AutoGMMCluster(**cluster_kws)
    pred_labels = agmm.fit_predict(X)
    results = agmm.results_
    fg_col_meta[i]["pred_labels"] = pred_labels
    fg_autoclusters.append(agmm)

    ggmm = GaussianCluster(min_components=10,
                           max_components=40,
                           n_init=20,
                           covariance_type="diag")
    ggmm.fit(X)
    fg_graspyclusters.append(ggmm)

    gbics = ggmm.bic_

    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    scatter_kws = {"s": 30}
    best_results = results.groupby("n_components").min()
    best_results = best_results.reset_index()
    sns.scatterplot(
        data=best_results,
        y="bic/aic",
        x="n_components",
        label="AutoGMM",
예제 #10
0
max_components = 35
comps = list(range(min_components, max_components))
cluster_kws = dict(min_components=min_components,
                   max_components=max_components,
                   covariance_type="full")
n_sims = 25
n_components_range = list(range(3, 25))

bic_results = []
for i, n_components in tqdm(enumerate(n_components_range)):
    ase = AdjacencySpectralEmbed(n_components=n_components)
    latent = ase.fit_transform(embed_graph)
    latent = np.concatenate(latent, axis=-1)

    for _ in range(n_sims):
        cluster = GaussianCluster(**cluster_kws)
        cluster.fit(latent)
        bics = cluster.bic_
        bics["n_clusters"] = bics.index
        bics["n_components"] = n_components
        bic_results.append(bics)

#%%
result_df = pd.concat(bic_results, axis=0)
result_df.rename(columns={"full": "bic"}, inplace=True)

plt.figure(figsize=(15, 10))
sns.lineplot(data=result_df, x="n_clusters", y="bic", hue="n_components")

save("clustering_june_bics")
# #%%
예제 #11
0
embed = "LSE"
cluster = "GMM"

lse_latent = lse(adj, 4, regularizer=None)

latent = lse_latent
pairplot(latent, labels=simple_class_labels, title=embed)

for k in range(MIN_CLUSTERS, MAX_CLUSTERS + 1):
    run_name = f"k = {k}, {cluster}, {embed}, right hemisphere (A to D), PTR, raw"
    print(run_name)
    print()

    # Cluster
    gmm = GaussianCluster(min_components=k, max_components=k, **gmm_params)
    gmm.fit(latent)
    pred_labels = gmm.predict(latent)

    # ARI
    base_dict = {
        "K": k,
        "Cluster": cluster,
        "Embed": embed,
        "Method": f"{cluster} o {embed}",
        "Score": gmm.model_.score(latent),
    }
    mb_ari = sub_ari(known_inds, mb_labels, pred_labels)
    mb_ari_dict = base_dict.copy()
    mb_ari_dict["ARI"] = mb_ari
    mb_ari_dict["Metric"] = "MB ARI"
예제 #12
0
n = 1000
pi = 0.9

A, counts = generate_cyclops(X, n, pi, None)
c = [0] * counts[0]
c += [1] * counts[1]

ase = ASE(n_components=3)
X_hat = ase.fit_transform(A)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X_hat[:, 0], X_hat[:, 1], X_hat[:, 2], c=c)

gclust = GCLUST(max_components=4)
c_hat = gclust.fit_predict(X_hat)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X_hat[:, 0], X_hat[:, 1], X_hat[:, 2], c=c_hat)


def quadratic(data, params):
    if data.ndim == 1:
        sum_ = np.sum(data[:-1]**2 * params[:-1]) + params[-1]
        return sum_
    elif data.ndim == 2:
        sums = np.sum(data[:, :-1]**2 * params[:-1], axis=1) + params[-1]
        return sums
    else:
예제 #13
0
def lse(adj, n_components, regularizer=None):
    if PTR:
        adj = pass_to_ranks(adj)
    lap = to_laplace(adj, form="R-DAD")
    ase = AdjacencySpectralEmbed(n_components=n_components)
    latent = ase.fit_transform(lap)
    latent = np.concatenate(latent, axis=-1)
    return latent


n_components = None
k = 30

latent = lse(adj, n_components, regularizer=None)

gmm = GaussianCluster(min_components=k, max_components=k)

pred_labels = gmm.fit_predict(latent)

stacked_barplot(pred_labels, class_labels, palette="tab20")

# %% [markdown]
# # verify on sklearn toy dataset
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

X, y = make_blobs(n_samples=200, n_features=3, centers=None, cluster_std=3)
# y = y.astype(int).astype(str)
data_df = pd.DataFrame(
    data=np.concatenate((X, y[:, np.newaxis]), axis=-1),
    columns=("Dim 0", "Dim 1", "Dim 2", "Labels"),
예제 #14
0
        side=side,
        n_components=agmm_model.n_components,
        covariance_type=agmm_model.covariance_type,
        model=agmm_model,
        embed=embed,
        labels=labels,
    )
    rows.append(row)

    # cluster using GaussianCluster
    method = "GClust"
    gclust = GaussianCluster(
        min_components=2,
        max_components=10,
        n_init=200,
        covariance_type="full",
        max_iter=200,
        tol=0.001,
        init_params="kmeans",
    )
    gclust.fit(embed)
    gclust_model = gclust.model_
    gclust_results = gclust.bic_.copy()
    gclust_results.index.name = "n_components"
    gclust_results.reset_index(inplace=True)
    gclust_results = pd.melt(
        gclust_results,
        id_vars="n_components",
        var_name="covariance_type",
        value_name="bic/aic",
    )
예제 #15
0
def cluster_func(k, seed):
    np.random.seed(seed)
    run_name = f"k = {k}, {cluster}, {embed}, right hemisphere (A to D), PTR, raw"
    print(run_name)
    print()

    # Cluster
    gmm = GaussianCluster(min_components=k, max_components=k, **gmm_params)
    gmm.fit(latent)
    pred_labels = gmm.predict(latent)

    # ARI
    base_dict = {
        "K": k,
        "Cluster": cluster,
        "Embed": embed,
        "Method": f"{cluster} o {embed}",
        "Score": gmm.model_.score(latent),
    }
    mb_ari = sub_ari(known_inds, mb_labels, pred_labels)
    mb_ari_dict = base_dict.copy()
    mb_ari_dict["ARI"] = mb_ari
    mb_ari_dict["Metric"] = "MB ARI"
    out_dicts.append(mb_ari_dict)

    simple_ari = sub_ari(known_inds, simple_class_labels, pred_labels)
    simple_ari_dict = base_dict.copy()
    simple_ari_dict["ARI"] = simple_ari
    simple_ari_dict["Metric"] = "Simple ARI"
    out_dicts.append(simple_ari_dict)

    full_ari = adjusted_rand_score(class_labels, pred_labels)
    full_ari_dict = base_dict.copy()
    full_ari_dict["ARI"] = full_ari
    full_ari_dict["Metric"] = "Full ARI"
    out_dicts.append(full_ari_dict)

    save_name = f"k{k}-{cluster}-{embed}-right-ad-PTR-raw"

    # Plot embedding
    pairplot(latent, labels=pred_labels, title=run_name)
    # stashfig("latent-" + save_name)

    # Plot everything else
    prob_df = get_sbm_prob(adj, pred_labels)
    block_sum_df = get_block_edgesums(adj, pred_labels, prob_df.columns.values)

    clustergram(adj, latent, prob_df, block_sum_df, simple_class_labels,
                pred_labels)
    plt.suptitle(run_name, fontsize=40)
    stashfig("clustergram-" + save_name)

    # output skeletons
    _, colormap, pal = stashskel(save_name,
                                 skeleton_labels,
                                 pred_labels,
                                 palette="viridis",
                                 multiout=True)

    sns.set_context("talk")
    palplot(k, cmap="viridis")

    stashfig("palplot-" + save_name)

    # save dict colormapping
    filename = (Path("./maggot_models/notebooks/outs") / Path(FNAME) /
                str("colormap-" + save_name + ".json"))
    with open(filename, "w") as fout:
        json.dump(colormap, fout)

    stashskel(save_name,
              skeleton_labels,
              pred_labels,
              palette="viridis",
              multiout=False)
예제 #16
0
def crossval_cluster(
    embed,
    left_inds,
    right_inds,
    R,
    min_clusters=2,
    max_clusters=15,
    n_init=25,
    left_pair_inds=None,
    right_pair_inds=None,
):
    left_embed = embed[left_inds]
    right_embed = embed[right_inds]
    print("Running left/right clustering with cross-validation\n")
    currtime = time.time()
    rows = []
    for k in tqdm(range(min_clusters, max_clusters)):
        # train left, test right
        # TODO add option for AutoGMM as well, might as well check
        left_gc = GaussianCluster(min_components=k,
                                  max_components=k,
                                  n_init=n_init)
        left_gc.fit(left_embed)
        model = left_gc.model_
        train_left_bic = model.bic(left_embed)
        train_left_lik = model.score(left_embed)
        test_left_bic = model.bic(right_embed @ R.T)
        test_left_lik = model.score(right_embed @ R.T)

        # train right, test left
        right_gc = GaussianCluster(min_components=k,
                                   max_components=k,
                                   n_init=n_init)
        right_gc.fit(right_embed)
        model = right_gc.model_
        train_right_bic = model.bic(right_embed)
        train_right_lik = model.score(right_embed)
        test_right_bic = model.bic(left_embed @ R)
        test_right_lik = model.score(left_embed @ R)

        left_row = {
            "k": k,
            "contra_bic": -test_left_bic,
            "contra_lik": test_left_lik,
            "ipsi_bic": -train_left_bic,
            "ipsi_lik": train_left_lik,
            "cluster": left_gc,
            "train": "left",
            "n_components": n_components,
        }
        right_row = {
            "k": k,
            "contra_bic": -test_right_bic,
            "contra_lik": test_right_lik,
            "ipsi_bic": -train_right_bic,
            "ipsi_lik": train_right_lik,
            "cluster": right_gc,
            "train": "right",
            "n_components": n_components,
        }

        # pairedness computation, if available
        if left_pair_inds is not None and right_pair_inds is not None:
            # TODO double check this is right
            pred_left = left_gc.predict(embed[left_pair_inds])
            pred_right = right_gc.predict(embed[right_pair_inds])
            pness, _, _ = compute_pairedness_bipartite(pred_left, pred_right)
            left_row["pairedness"] = pness
            right_row["pairedness"] = pness

            ari = adjusted_rand_score(pred_left, pred_right)
            left_row["ARI"] = ari
            right_row["ARI"] = ari

        rows.append(left_row)
        rows.append(right_row)

    results = pd.DataFrame(rows)
    print(f"{time.time() - currtime} elapsed")
    return results
예제 #17
0
Rs = []
n_components = embed[0].shape[1]
print(n_components)
print()
train_embed = np.concatenate(
    (embed[0][:, :n_components], embed[1][:, :n_components]), axis=-1)
R, _ = orthogonal_procrustes(train_embed[lp_inds], train_embed[rp_inds])
Rs.append(R)
left_embed = train_embed[left_inds]
left_embed = left_embed @ R
right_embed = train_embed[right_inds]

for k in tqdm(range(2, 15)):
    # train left, test right
    left_gc = GaussianCluster(min_components=k,
                              max_components=k,
                              n_init=n_init)
    left_gc.fit(left_embed)
    model = left_gc.model_
    train_left_bic = model.bic(left_embed)
    train_left_lik = model.score(left_embed)
    test_left_bic = model.bic(right_embed)
    test_left_lik = model.score(right_embed)

    row = {
        "k": k,
        "contra_bic": test_left_bic,
        "contra_lik": test_left_lik,
        "ipsi_bic": train_left_bic,
        "ipsi_lik": train_left_lik,
        "cluster": left_gc,
예제 #18
0
def brute_graspy_cluster(Ns,
                         x,
                         covariance_types,
                         ks,
                         c_true,
                         savefigs=None,
                         graphList=None):
    if graphList != None and 'all_bics' in graphList:
        _, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2,
                                                   2,
                                                   sharey='row',
                                                   sharex='col',
                                                   figsize=(10, 10))
    titles = ['full', 'tied', 'diag', 'spherical']
    best_bic = -np.inf
    for N in Ns:
        bics = np.zeros([len(ks), len(covariance_types), N])
        aris = np.zeros([len(ks), len(covariance_types), N])
        for i in np.arange(N):
            graspy_gmm = GaussianCluster(min_components=ks[0],
                                         max_components=ks[len(ks) - 1],
                                         covariance_type=covariance_types,
                                         random_state=i)
            c_hat, ari = graspy_gmm.fit_predict(x, y=c_true)
            bic_values = -graspy_gmm.bic_.values
            ari_values = graspy_gmm.ari_.values
            bics[:, :, i] = bic_values
            aris[:, :, i] = ari_values
            bic = bic_values.max()

            if bic > best_bic:
                idx = np.argmax(bic_values)
                idxs = np.unravel_index(idx, bic_values.shape)
                best_ari_bic = ari
                best_bic = bic
                best_k_bic = ks[idxs[0]]
                best_cov_bic = titles[3 - idxs[1]]
                best_c_hat_bic = c_hat

        max_bics = np.amax(bics, axis=2)
        title = 'N=' + str(N)
        if graphList != None and 'all_bics' in graphList:
            ax0.plot(np.arange(1, len(ks) + 1), max_bics[:, 3])
            ax1.plot(np.arange(1, len(ks) + 1), max_bics[:, 2], label=title)
            ax2.plot(np.arange(1, len(ks) + 1), max_bics[:, 1])
            ax3.plot(np.arange(1, len(ks) + 1), max_bics[:, 0])

    if graphList != None and 'best_bic' in graphList:
        #Plot with best BIC*********************************
        if c_true is None:
            best_ari_bic_str = 'NA'
        else:
            best_ari_bic_str = '%1.3f' % best_ari_bic

        fig_bestbic = plt.figure(figsize=(8, 8))
        ax_bestbic = fig_bestbic.add_subplot(1, 1, 1)
        #ptcolors = [colors[i] for i in best_c_hat_bic]
        ax_bestbic.scatter(x[:, 0], x[:, 1], c=best_c_hat_bic)
        #mncolors = [colors[i] for i in np.arange(best_k_bic)]
        mncolors = [i for i in np.arange(best_k_bic)]
        ax_bestbic.set_title(
            "py(agg-gmm) BIC %3.0f from " % best_bic + str(best_cov_bic) +
            " k=" + str(best_k_bic) + ' ari=' +
            best_ari_bic_str)  # + "iter=" + str(best_iter_bic))
        ax_bestbic.set_xlabel("First feature")
        ax_bestbic.set_ylabel("Second feature")
        if savefigs is not None:
            plt.savefig(savefigs + '_python_bestbic.jpg')

    if graphList != None and 'all_bics' in graphList:
        #plot of all BICS*******************************
        titles = ['full', 'tied', 'diag', 'spherical']
        #ax0.set_title(titles[0],fontsize=20,fontweight='bold')
        #ax0.set_ylabel('BIC',fontsize=20)
        ax0.locator_params(axis='y', tight=True, nbins=4)
        ax0.set_yticklabels(ax0.get_yticks(), fontsize=14)

        #ax1.set_title(titles[1],fontsize=20,fontweight='bold')
        legend = ax1.legend(loc='best', title='Number of\nRuns', fontsize=12)
        plt.setp(legend.get_title(), fontsize=14)

        #ax2.set_title(titles[2],fontsize=20,fontweight='bold')
        #ax2.set_xlabel('Number of components',fontsize=20)
        ax2.set_xticks(np.arange(0, 21, 4))
        ax2.set_xticklabels(ax2.get_xticks(), fontsize=14)
        #ax2.set_ylabel('BIC',fontsize=20)
        ax2.locator_params(axis='y', tight=True, nbins=4)
        ax2.set_yticklabels(ax2.get_yticks(), fontsize=14)

        #ax3.set_title(titles[3],fontsize=20,fontweight='bold')
        #ax3.set_xlabel('Number of components',fontsize=20)
        ax3.set_xticks(np.arange(0, 21, 4))
        ax3.set_xticklabels(ax3.get_xticks(), fontsize=14)

        if savefigs is not None:
            plt.savefig('.\\figures\\25_6_19_paperv2\\' + savefigs +
                        '_graspy_bicplot2.jpg')
    plt.show()

    return best_c_hat_bic, best_cov_bic, best_k_bic, best_ari_bic, best_bic
concatenate_latent = np.concatenate(list(latent), axis=-1)
concatenate_latent.shape
pairplot(concatenate_latent, labels=unknown)
#%%
g = graphs[0]
classes = [meta["Class"] for node, meta in g.nodes(data=True)]
classes = np.array(classes)
unknown = classes == "Other"
plot_unknown = np.tile(unknown, n_graphs)
pairplot(plot_latent, labels=plot_unknown, alpha=0.3, legend_name="Unknown")


clust_latent = np.concatenate(list(latent), axis=-1)
clust_latent.shape
#%%
gc = GaussianCluster(min_components=2, max_components=15, covariance_type="all")

filterwarnings("ignore")
n_init = 50
sim_mat = np.zeros((n_verts, n_verts))

for i in tqdm(range(n_init)):
    assignments = gc.fit_predict(clust_latent)
    for c in np.unique(assignments):
        inds = np.where(assignments == c)[0]
        sim_mat[np.ix_(inds, inds)] += 1


sim_mat -= np.diag(np.diag(sim_mat))
sim_mat = sim_mat / n_init
heatmap(sim_mat)