def signal_flow_marginal(adj, labels, col_wrap=5, palette="tab20"): sf = signal_flow(adj) uni_labels = np.unique(labels) medians = [] for i in uni_labels: inds = np.where(labels == i)[0] medians.append(np.median(sf[inds])) sort_inds = np.argsort(medians)[::-1] col_order = uni_labels[sort_inds] plot_df = pd.DataFrame() plot_df["Signal flow"] = sf plot_df["Class"] = labels fg = sns.FacetGrid( plot_df, col="Class", aspect=1.5, palette=palette, col_order=col_order, sharey=False, col_wrap=col_wrap, xlim=(-3, 3), ) fg = fg.map(sns.distplot, "Signal flow") # bins=np.linspace(-2.2, 2.2)) fg.set(yticks=[], yticklabels=[]) plt.tight_layout() return fg
def to_minigraph( adj, labels, drop_neg=True, remove_diag=True, size_scaler=1, use_counts=False, use_weights=True, color_map=None, ): # convert the adjacency and a partition to a minigraph based on SBM probs prob_df = get_blockmodel_df( adj, labels, return_counts=use_counts, use_weights=use_weights ) if drop_neg and ("-1" in prob_df.index): prob_df.drop("-1", axis=0, inplace=True) prob_df.drop("-1", axis=1, inplace=True) if remove_diag: adj = prob_df.values adj -= np.diag(np.diag(adj)) prob_df.data = prob_df g = nx.from_pandas_adjacency(prob_df, create_using=nx.DiGraph()) uni_labels, counts = np.unique(labels, return_counts=True) # add size attribute base on number of vertices size_map = dict(zip(uni_labels, size_scaler * counts)) nx.set_node_attributes(g, size_map, name="Size") # add signal flow attribute (for the minigraph itself) mini_adj = nx.to_numpy_array(g, nodelist=uni_labels) node_signal_flow = signal_flow(mini_adj) sf_map = dict(zip(uni_labels, node_signal_flow)) nx.set_node_attributes(g, sf_map, name="Signal Flow") # add spectral properties sym_adj = symmetrize(mini_adj) n_components = 10 latent = AdjacencySpectralEmbed(n_components=n_components).fit_transform(sym_adj) for i in range(n_components): latent_dim = latent[:, i] lap_map = dict(zip(uni_labels, latent_dim)) nx.set_node_attributes(g, lap_map, name=f"AdjEvec-{i}") # add spring layout properties pos = nx.spring_layout(g) spring_x = {} spring_y = {} for key, val in pos.items(): spring_x[key] = val[0] spring_y[key] = val[1] nx.set_node_attributes(g, spring_x, name="Spring-x") nx.set_node_attributes(g, spring_y, name="Spring-y") # add colors if color_map is None: color_map = dict(zip(uni_labels, cc.glasbey_light)) nx.set_node_attributes(g, color_map, name="Color") return g
def signal_flow_sort(A, return_inds=False): A = A.copy() nodes_signal_flow = signal_flow(A) sort_inds = np.argsort(nodes_signal_flow)[::-1] sorted_A = A[np.ix_(sort_inds, sort_inds)] if return_inds: return sorted_A, sort_inds else: return sorted_A
print(f"FAQ took {(end - start)/60.0} minutes") perm_inds = faq.perm_inds_ # perm_inds = unshuffle(shuffle_inds, shuffle_perm_inds) # %% [markdown] # # from graspy.plot import gridplot gridplot([adj[np.ix_(perm_inds, perm_inds)]]) stashfig("unshuffled-real-heatmap-faq") # %% [markdown] # # from src.hierarchy import signal_flow z = signal_flow(adj) sort_inds = np.argsort(z)[::-1] gridplot([adj[np.ix_(sort_inds, sort_inds)]]) stashfig("unshuffled-real-heatmap-sf") # %% [markdown] # # # %% [markdown] # # def shuffle_edges(A): n_verts = A.shape[0] A_fake = A.copy().ravel()
meta=perm_meta, colors="merge_class", palette=CLASS_COLOR_DICT, plot_type="scattermap", sizes=(1, 10), ax=ax, ) # %% [markdown] # ## from src.hierarchy import signal_flow from src.visualization import remove_axis import pandas as pd n_verts = len(adj) sf = signal_flow(adj) sf_perm = np.argsort(-sf) inds = np.arange(n_verts) plot_df = pd.DataFrame() # plot_df["labels"] = labels plot_df["x"] = inds def format_order_ax(ax): ax.set_xticks([]) ax.set_yticks([]) ax.set_ylabel("") ax.set_xlabel("True order") ax.axis("square")
norm_bar_width=False, ax=ax, ) ax.set_yticks([]) ax.get_legend().remove() ax.set_title(title) plt.tight_layout() stashfig(f"count-barplot-lvl{i}" + basename) plt.close() inds = np.concatenate((lp_inds, rp_inds)) new_adj = adj[np.ix_(inds, inds)] new_meta = mc.meta new_meta["sf"] = -signal_flow(new_adj) for l in range(n_levels): fig, ax = plt.subplots(1, 1, figsize=(20, 20)) sort_class = [f"lvl{i}_labels" for i in range(l)] sort_class += [f"lvl{l}_labels_side"] _, _, top, _ = adjplot( new_adj, meta=new_meta, sort_class=sort_class, item_order="merge_class", plot_type="scattermap", class_order="sf", sizes=(0.5, 1), ticks=False, colors="merge_class",
not_pdiff = np.where(~mg["is_pdiff"])[0] mg = mg.reindex(not_pdiff) print(len(mg.meta)) g_sym = nx.to_undirected(mg.g) skeleton_labels = np.array(list(g_sym.nodes())) out_dict = cm.best_partition(g_sym, resolution=r) partition = np.array(itemgetter(*skeleton_labels)(out_dict)) adj = nx.to_numpy_array(g_sym, nodelist=skeleton_labels) part_unique, part_count = np.unique(partition, return_counts=True) for uni, count in zip(part_unique, part_count): if count < 3: inds = np.where(partition == uni)[0] partition[inds] = -1 mg.meta["signal_flow"] = signal_flow(mg.adj) mg.meta["partition"] = partition partition_sf = mg.meta.groupby("partition")["signal_flow"].median() sort_partition_sf = partition_sf.sort_values(ascending=False) basename = f"louvain-res{r}-t{thresh}-{graph_type}-" title = f"Louvain, {graph_type}, res = {r}, thresh = {thresh}" # barplot by merge class label (more detail) class_label_dict = nx.get_node_attributes(g_sym, "Merge Class") class_labels = np.array( itemgetter(*skeleton_labels)(class_label_dict)) part_color_dict = dict(zip(np.unique(partition), cc.glasbey_warm)) true_color_dict = dict(zip(names, colors)) color_dict = {**part_color_dict, **true_color_dict} fig, ax = plt.subplots(1, 1, figsize=(20, 20))
sbm.p_mat_, ax=ax, plot_type="heatmap", sort_class=["hemisphere"] + level_names[:level + 1], item_order=["merge_class_sf_order", "merge_class", "sf"], class_order="sf", meta=meta, palette=CLASS_COLOR_DICT, colors="merge_class", ticks=False, gridline_kws=dict(linewidth=0.05, color="grey", linestyle="--"), cbar_kws=dict(shrink=0.6), ) # TODO show sorted by SF for leaf nodes bhat = sbm.block_p_ block_sf = -signal_flow(bhat) meta["block_sf"] = meta[label_name].map(dict(zip(labels, block_sf))) ax = axs[2, level] _, _, top, _ = adjplot( sbm.p_mat_, ax=ax, plot_type="heatmap", sort_class=label_name, # item_order=["merge_class_sf_order", "merge_class", "sf"], class_order="block_sf", meta=meta, palette=CLASS_COLOR_DICT, colors="merge_class", ticks=False, gridline_kws=dict(linewidth=0.05, color="grey", linestyle="--"), cbar_kws=dict(shrink=0.6),
P1[6, 7] = 0.5 P2 = np.zeros((n_blocks, n_blocks)) P2[8, 9] = 0.5 P2[9, 10] = 0.5 P2[10, 6] = 0.5 P2[6, 7] = 0.5 P3 = np.zeros((n_blocks, n_blocks)) P3[11, 12] = 0.5 P3[12, 6] = 0.5 P3[6, 13] = 0.5 B = 0.25 * P0 + 0.25 * P1 + 0.25 * P2 + 0.25 * P3 B = add_noise(B) sf = -signal_flow(B) perm = np.argsort(sf) B = B[np.ix_(perm, perm)] P0 = P0[np.ix_(perm, perm)] P1 = P1[np.ix_(perm, perm)] P2 = P2[np.ix_(perm, perm)] P3 = P3[np.ix_(perm, perm)] n_row = 2 n_col = 6 scale = 5 fig = plt.figure(figsize=(n_col * scale, n_row * scale)) from matplotlib.gridspec import GridSpec gs = GridSpec(n_row, n_col, figure=fig)
fake_adj = signal_flow_sort(fake_adj) fake_triu_prop = compute_triu_prop(fake_adj) out_dict = { "Proportion": fake_triu_prop, "Graph": name, "Type": "Shuffled" } shuffled_triu_outs.append(out_dict) print( f"{g} shuffled graph sorted proportion in upper triangle: {fake_triu_prop}" ) gridmap(fake_adj, ax=axs[0]) axs[0].set_title("Shuffled edges") z = signal_flow(adj) sort_inds = np.argsort(z)[::-1] adj = adj[np.ix_(sort_inds, sort_inds)] true_triu_prop = compute_triu_prop(adj) out_dict = {"Proportion": true_triu_prop, "Graph": name, "Type": "True"} true_triu_outs.append(out_dict) print(f"{g} graph sorted proportion in upper triangle: {true_triu_prop}") gridmap(adj, ax=axs[1]) axs[1].set_title("True edges") fig.suptitle(f"{name} ({n_verts})", fontsize=40, y=1.02) plt.tight_layout() stashfig(f"{g}-gridplot-sf-sorted")
# this you read as "thing at posiition i goes to arr[i]" so the rank of thing at i gm_left_perm, gm_left_score = get_best_run(left_perms, left_scores) gm_right_perm, gm_right_score = get_best_run(right_perms, right_scores) double_adj_plot(gm_left_perm, gm_right_perm) stashfig("adj-gm-flow" + basename) gm_left_sort = np.argsort(gm_left_perm) gm_right_sort = np.argsort(gm_right_perm) ax = rank_corr_plot(gm_left_sort, gm_right_sort) ax.set_title("Pair graph match flow") stashfig("rank-gm-flow" + basename) # signal flow left_sf = -signal_flow(left_adj) right_sf = -signal_flow(right_adj) # how to permute the graph to sort in signal flow sf_left_perm = np.argsort(left_sf) sf_right_perm = np.argsort(right_sf) double_adj_plot(sf_left_perm, sf_right_perm) stashfig("adj-signal-flow" + basename) # how things get ranked in terms of the above sf_left_sort = np.argsort(sf_left_perm) sf_right_sort = np.argsort(sf_right_perm) ax = rank_corr_plot(sf_left_sort, sf_right_sort) ax.set_title("Pair signal flow") stashfig("rank-signal-flow" + basename)
#%% # uni_pred_labels, counts = np.unique(pred_labels, return_counts=True) # uni_ints = range(len(uni_pred_labels)) # label_map = dict(zip(uni_pred_labels, uni_ints)) # int_labels = np.array(itemgetter(*uni_pred_labels)(label_map)) # synapse_counts = _calculate_block_counts(adj, uni_ints, pred_labels) block_df = sbm_prob block_adj = sbm_prob.values block_labels = sbm_prob.index.values sym_adj = symmetrize(block_adj) lse_embed = LaplacianSpectralEmbed(form="DAD", n_components=1) latent = lse_embed.fit_transform(sym_adj) latent = np.squeeze(latent) block_signal_flow = signal_flow(block_adj) block_g = nx.from_pandas_adjacency(block_df, create_using=nx.DiGraph()) pos = dict(zip(block_labels, zip(latent, block_signal_flow))) weights = nx.get_edge_attributes(block_g, "weight") node_colors = np.array(itemgetter(*block_labels)(pred_color_dict)) uni_pred_labels, pred_counts = np.unique(pred_labels, return_counts=True) size_map = dict(zip(uni_pred_labels, pred_counts)) node_sizes = np.array(itemgetter(*block_labels)(size_map)) node_sizes *= 4 norm = mpl.colors.Normalize(vmin=0.1, vmax=block_adj.max()) sm = ScalarMappable(cmap="Blues", norm=norm) cmap = sm.to_rgba(np.array(list(weights.values()))) # cmap = mpl.colors.LinearSegmentedColormap("Blues", block_counts.ravel()).to_rgba( # np.array(list(labels.values()))
# adj_dict[g] = temp_adj if remove_missing: temp_mg = temp_mg.make_lcc() mg_dict[g] = temp_mg graph_type_colors = dict( zip(graph_types, sns.color_palette("colorblind", n_colors=len(graph_types)))) # for mg, g in zip(adjs, graph_types): # sf = -signal_flow(adj) # TODO replace with GM flow # meta[f"{g}_flow"] = sf for g, mg in mg_dict.items(): adj = mg.adj meta = mg.meta sf = -signal_flow(adj) meta[f"{g}_flow"] = sf main_meta.loc[meta.index, f"{g}_flow"] = sf line_kws = dict(linewidth=1, linestyle="--", color="grey") rc_dict = { "axes.spines.right": False, "axes.spines.top": False, "axes.formatter.limits": (-3, 3), "figure.figsize": (6, 3), "figure.dpi": 100, "axes.edgecolor": "lightgrey", "ytick.color": "grey", "xtick.color": "grey", "axes.labelcolor": "grey",
def run_experiment(graph_type=None, thresh=None, res=None): # load and preprocess the data mg = load_metagraph(graph_type, version=BRAIN_VERSION) edgelist = mg.to_edgelist() edgelist = add_max_weight(edgelist) edgelist = edgelist[edgelist["max_weight"] > thresh] mg = edgelist_to_mg(edgelist, mg.meta) mg = mg.make_lcc() mg = mg.remove_pdiff() g_sym = nx.to_undirected(mg.g) skeleton_labels = np.array(list(g_sym.nodes())) partition = run_louvain(g_sym, res, skeleton_labels) # compute signal flow for sorting purposes mg.meta["signal_flow"] = signal_flow(mg.adj) mg.meta["partition"] = partition partition_sf = mg.meta.groupby("partition")["signal_flow"].median() sort_partition_sf = partition_sf.sort_values(ascending=False) # common names basename = f"louvain-res{res}-t{thresh}-{graph_type}-" title = f"Louvain, {graph_type}, res = {res}, thresh = {thresh}" # get out some metadata class_label_dict = nx.get_node_attributes(g_sym, "Merge Class") class_labels = np.array(itemgetter(*skeleton_labels)(class_label_dict)) lineage_label_dict = nx.get_node_attributes(g_sym, "lineage") lineage_labels = np.array(itemgetter(*skeleton_labels)(lineage_label_dict)) lineage_labels = np.vectorize(lambda x: "~" + x)(lineage_labels) classlin_labels, color_dict, hatch_dict = augment_classes( skeleton_labels, class_labels, lineage_labels) # barplot by merge class and lineage fig, axs = barplot_text( partition, classlin_labels, color_dict=color_dict, plot_proportions=False, norm_bar_width=True, figsize=(24, 18), title=title, hatch_dict=hatch_dict, ) stashfig(basename + "barplot-mergeclasslin-props") fig, axs = barplot_text( partition, class_labels, color_dict=color_dict, plot_proportions=False, norm_bar_width=True, figsize=(24, 18), title=title, hatch_dict=None, ) stashfig(basename + "barplot-mergeclass-props") fig, axs = barplot_text( partition, class_labels, color_dict=color_dict, plot_proportions=False, norm_bar_width=False, figsize=(24, 18), title=title, hatch_dict=None, ) stashfig(basename + "barplot-mergeclass-counts") fig, axs = barplot_text( partition, lineage_labels, color_dict=None, plot_proportions=False, norm_bar_width=True, figsize=(24, 18), title=title, ) stashfig(basename + "barplot-lineage-props") # sorted heatmap heatmap( mg.adj, transform="simple-nonzero", figsize=(20, 20), inner_hier_labels=partition, hier_label_fontsize=10, title=title, title_pad=80, ) stashfig(basename + "heatmap") # block probability matrices counts = False weights = False prob_df = get_blockmodel_df(mg.adj, partition, return_counts=counts, use_weights=weights) prob_df = prob_df.reindex(sort_partition_sf.index, axis=0) prob_df = prob_df.reindex(sort_partition_sf.index, axis=1) ax = probplot( 100 * prob_df, fmt="2.0f", figsize=(20, 20), title=f"Louvain, res = {res}, counts = {counts}, weights = {weights}", ) ax.set_ylabel(r"Median signal flow $\to$", fontsize=28) stashfig(basename + f"probplot-counts{counts}-weights{weights}") # plot minigraph with layout adjusted_partition = adjust_partition(partition, class_labels) minigraph = to_minigraph(mg.adj, adjusted_partition, use_counts=True, size_scaler=10) draw_networkx_nice( minigraph, "Spring-x", "Signal Flow", sizes="Size", colors="Color", cmap="Greys", vmin=100, weight_scale=0.001, ) stashfig(basename + "sbm-drawn-network")
remove_pdiff=True, binarize=False, weight=weight, ) print( f"Preprocessed graph {graph_type} with threshold={threshold}, weight={weight}" ) graphs.append(mg) # %% [markdown] # ## signal flow sort and plot sns.set_context("talk", font_scale=1.25) graph_sfs = [] for mg, graph_type in zip(graphs, graph_types): meta = mg.meta sf = signal_flow(mg.adj) meta["signal_flow"] = -sf graph_sfs.append(sf) fig, ax = plt.subplots(1, 1, figsize=(20, 20)) matrixplot( mg.adj, ax=ax, col_meta=meta, row_meta=meta, col_item_order="signal_flow", row_item_order="signal_flow", col_colors="Merge Class", row_colors="Merge Class", col_palette=CLASS_COLOR_DICT, row_palette=CLASS_COLOR_DICT,
result_df = pd.DataFrame(out_dicts) fg = sns.FacetGrid(result_df, col="Metric", col_wrap=3, sharey=False, height=4) fg.map(sns.lineplot, "K", "Score") stashfig(f"metrics-{cluster}-{embed}-right-ad-PTR-raw") # Modifications i need to make to the above # - Increase the height of the sankey diagram overall # - Look into color maps that could be better # - Color the cluster labels by what gets written to the JSON # - Plot the clusters as nodes in a small network # %% [markdown] # # try graph flow node_signal_flow = signal_flow(adj) mean_sf = np.zeros(k) for i in np.unique(pred_labels): inds = np.where(pred_labels == i)[0] mean_sf[i] = np.mean(node_signal_flow[inds]) cluster_mean_latent = gmm.model_.means_[:, 0] block_probs = SBMEstimator().fit(bin_adj, y=pred_labels).block_p_ block_prob_df = pd.DataFrame(data=block_probs, index=range(k), columns=range(k)) block_g = nx.from_pandas_adjacency(block_prob_df, create_using=nx.DiGraph) plt.figure(figsize=(10, 10)) # don't ever let em tell you you're too pythonic pos = dict(zip(range(k), zip(cluster_mean_latent, mean_sf))) # nx.draw_networkx_nodes(block_g, pos=pos)
community_sizes[1::2] = n_feedback community_sizes = n_blocks * [n_feedforward] labels = n_to_labels(community_sizes) A = sbm(community_sizes, block_probs, directed=True, loops=False) n_verts = A.shape[0] perm_inds = np.random.permutation(n_verts) A_perm = A[np.ix_(perm_inds, perm_inds)] heatmap(A, cbar=False, title="Feedforward SBM") stashfig("ffSBM") heatmap(A_perm, cbar=False, title="Feedforward SBM, shuffled") stashfig("ffSBM-shuffle") true_z = signal_flow(A) sort_inds = np.argsort(true_z)[::-1] heatmap( A[np.ix_(sort_inds, sort_inds)], cbar=False, title=r"Feedforward SBM, sorted by $A$ signal flow", ) stashfig("ffSBM-adj-sf") A_fake = A.copy().ravel() np.random.shuffle(A_fake) A_fake = A_fake.reshape((n_verts, n_verts)) fake_z = signal_flow(A_fake) sort_inds = np.argsort(fake_z)[::-1] heatmap( A_fake[np.ix_(sort_inds, sort_inds)],
# %% [markdown] # ## sbm = DCSBMEstimator(directed=True, degree_directed=True, loops=False, max_comm=30) sbm.fit(binarize(adj)) pred_labels = sbm.vertex_assignments_ print(len(np.unique(pred_labels))) meta["pred_labels"] = pred_labels graph = np.squeeze(sbm.sample()) meta["adj_sf"] = -signal_flow(binarize(adj)) block_sf = -signal_flow(sbm.block_p_) block_map = pd.Series(data=block_sf) meta["block_sf"] = meta["pred_labels"].map(block_map) #%% graph_type = "G" fig, axs = plt.subplots(1, 2, figsize=(20, 10)) ax = axs[0] ax, _, tax, _ = matrixplot( binarize(adj), ax=ax, plot_type="scattermap", sizes=(0.25, 0.5), col_colors="merge_class",
ffw_labels[labels % 2 == 1] = "-rec" full_labels = np.core.defchararray.add(labels.astype(str), ffw_labels) A = sbm(community_sizes, block_probs, directed=True, loops=False) n_verts = A.shape[0] perm_inds = np.random.permutation(n_verts) A_perm = A[np.ix_(perm_inds, perm_inds)] heatmap(A, cbar=False, title="Feedforward SBM w/ block recurrence") stashfig("ffSBM") heatmap(A_perm, cbar=False, title="Feedforward SBM w/ block recurrence, shuffled") stashfig("ffSBM-shuffle") true_z = signal_flow(A) sort_inds = np.argsort(true_z)[::-1] heatmap( A[np.ix_(sort_inds, sort_inds)], cbar=False, title="Feedforward SBM w/ block recurrence, sorted by signal flow", ) stashfig("ffSBM-sf") A_fake = A.copy().ravel() np.random.shuffle(A_fake) A_fake = A_fake.reshape((n_verts, n_verts)) fake_z = signal_flow(A_fake) sort_inds = np.argsort(fake_z)[::-1] heatmap( A_fake[np.ix_(sort_inds, sort_inds)],
order = invert_permutation(R["leaves"]) path_meta = pd.DataFrame() path_meta["cluster"] = pred path_meta["dend_order"] = order Z = linkage(squareform(pdist), method="average", optimal_ordering=False) R = dendrogram(Z, truncate_mode=None, get_leaves=True, no_plot=True, color_threshold=-np.inf) order = invert_permutation(R["leaves"]) meta["dend_order"] = order meta["signal_flow"] = -signal_flow(adj) meta["class2"].fillna(" ", inplace=True) # %% [markdown] # ## fig, axs = plt.subplots(1, 2, figsize=(30, 20), gridspec_kw=dict(width_ratios=[0.95, 0.02], wspace=0.02)) ax = axs[0] matrixplot( path_indicator_mat, ax=ax, plot_type="scattermap", col_sort_class=["class1", "class2"], col_class_order="signal_flow",
# ## def compute_mean_visit(hop_hist): n_visits = np.sum(hop_hist, axis=0) weight_sum_visits = (np.arange(1, max_hops + 1)[:, None] * hop_hist).sum(axis=0) mean_visit = weight_sum_visits / n_visits return mean_visit col_df["rw_mean_visit"] = compute_mean_visit(rw_hop_hist) col_df["casc_mean_visit"] = compute_mean_visit(casc_hop_hist) col_df["back_mean_visit"] = compute_mean_visit(back_hop_hist) col_df["diff"] = col_df["casc_mean_visit"] - col_df["back_mean_visit"] col_df["signal_flow"] = -signal_flow(adj) # %% [markdown] # ## sns.set_context("talk", font_scale=1.5) pad = 30 fig, axs = plt.subplots(2, 2, figsize=(20, 20)) ax = axs[0, 0] method = "Signal flow" matrixplot( adj, ax=ax, col_meta=col_df, col_colors="label", col_item_order="signal_flow", row_meta=col_df,
for k in sub_ks[i]: node.plot_model(k) # %% [markdown] # ## sub_k = [2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0] for i, node in enumerate(get_lowest_level(mc)): print(node.name) print() node.select_model(sub_k[i]) # %% [markdown] # ## meta = mc.meta.copy() meta["rand"] = np.random.uniform(size=len(meta)) sf = signal_flow(adj) meta["signal_flow"] = -sf meta["te"] = -meta["Total edgesum"] # %% [markdown] # ## plot by class and randomly within class fig, ax = plt.subplots(1, 1, figsize=(20, 20)) adjplot( adj, meta=meta, sort_class=["0_pred_side"], colors="merge_class", palette=CLASS_COLOR_DICT, item_order=["merge_class", "rand"], plot_type="scattermap", sizes=(0.5, 1), ax=ax,
# # blockmodel_df = get_blockmodel_df(adj, pred_labels, use_weights=True, return_counts=False) plt.figure(figsize=(20, 20)) sns.heatmap(blockmodel_df, cmap="Reds") g = nx.from_pandas_adjacency(blockmodel_df, create_using=nx.DiGraph()) uni_labels, counts = np.unique(pred_labels, return_counts=True) size_scaler = 5 size_map = dict(zip(uni_labels, size_scaler * counts)) nx.set_node_attributes(g, size_map, name="Size") mini_adj = nx.to_numpy_array(g, nodelist=uni_labels) node_signal_flow = signal_flow(mini_adj) sf_map = dict(zip(uni_labels, node_signal_flow)) nx.set_node_attributes(g, sf_map, name="Signal Flow") sym_adj = symmetrize(mini_adj) node_lap = LaplacianSpectralEmbed(n_components=1).fit_transform(sym_adj) node_lap = np.squeeze(node_lap) lap_map = dict(zip(uni_labels, node_lap)) nx.set_node_attributes(g, lap_map, name="Laplacian-2") color_map = dict(zip(uni_labels, cc.glasbey_light)) nx.set_node_attributes(g, color_map, name="Color") g.nodes(data=True) nx.write_graphml(g, f"maggot_models/notebooks/outs/{FNAME}/mini_g.graphml") # %% sort minigraph based on signal flow sort_inds = np.argsort(node_signal_flow)[::-1]