예제 #1
0
 def _plot_bars(self, pred_side):
     ax = stacked_barplot(
         pred_side,
         self.meta["merge_class"],
         color_dict=CLASS_COLOR_DICT,
         legend_ncol=6,
         category_order=np.unique(pred_side),
     )
     k = int(len(np.unique(pred_side)) / 2)
     ax.set_title(f"{self.name}, k={k}")
     self._stashfig(f"bars-k={k}")
예제 #2
0
def one_iteration(start_labels, class_key="Merge Class"):
    # generate walks
    data, bins, classes = random_walk_classes(start_labels,
                                              seed=None,
                                              class_key=class_key)
    log_data = np.log10(data + 1)
    # plot the clustermap
    path_clustermap(log_data, classes, bins)
    # embed and plot by known class
    embedding = PCA(n_components=8).fit_transform(log_data)
    pairplot(embedding, labels=classes, palette=CLASS_COLOR_DICT)
    # cluster
    agm = AutoGMMCluster(min_components=2,
                         max_components=20,
                         n_jobs=-1,
                         verbose=10)
    pred_labels = agm.fit_predict(embedding)
    plt.figure()
    sns.scatterplot(data=agm.results_, x="n_components", y="bic/aic")
    # plot embedding by cluster
    pairplot(embedding, labels=pred_labels, palette=cc.glasbey_light)
    # plot predicted clusters by known class
    stacked_barplot(pred_labels, classes, color_dict=CLASS_COLOR_DICT)
    return pred_labels
예제 #3
0
def bartreeplot(
        dc,
        class_labels,
        show_props=True,
        print_props=True,
        text_pad=0.01,
        inverse_memberships=True,
        figsize=(24, 23),
        title=None,
        palette=cc.glasbey_light,
        color_dict=None,
):
    # gather necessary info from model
    linkage, labels = dc.build_linkage(
        bic_distance=False)  # hackily built like scipy's
    pred_labels = dc.predict(latent)
    uni_class_labels, uni_class_counts = np.unique(class_labels,
                                                   return_counts=True)
    uni_pred_labels, uni_pred_counts = np.unique(pred_labels,
                                                 return_counts=True)

    # set up the figure
    fig = plt.figure(figsize=figsize)
    r = fig.canvas.get_renderer()
    gs0 = plt.GridSpec(1, 2, figure=fig, width_ratios=[0.2, 0.8], wspace=0)
    gs1 = plt.GridSpec(1, 1, figure=fig, width_ratios=[0.2], wspace=0.1)

    # title the plot
    plt.suptitle(title, y=0.92, fontsize=30, x=0.5)

    # plot the dendrogram
    ax0 = fig.add_subplot(gs0[0])

    dendr_data = dendrogram(
        linkage,
        orientation="left",
        labels=labels,
        color_threshold=0,
        above_threshold_color="k",
        ax=ax0,
    )
    ax0.axis("off")
    ax0.set_title("Dendrogram", loc="left")

    # get the ticks from the dendrogram to apply to the bar plot
    ticks = ax0.get_yticks()

    # plot the barplot (and ticks to the right of them)
    leaf_names = np.array(dendr_data["ivl"])[::-1]
    ax1 = fig.add_subplot(gs0[1], sharey=ax0)
    ax1, prop_data, uni_class, subcategory_colors = stacked_barplot(
        pred_labels,
        class_labels,
        label_pos=ticks,
        category_order=leaf_names,
        ax=ax1,
        bar_height=5,
        horizontal_pad=0,
        palette=palette,
        norm_bar_width=show_props,
        return_data=True,
        color_dict=color_dict,
    )
    ax1.set_frame_on(False)
    ax1.yaxis.tick_right()

    if show_props:
        ax1_title = "Cluster proportion of known cell types"
    else:
        ax1_title = "Cluster counts by known cell types"

    ax1_title = ax1.set_title(ax1_title, loc="left")
    transformer = ax1.transData.inverted()
    bbox = ax1_title.get_window_extent(renderer=r)
    bbox_points = bbox.get_points()
    out_points = transformer.transform(bbox_points)
    xlim = ax1.get_xlim()
    ax1.text(xlim[1],
             out_points[0][1],
             "Cluster name (size)",
             verticalalignment="bottom")

    # plot the cluster compositions as text to the right of the bars
    gs0.update(right=0.4)
    ax2 = fig.add_subplot(gs1[0], sharey=ax0)
    ax2.axis("off")
    gs1.update(left=0.48)

    text_kws = {
        "verticalalignment": "center",
        "horizontalalignment": "left",
        "fontsize": 12,
        "alpha": 1,
        "weight": "bold",
    }

    ax2.set_xlim((0, 1))
    transformer = ax2.transData.inverted()

    cluster_sizes = prop_data.sum(axis=1)
    for i, y in enumerate(ticks):
        x = 0
        for j, (colname, color) in enumerate(zip(uni_class,
                                                 subcategory_colors)):
            prop = prop_data[i, j]
            if prop > 0:
                if inverse_memberships:
                    prop = prop / uni_class_counts[j]
                    name = f"{colname} ({prop:3.0%})"
                else:
                    if print_props:
                        name = f"{colname} ({prop / cluster_sizes[i]:3.0%})"
                    else:
                        name = f"{colname} ({prop})"
                text = ax2.text(x, y, name, color=color, **text_kws)
                bbox = text.get_window_extent(renderer=r)
                bbox_points = bbox.get_points()
                out_points = transformer.transform(bbox_points)
                width = out_points[1][0] - out_points[0][0]
                x += width + text_pad

    # deal with title for the last plot column based on options
    if inverse_memberships:
        ax2_title = "Known cell type (percentage of cell type in cluster)"
    else:
        if print_props:
            ax2_title = "Known cell type (percentage of cluster)"
        else:
            ax2_title = "Known cell type (count in cluster)"
    ax2.set_title(ax2_title, loc="left")
예제 #4
0
# %%
from src.visualization import stacked_barplot

labels = meta["merge_class"].values
uni_labels, counts = np.unique(labels, return_counts=True)
inds = np.argsort(-counts)

paired = meta["Pair ID"] != -1

fig, axs = plt.subplots(1, 2, figsize=(20, 20))
ax = axs[0]
ax, data, uni_subcat, subcategory_colors, order, = stacked_barplot(
    labels,
    labels,
    color_dict=CLASS_COLOR_DICT,
    category_order=uni_labels[inds],
    norm_bar_width=False,
    ax=ax,
    return_data=True,
)
ax.get_legend().remove()
ax.set_title("Class membership")

ax = axs[1]
ax, data, uni_subcat, subcategory_colors, order, = stacked_barplot(
    labels,
    paired,
    category_order=uni_labels[inds],
    norm_bar_width=False,
    ax=ax,
    return_data=True,
예제 #5
0
stashfig("ffwdness-by-model")

# %% [markdown]
# ## make barplots

from src.visualization import barplot_text

lvls = ["lvl0_labels", "lvl1_labels", "lvl2_labels"]
for lvl in lvls:
    pred_labels = meta[lvl]
    true_labels = meta["merge_class"].values
    fig, ax = plt.subplots(1, 1, figsize=(15, 20))
    stacked_barplot(
        pred_labels,
        true_labels,
        category_order=np.unique(pred_labels),
        color_dict=CLASS_COLOR_DICT,
        ax=ax,
    )
    stashfig(f"barplot-no-text-lvl-{lvl}", dpi=200)

# %% [markdown]
# ##
meta["lvl0_labels"] = meta["0_pred"]
meta["lvl1_labels"] = meta["0_pred"] + "-" + meta["1_pred"]
meta["lvl2_labels"] = meta["0_pred"] + "-" + meta["1_pred"] + "-" + meta[
    "2_pred"]
meta["lvl0_labels_side"] = meta["lvl0_labels"] + meta["hemisphere"]
meta["lvl1_labels_side"] = meta["lvl1_labels"] + meta["hemisphere"]
meta["lvl2_labels_side"] = meta["lvl2_labels"] + meta["hemisphere"]
예제 #6
0
    ase = AdjacencySpectralEmbed(n_components=n_components)
    latent = ase.fit_transform(lap)
    latent = np.concatenate(latent, axis=-1)
    return latent


n_components = None
k = 30

latent = lse(adj, n_components, regularizer=None)

gmm = GaussianCluster(min_components=k, max_components=k)

pred_labels = gmm.fit_predict(latent)

stacked_barplot(pred_labels, class_labels, palette="tab20")

# %% [markdown]
# # verify on sklearn toy dataset
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

X, y = make_blobs(n_samples=200, n_features=3, centers=None, cluster_std=3)
# y = y.astype(int).astype(str)
data_df = pd.DataFrame(
    data=np.concatenate((X, y[:, np.newaxis]), axis=-1),
    columns=("Dim 0", "Dim 1", "Dim 2", "Labels"),
)
# data_df["Labels"] = data_df["Labels"].values.astype("<U10")
sns.scatterplot(data=data_df,
                x="Dim 1",
예제 #7
0
            basename = f"louvain-res{r}-t{thresh}-{graph_type}-"
            title = f"Louvain, {graph_type}, res = {r}, thresh = {thresh}"

            # barplot by merge class label (more detail)
            class_label_dict = nx.get_node_attributes(g_sym, "Merge Class")
            class_labels = np.array(
                itemgetter(*skeleton_labels)(class_label_dict))
            part_color_dict = dict(zip(np.unique(partition), cc.glasbey_warm))
            true_color_dict = dict(zip(names, colors))
            color_dict = {**part_color_dict, **true_color_dict}
            fig, ax = plt.subplots(1, 1, figsize=(20, 20))
            stacked_barplot(
                partition,
                class_labels,
                ax=ax,
                color_dict=color_dict,
                plot_proportions=False,
                norm_bar_width=True,
                category_order=sort_partition_sf.index.values,
            )
            ax.set_title(title)
            ax.set_ylabel(r"Median signal flow $\to$", fontsize=28)
            stashfig(basename + "barplot-mergeclass")

            # sorted heatmap
            heatmap(
                mg.adj,
                transform="simple-nonzero",
                figsize=(20, 20),
                inner_hier_labels=partition,
                hier_label_fontsize=10,
예제 #8
0
ax.set_ylabel("Prop. in upper triangle")
ax.set_xlabel("Model")
stashfig("ffwdness-by-model")

# %% [markdown]
# ## make barplots

from src.visualization import barplot_text

lvls = ["lvl0_labels", "lvl1_labels", "lvl2_labels"]
for lvl in lvls:
    pred_labels = meta[lvl]
    true_labels = meta["merge_class"].values
    fig, ax = plt.subplots(1, 1, figsize=(15, 20))
    stacked_barplot(pred_labels,
                    true_labels,
                    color_dict=CLASS_COLOR_DICT,
                    ax=ax)
    stashfig(f"barplot-no-text-lvl-{lvl}", dpi=200)

# %% [markdown]
# ##
start_instance()
# labels = meta["lvl2_labels"].values

for tp in meta["lvl2_labels"].unique():
    ids = list(meta[meta["lvl2_labels"] == tp].index.values)
    ids = [int(i) for i in ids]
    fig = plt.figure(figsize=(30, 10))

    gs = plt.GridSpec(2,
                      3,
예제 #9
0
    min_split=4,
)

mc.fit(n_levels=n_levels, metric=metric)

n_levels = mc.height

show_bars = False
if show_bars:
    fig, axs = plt.subplots(1, n_levels, figsize=(8 * n_levels, 30))
    for i in range(n_levels):
        ax = axs[i]
        stacked_barplot(
            mc.meta[f"lvl{i}_labels_side"],
            mc.meta["merge_class"],
            category_order=np.unique(mc.meta[f"lvl{i}_labels_side"].values),
            color_dict=CLASS_COLOR_DICT,
            norm_bar_width=False,
            ax=ax,
        )
        ax.set_yticks([])
        ax.get_legend().remove()
        ax.set_title(title)

    plt.tight_layout()

    stashfig(f"count-barplot-lvl{i}" + basename)
    plt.close()

inds = np.concatenate((lp_inds, rp_inds))
new_adj = adj[np.ix_(inds, inds)]
new_meta = mc.meta
예제 #10
0
    left=margin + 3 / n_col + gap,
    right=margin + 5 / n_col + gap,
    bottom=margin,
    top=1 - margin,
)

# ax = fig.add_subplot(mid_gs[0, 0])
ax = fig.add_subplot(mid_gs[0, 1:])

cat = temp_meta[f"lvl{level}_labels_side"].values
subcat = temp_meta["merge_class"].values
stacked_barplot(
    cat,
    subcat,
    ax=ax,
    color_dict=CLASS_COLOR_DICT,
    plot_names=True,
    text_color="dimgrey",
    bar_height=0.2,
)
ax.get_legend().remove()

subgraph_inds = temp_meta["inds"].values
subgraph_adj = adj[np.ix_(subgraph_inds, subgraph_inds)]
ax = fig.add_subplot(mid_gs[1:, :])
_, _, top, _ = adjplot(
    pass_to_ranks(subgraph_adj),
    plot_type="heatmap",
    cbar=False,
    ax=ax,
    meta=temp_meta,
예제 #11
0
    color_threshold=0,
    above_threshold_color="k",
    ax=ax0,
)
ax0.axis("off")

ticks = ax0.get_yticks()

leaf_names = np.array(dendr_data["ivl"])[::-1]
ax1 = fig.add_subplot(gs0[1], sharey=ax0)
ax1, prop_data, uni_class, subcategory_colors = stacked_barplot(
    pred_labels,
    class_labels,
    label_pos=ticks,
    category_order=leaf_names,
    ax=ax1,
    bar_height=5,
    horizontal_pad=0,
    palette="tab20",
    norm_bar_width=show_props,
    return_data=True,
)
ax1.set_frame_on(False)
ax1.yaxis.tick_right()

gs0.update(right=0.5)
ax2 = fig.add_subplot(gs1[0], sharey=ax0)
ax2.axis("off")
gs1.update(left=0.6)

text_kws = {
    "verticalalignment": "center",
예제 #12
0
    linkage,
    orientation="left",
    labels=labels,
    color_threshold=0,
    above_threshold_color="k",
    ax=ax0,
)
ax0.axis("off")

ticks = ax0.get_yticks()

leaf_names = np.array(dendr_data["ivl"])[::-1]
ax1 = fig.add_subplot(gs[1], sharey=ax0)
ax1 = stacked_barplot(
    pred_labels,
    class_labels,
    label_pos=ticks,
    category_order=leaf_names,
    ax=ax1,
    bar_height=5,
    horizontal_pad=0,
    palette="tab20",
    norm_bar_width=True,
)
ax1.set_frame_on(False)
ax1.yaxis.tick_right()
plt.title(
    r"Divisive hierarchical clustering, GraspyGMM, LSE, PTR, Full Brain, A $\to$ D"
)
stashfig("hierarchy-bars")
예제 #13
0
norm_embed_spherical = n_sphere.convert_spherical(norm_embed)
norm_embed_spherical = norm_embed_spherical[:, 1:]  # chop off R dimension
pg = pairplot(norm_embed_spherical, labels=labels, palette=CLASS_COLOR_DICT)
pg._legend.remove()


# %% [markdown]
# ##

from sklearn.cluster import AgglomerativeClustering

for k in range(2, 10):
    ag = AgglomerativeClustering(n_clusters=k, affinity="cosine", linkage="average")
    pred_labels = ag.fit_predict(norm_embed)

    ax = stacked_barplot(pred_labels, labels, color_dict=CLASS_COLOR_DICT)
    ax.set_title(f"k={k}")
# %% [markdown]
# ## new
np.random.seed(8888)
mc = MaggotCluster(
    "0",
    adj=adj,
    meta=meta,
    n_init=50,
    stashfig=stashfig,
    max_clusters=8,
    n_components=None,
    embed="unscaled_ase",
    reembed=True,
    realign=True,
예제 #14
0
plt.tight_layout()

plt.rcParams["figure.facecolor"] = "w"
plt.rcParams["savefig.facecolor"] = "w"

stashfig(f"gmm-crossval-pairs-k={k}-n_components={n_components}")
stashfig(f"gmm-crossval-pairs-k={k}-n_components={n_components}", fmt="pdf")

# %% [markdown]
# ##

from src.visualization import barplot_text, stacked_barplot

# barplot_text(pred, meta["merge_class"].values, color_dict=CLASS_COLOR_DICT)
stacked_barplot(pred,
                meta["merge_class"].values,
                color_dict=CLASS_COLOR_DICT,
                legend_ncol=4)

stashfig(f"gmm-crossval-barplot-k={k}-n_components={n_components}")

# %% [markdown]
# ## SUBCLUSTER !

from scipy.optimize import linear_sum_assignment


def compute_pairedness(partition, meta, rand_adjust=False, plot=False):
    partition = partition.copy()
    meta = meta.copy()

    uni_labels, inv = np.unique(partition, return_inverse=True)
예제 #15
0
model = results.loc[ind, "model"]
pred = predict(joint_embed, left_inds, right_inds, model, relabel=False)

plot_cluster_pairs(
    joint_embed,
    left_inds,
    right_inds,
    model,
    meta["merge_class"].values,
    lp_inds,
    rp_inds,
)

# %% [markdown]
# ##
stacked_barplot(pred, meta["merge_class"].values, color_dict=CLASS_COLOR_DICT)


# %%
meta["inds"] = range(len(meta))
left_inds = meta[meta["left"]]["inds"]
right_inds = meta[meta["right"]]["inds"]
lp_inds, rp_inds = get_paired_inds(meta)
results = crossval_cluster(
    embed,
    left_inds,
    right_inds,
    min_clusters=2,
    max_clusters=10,
    left_pair_inds=lp_inds,
    right_pair_inds=rp_inds,
예제 #16
0
#     sub_partition.values,
#     comm_mg["Merge Class"],
#     color_dict=color_dict,
#     plot_proportions=False,
#     norm_bar_width=True,
#     figsize=(24, 18),
#     title=title,
#     hatch_dict=None,
# )
from src.visualization import stacked_barplot

stacked_barplot(
    partition,
    mg["Merge Class"],
    color_dict=color_dict,
    plot_proportions=False,
    norm_bar_width=True,
    hatch_dict=None,
    ax=axs[0],
)
draw_networkx_nice(
    minigraph,
    "Spring-x",
    "Spring-y",
    sizes="Size",
    colors="Color",
    ax=axs[1],
    weight_scale=20,
    vmin=0.0001,
)
axs[1].set_xlabel("")
예제 #17
0
g_sym = nx.to_undirected(g)
skeleton_labels = np.array(list(g_sym.nodes()))
scales = [1]
r = 0.5
out_dict = cm.best_partition(g_sym, resolution=r)
partition = np.array(itemgetter(*skeleton_labels.astype(str))(out_dict))
adj = nx.to_numpy_array(g_sym, nodelist=skeleton_labels)

part_unique, part_count = np.unique(partition, return_counts=True)
for uni, count in zip(part_unique, part_count):
    if count < 3:
        inds = np.where(partition == uni)[0]
        partition[inds] = -1

class_label_dict = nx.get_node_attributes(g_sym, "Class 1")
class_labels = np.array(itemgetter(*skeleton_labels)(class_label_dict))
part_color_dict = dict(zip(np.unique(partition), cc.glasbey_warm))
true_color_dict = dict(zip(np.unique(class_labels), cc.glasbey_light))
color_dict = {**part_color_dict, **true_color_dict}
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
stacked_barplot(
    partition,
    class_labels,
    ax=ax,
    color_dict=color_dict,
    plot_proportions=False,
    norm_bar_width=True,
)
stashfig("louvain-barplot")

예제 #18
0
print(f"ARI: {results.loc[ind, 'ARI']}")
print(f"Pairedness: {results.loc[ind, 'pairedness']}\n")

model = results.loc[ind, "model"]
left_model = model
right_model = model

pred = composite_predict(
    X, left_inds, right_inds, left_model, right_model, relabel=False
)
pred_side = composite_predict(
    X, left_inds, right_inds, left_model, right_model, relabel=True
)

ax = stacked_barplot(
    pred_side, meta["merge_class"].values, color_dict=CLASS_COLOR_DICT, legend_ncol=6
)
ax.set_title(basetitle)
stashfig(f"barplot" + basename)


fig, ax = plot_cluster_pairs(
    X, left_inds, right_inds, left_model, right_model, meta["merge_class"].values
)
fig.suptitle(basetitle, y=1)

stashfig(f"pairs" + basename)


sf = signal_flow(adj)
meta["signal_flow"] = -sf
예제 #19
0
agg = AgglomerativeClustering(n_clusters=10, affinity="euclidean", linkage="average")
labels = agg.fit_predict(raw_hist_data)
pairplot(embedding, labels=labels, palette=cc.glasbey_light)

# %% [markdown]
# #

from graspy.cluster import AutoGMMCluster

agm = AutoGMMCluster(min_components=2, max_components=20, n_jobs=-1
agm.fit(embedding)

# %% [markdown] 
# # 
# agm.results_.groupby(["affinity", "covariance_type", "linkage"])
sns.scatterplot(data=agm.results_, x='n_components', y='bic/aic')

# %% [markdown] 
# # 

new_groups = agm.predict(embedding)

stacked_barplot(new_groups, meta["Merge Class"].values, color_dict=CLASS_COLOR_DICT)
# %% [markdown] 
# # 

pairplot(embedding, labels=new_groups, palette='tab10')

# %% [markdown] 
# #
예제 #20
0
pred = composite_predict(X,
                         left_inds,
                         right_inds,
                         left_model,
                         right_model,
                         relabel=False)
pred_side = composite_predict(X,
                              left_inds,
                              right_inds,
                              left_model,
                              right_model,
                              relabel=True)

ax = stacked_barplot(pred_side,
                     meta["merge_class"].values,
                     color_dict=CLASS_COLOR_DICT,
                     legend_ncol=6)
ax.set_title(basetitle)
stashfig(f"barplot" + basename)

fig, ax = plot_cluster_pairs(X, left_inds, right_inds, left_model, right_model,
                             meta["merge_class"].values)
fig.suptitle(basetitle, y=1)

stashfig(f"pairs" + basename)

sf = signal_flow(adj)
meta["signal_flow"] = -sf
meta["pred"] = pred
meta["pred_side"] = pred_side
meta["group_signal_flow"] = meta["pred"].map(
예제 #21
0
    def predict(self, X):
        predictions = []
        for sample in X:
            pred = self.predict_sample(sample)
            predictions.append(pred)
        return np.array(predictions)


pgmm = PartitionCluster()
pgmm.fit(latent)
pred_labels = pgmm.predict(latent)

from src.visualization import stacked_barplot

stacked_barplot(pred_labels, class_labels)

# %% [markdown]
# #

uni_labels = np.unique(pred_labels)
# consider only the longest strings:

label_lens = []
for l in uni_labels:
    str_len = len(l)
    label_lens.append(str_len)
label_lens = np.array(label_lens)
max_len = max(label_lens)
print(max_len)
예제 #22
0
def run_experiment(seed=None, graph_type=None, threshold=None, param_key=None):
    np.random.seed(seed)
    if BLIND:
        temp_param_key = param_key.replace(
            " ", "")  # don't want spaces in filenames
        savename = f"{temp_param_key}-cell-types-"
        title = param_key
    else:
        savename = f"{graph_type}-t{threshold}-cell-types"
        title = f"{graph_type}, threshold = {threshold}"

    mg = load_metagraph(graph_type, version=VERSION)

    # simple threshold
    # TODO they will want symmetric threshold...
    # TODO maybe make that a parameter
    adj = mg.adj.copy()
    adj[adj <= threshold] = 0
    meta = mg.meta.copy()
    meta = pd.DataFrame(mg.meta["neuron_name"])
    mg = MetaGraph(adj, meta)

    # run the graphtool code
    temp_loc = f"maggot_models/data/interim/temp-{param_key}.graphml"
    block_series = run_minimize_blockmodel(mg, temp_loc)

    # manage the output
    mg = load_metagraph(graph_type, version=VERSION)
    mg.meta = pd.concat((mg.meta, block_series), axis=1)
    mg.meta["Original index"] = range(len(mg.meta))
    keep_inds = mg.meta[~mg.meta["block_label"].isna(
    )]["Original index"].values
    mg.reindex(keep_inds)
    if graph_type != "G":
        mg.verify(10000, graph_type=graph_type, version=VERSION)

    # deal with class labels
    lineage_labels = mg.meta["lineage"].values
    lineage_labels = np.vectorize(lambda x: "~" + x)(lineage_labels)
    class_labels = mg["Merge Class"]
    skeleton_labels = mg.meta.index.values
    classlin_labels, color_dict, hatch_dict = augment_classes(
        skeleton_labels, class_labels, lineage_labels)
    block_label = mg["block_label"].astype(int)

    # barplot with unknown class labels merged in, proportions
    _, _, order = barplot_text(
        block_label,
        classlin_labels,
        norm_bar_width=True,
        color_dict=color_dict,
        hatch_dict=hatch_dict,
        title=title,
        figsize=(24, 18),
        return_order=True,
    )
    stashfig(savename + "barplot-mergeclasslin-props")
    category_order = np.unique(block_label)[order]

    # barplot with regular class labels
    barplot_text(
        block_label,
        class_labels,
        norm_bar_width=True,
        color_dict=color_dict,
        hatch_dict=hatch_dict,
        title=title,
        figsize=(24, 18),
        category_order=category_order,
    )
    stashfig(savename + "barplot-mergeclass-props")

    # barplot with unknown class labels merged in, counts
    barplot_text(
        block_label,
        classlin_labels,
        norm_bar_width=False,
        color_dict=color_dict,
        hatch_dict=hatch_dict,
        title=title,
        figsize=(24, 18),
        return_order=True,
        category_order=category_order,
    )
    stashfig(savename + "barplot-mergeclasslin-counts")

    # barplot of hemisphere membership
    fig, ax = plt.subplots(1, 1, figsize=(10, 20))
    stacked_barplot(
        block_label,
        mg["Hemisphere"],
        norm_bar_width=True,
        category_order=category_order,
        ax=ax,
    )
    remove_spines(ax)
    stashfig(savename + "barplot-hemisphere")

    # plot block probability matrix
    counts = False
    weights = False
    prob_df = get_blockmodel_df(mg.adj,
                                block_label,
                                return_counts=counts,
                                use_weights=weights)
    prob_df = prob_df.reindex(order, axis=0)
    prob_df = prob_df.reindex(order, axis=1)
    ax = probplot(100 * prob_df,
                  fmt="2.0f",
                  figsize=(20, 20),
                  title=title,
                  font_scale=0.4)
    stashfig(savename + "probplot")
    block_series.name = param_key
    return block_series
예제 #23
0
        ax=ax,
        s=30,
        alpha=1,
        ellipses="filled",
        ellipse_kws=dict(linewidth=2, alpha=0.1),
    )
    ax.set(xlabel="", ylabel="", xticks=[], yticks=[])

    ax = axs[1, j]
    uni_pred, uni_counts = np.unique(pred_labels, return_counts=True)
    sort_inds = np.argsort(uni_counts)

    stacked_barplot(
        pred_labels,
        labels,
        color_dict=palette,
        norm_bar_width=False,
        ax=ax,
        category_order=uni_pred[sort_inds],
    )
    ax.set(yticks=[], xticks=[0, 25, 50])
    ax.set_xticklabels([0, 25, 50], fontsize="large")
    ax.xaxis.set_visible(True)
    ax.spines["bottom"].set_visible(True)
    handles, labels = ax.get_legend_handles_labels()
    ax.get_legend().remove()

axs[0, 1].legend(
    handles=handles, labels=labels, bbox_to_anchor=(1, 0), loc="lower right"
)
# axs[1, 0].set_xlabel(
#
예제 #24
0
# %% [markdown]
# #
paths = sm_paths
path_start_labels = []
path_start_labels = [p[0] for p in paths]
path_start_labels = np.array(path_start_labels)

#%%
class_start_labels = meta.iloc[path_start_labels,
                               meta.columns.get_loc("Merge Class")].values

# %% [markdown]
# #

fig, ax = plt.subplots(1, 1, figsize=(10, 20))
stacked_barplot(pred_labels, class_start_labels, ax=ax)

# %% [markdown]
# # Start on the big visualization

# choose a cluster
# get union path graph

path_graph = nx.MultiDiGraph()
chosen_cluster = 3
path_cluster = []
for i, path in enumerate(paths):
    if pred_labels[i] == chosen_cluster:
        path_cluster.append(path)

all_nodes = list(itertools.chain.from_iterable(path_cluster))
예제 #25
0
    # ax = fig.add_subplot(1, 3, 3, projection="3d")
    ax = fig.add_subplot(gs[0, 2], projection="3d")
    pymaid.plot2d(
        ids,
        color=skeleton_color_dict,
        ax=ax,
        connectors=False,
        method="3d",
        autoscale=True,
    )
    ax.azim = -90
    ax.elev = 90
    ax.dist = 6
    set_axes_equal(ax)

    ax = fig.add_subplot(gs[1, :])
    temp_meta = meta[meta["lvl2_labels"] == tp]
    cat = temp_meta["lvl2_labels_side"].values
    subcat = temp_meta["merge_class"].values
    stacked_barplot(cat, subcat, ax=ax, color_dict=CLASS_COLOR_DICT)
    ax.get_legend().remove()

    fig.suptitle(tp)

    stashfig(f"plot3d-{tp}")
    plt.close()


# %%
예제 #26
0
def plot_level(sub_results, sub_data, ks, label, metric="bic", basename=""):
    if isinstance(ks, int):
        ks = (ks, )

    sub_X = sub_data["X"]
    sub_left_inds = sub_data["left_inds"]
    sub_right_inds = sub_data["right_inds"]
    sub_lp_inds = sub_data["left_pair_inds"]
    sub_rp_inds = sub_data["right_pair_inds"]
    sub_meta = sub_data["meta"]
    reembed = sub_data["reembed"]

    fig, axs = plot_metrics(sub_results)
    fig.suptitle(f"Clustering for cluster {label}, reembed={reembed}")
    for ax in axs[:-1]:
        for k in ks:
            ax.axvline(k, linestyle="--", color="red", linewidth=2)
    stashfig(f"cluster-metrics-label={label}-reembed={reembed}" + basename)
    plt.close()

    for k in ks:
        if k != 0:
            sub_basename = f"-label={label}-subk={k}-reembed={reembed}" + basename
            sub_basetitle = f"Cluster for {label}, subk={k}, reembed={reembed},"
            sub_basetitle += (f" metric={metric}, k={k}"
                              )  # , n_components={n_components}"

            ind = sub_results[sub_results["k"] == k][metric].idxmax()
            sub_model = sub_results.loc[ind, "model"]
            sub_left_model = sub_model
            sub_right_model = sub_model

            sub_pred_side = composite_predict(
                sub_X,
                sub_left_inds,
                sub_right_inds,
                sub_left_model,
                sub_right_model,
                relabel=True,
            )

            ax = stacked_barplot(
                sub_pred_side,
                sub_meta["merge_class"].values,
                color_dict=CLASS_COLOR_DICT,
                legend_ncol=6,
            )
            ax.set_title(sub_basetitle)
            stashfig(f"barplot" + sub_basename)
            plt.close()

            fig, ax = plot_cluster_pairs(
                sub_X,
                sub_left_inds,
                sub_right_inds,
                sub_left_model,
                sub_right_model,
                sub_meta["merge_class"].values,
            )
            fig.suptitle(sub_basetitle, y=1)
            stashfig(f"pairs" + sub_basename)
            plt.close()
예제 #27
0
def plot_neurons(meta, key=None, label=None, barplot=False):

    if label is not None:
        ids = list(meta[meta[key] == label].index.values)
    else:
        ids = list(meta.index.values)
    ids = [int(i) for i in ids]

    new_ids = []
    for i in ids:
        try:
            pymaid.get_neuron(
                i, raise_missing=True, with_connectors=False, with_tags=False
            )
            new_ids.append(i)
        except:
            print(f"Missing neuron {i}, not plotting it.")

    ids = new_ids
    meta = meta.loc[ids]

    fig = plt.figure(figsize=(30, 10))

    gs = plt.GridSpec(2, 3, figure=fig, wspace=0, hspace=0, height_ratios=[0.8, 0.2])

    skeleton_color_dict = dict(
        zip(meta.index, np.vectorize(CLASS_COLOR_DICT.get)(meta["merge_class"]))
    )

    ax = fig.add_subplot(gs[0, 0], projection="3d")

    pymaid.plot2d(
        ids,
        color=skeleton_color_dict,
        ax=ax,
        connectors=False,
        method="3d",
        autoscale=True,
    )
    ax.azim = -90
    ax.elev = 0
    ax.dist = 5
    set_axes_equal(ax)

    ax = fig.add_subplot(gs[0, 1], projection="3d")
    pymaid.plot2d(
        ids,
        color=skeleton_color_dict,
        ax=ax,
        connectors=False,
        method="3d",
        autoscale=True,
    )
    ax.azim = 0
    ax.elev = 0
    ax.dist = 5
    set_axes_equal(ax)

    ax = fig.add_subplot(gs[0, 2], projection="3d")
    pymaid.plot2d(
        ids,
        color=skeleton_color_dict,
        ax=ax,
        connectors=False,
        method="3d",
        autoscale=True,
    )
    ax.azim = -90
    ax.elev = 90
    ax.dist = 5
    set_axes_equal(ax)

    if barplot:
        ax = fig.add_subplot(gs[1, :])
        temp_meta = meta[meta[key] == label]
        cat = temp_meta[key + "_side"].values
        subcat = temp_meta["merge_class"].values
        stacked_barplot(
            cat,
            subcat,
            ax=ax,
            color_dict=CLASS_COLOR_DICT,
            category_order=np.unique(cat),
        )
        ax.get_legend().remove()

    # fig.suptitle(label)
    return fig, ax