Example #1
0
    wspace=0,
)

ax = fig.add_subplot(gs[1, 1], adjustable="box")  # this is the main
# ax.set_aspect(1)
# ax.axis("equal")
# ax.set(adjustable="box", aspect="equal")

top_cax = fig.add_subplot(gs[0, 1], adjustable="box", sharex=ax)
top_cax.set_aspect("auto")
left_cax = fig.add_subplot(gs[1, 0], adjustable="box", sharey=ax)
left_cax.set_aspect("auto")

classes = sort_meta[sort_class].values
class_colors = np.vectorize(CLASS_COLOR_DICT.get)(classes)
gridmap(data, ax=ax, sizes=(0.5, 1))

from matplotlib.colors import ListedColormap

# make colormap
uni_classes = np.unique(classes)
class_map = dict(zip(uni_classes, range(len(uni_classes))))
color_list = []
for u in uni_classes:
    color_list.append(CLASS_COLOR_DICT[u])
lc = ListedColormap(color_list)
classes = np.vectorize(class_map.get)(classes)
classes = classes.reshape(len(classes), 1)
sns.heatmap(
    classes,
    cmap=lc,
Example #2
0
middle_labels = list(middle_df.index)

if ax is None:
    _, ax = plt.subplots(1, 1, figsize=(10, 10))

# do the actual plotting!
if plot_type == "heatmap":
    sns.heatmap(data,
                cmap=cmap,
                ax=ax,
                vmin=0,
                center=0,
                cbar=False,
                square=True)
elif plot_type == "scattermap":
    gridmap(data, ax=ax, sizes=sizes, border=False)

ax.axis("square")
ax.set_ylim(len(data), -1)
ax.set_xlim(-1, len(data))

# add grid lines separating classes
if plot_type == "heatmap":
    boost = 0
elif plot_type == "scattermap":
    boost = 0.5
for t in first_inds:
    ax.axvline(t - boost, **gridline_kws)
    ax.axhline(t - boost, **gridline_kws)

if use_colors:  # TODO experimental!
Example #3
0
    # sort the graph
    mg = load_metagraph(graph_version, brain_version)
    paired_inds = np.where(mg.meta["Pair ID"] != -1)[0]
    mg = mg.reindex(paired_inds)
    mg.sort_values(["Merge Class", "Pair ID", "Hemisphere"], ascending=True)
    # if graph_version not in ["G", "Gn"]:
    mg.verify(n_checks=10000, graph_type=graph_version, version=brain_version)

    # plot the sorted graph
    mg.meta["Index"] = range(len(mg))
    groups = mg.meta.groupby("Merge Class", as_index=True)
    tick_locs = groups["Index"].mean()
    border_locs = groups["Index"].first()

    fig, ax = plt.subplots(1, 1, figsize=(30, 30))
    gridmap(mg.adj, sizes=(3, 5), ax=ax)
    for bl in border_locs:
        ax.axvline(bl, linewidth=1, linestyle="--", color="grey", alpha=0.5)
        ax.axhline(bl, linewidth=1, linestyle="--", color="grey", alpha=0.5)

    ticklabels = np.array(list(groups.groups.keys()))
    for axis in [ax.yaxis, ax.xaxis]:
        axis.set_major_locator(plt.FixedLocator(tick_locs[0::2]))
        axis.set_minor_locator(plt.FixedLocator(tick_locs[1::2]))
        axis.set_minor_formatter(plt.FormatStrFormatter("%s"))
    ax.tick_params(which="minor", pad=80)
    ax.set_yticklabels(ticklabels[0::2])
    ax.set_yticklabels(ticklabels[1::2], minor=True)
    ax.set_xticklabels(ticklabels[0::2])
    ax.set_xticklabels(ticklabels[1::2], minor=True)
    ax.xaxis.tick_top()
Example #4
0
def matrixplot(
    data,
    ax=None,
    plot_type="heatmap",
    row_meta=None,
    col_meta=None,
    row_sort_class=None,
    col_sort_class=None,
    row_class_order=None,
    col_class_order=None,
    row_ticks=True,
    col_ticks=True,
    row_item_order=None,
    col_item_order=None,
    row_colors=None,
    col_colors=None,
    row_palette="tab10",
    col_palette="tab10",
    col_highlight=None,
    row_highlight=None,
    col_tick_pad=None,
    row_tick_pad=None,
    border=True,
    minor_ticking=False,
    tick_rot=0,
    center=0,
    cmap="RdBu_r",
    sizes=(5, 10),
    square=False,
    gridline_kws=None,
    spinestyle_kws=None,
    highlight_kws=None,
    # dot_color=None,
    **kws,
):
    """Plotting matrices
    
    Parameters
    ----------
    data : np.ndarray, ndim=2
        matrix to plot
    ax : matplotlib axes object, optional
        [description], by default None
    plot_type : str, optional
        One of "heatmap" or "scattermap", by default "heatmap"
    row_meta : pd.DataFrame, pd.Series, list of pd.Series or np.array, optional
        [description], by default None
    col_meta : [type], optional
        [description], by default None
    row_sort_class : list or np.ndarray, optional
        [description], by default None
    col_sort_class : list or np.ndarray, optional
        [description], by default None
    row_colors : dict, optional
        [description], by default None
    col_colors : dict, optional
        [description], by default None
    row_class_order : str, optional
        [description], by default "size"
    col_class_order : str, optional
        [description], by default "size"
    row_item_order : string or list of string, optional
        attribute in meta by which to sort elements within a class, by default None
    col_item_order : [type], optional
        [description], by default None
    row_ticks : bool, optional
        [description], by default True
    col_ticks : bool, optional
        [description], by default True
    border : bool, optional
        [description], by default True
    minor_ticking : bool, optional
        [description], by default False
    cmap : str, optional
        [description], by default "RdBu_r"
    sizes : tuple, optional
        [description], by default (10, 40)
    square : bool, optional
        [description], by default False
    gridline_kws : [type], optional
        [description], by default None
    spinestyle_kws : [type], optional
        [description], by default None
    tick_rot : int, optional
        [description], by default 0
    
    Returns
    -------
    [type]
        [description]
    """

    _check_data(data)

    plot_type_opts = ["scattermap", "heatmap"]
    if plot_type not in plot_type_opts:
        raise ValueError(f"`plot_type` must be one of {plot_type_opts}")

    row_meta, row_sort_class, row_class_order, row_item_order, row_colors = _check_sorting_kws(
        data.shape[0],
        row_meta,
        row_sort_class,
        row_class_order,
        row_item_order,
        row_colors,
    )

    col_meta, col_sort_class, col_class_order, col_item_order, col_colors = _check_sorting_kws(
        data.shape[1],
        col_meta,
        col_sort_class,
        col_class_order,
        col_item_order,
        col_colors,
    )

    # sort the data and metadata
    row_perm_inds, row_meta = sort_meta(
        data.shape[0],
        row_meta,
        row_sort_class,
        class_order=row_class_order,
        sort_item=row_item_order,
    )
    col_perm_inds, col_meta = sort_meta(
        data.shape[1],
        col_meta,
        col_sort_class,
        class_order=col_class_order,
        sort_item=col_item_order,
    )
    data = data[np.ix_(row_perm_inds, col_perm_inds)]

    # draw the main heatmap/scattermap
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(10, 10))

    if plot_type == "heatmap":
        sns.heatmap(data, cmap=cmap, ax=ax, center=center, **kws)
    elif plot_type == "scattermap":
        gridmap(data, ax=ax, sizes=sizes, border=False, **kws)

    if square:
        ax.axis("square")

    if plot_type == "scattermap":
        ax_pad = 0.5
    else:
        ax_pad = 0
    ax.set_ylim(data.shape[0] + ax_pad, 0 - ax_pad)
    ax.set_xlim(0 - ax_pad, data.shape[1] + ax_pad)

    # this will let us make axes for the colors and ticks as necessary
    divider = make_axes_locatable(ax)

    # draw colors
    # note that top_cax and left_cax may = ax if no colors are requested
    top_cax = draw_colors(
        ax,
        divider=divider,
        ax_type="x",
        colors=col_colors,
        palette=col_palette,
        sort_meta=col_meta,
    )
    top_cax.xaxis.set_label_position("top")

    left_cax = draw_colors(
        ax,
        divider=divider,
        ax_type="y",
        colors=row_colors,
        palette=row_palette,
        sort_meta=row_meta,
    )

    remove_shared_ax(ax)

    # draw separators
    draw_separators(
        ax,
        ax_type="x",
        sort_meta=col_meta,
        sort_class=col_sort_class,
        plot_type=plot_type,
        gridline_kws=gridline_kws,
    )
    draw_separators(
        ax,
        ax_type="y",
        sort_meta=row_meta,
        sort_class=row_sort_class,
        plot_type=plot_type,
        gridline_kws=gridline_kws,
    )

    # draw ticks
    if len(col_sort_class) > 0 and col_ticks:
        if col_tick_pad is None:
            col_tick_pad = len(col_sort_class) * [0.5]

        tick_ax = top_cax  # start with the axes we already have
        tick_ax_border = False
        rev_col_sort_class = list(col_sort_class[::-1])

        for i, sc in enumerate(rev_col_sort_class):
            if i > 0:  # add a new axis for ticks
                tick_ax = divider.append_axes("top",
                                              size="1%",
                                              pad=col_tick_pad[i],
                                              sharex=ax)
                remove_shared_ax(tick_ax)
                tick_ax.spines["right"].set_visible(True)
                tick_ax.spines["top"].set_visible(True)
                tick_ax.spines["left"].set_visible(True)
                tick_ax.spines["bottom"].set_visible(False)
                tick_ax_border = True

            draw_ticks(
                tick_ax,
                col_meta,
                rev_col_sort_class[i:],
                ax_type="x",
                tick_rot=tick_rot,
                tick_ax_border=tick_ax_border,
            )
            ax.xaxis.set_label_position("top")

    if len(row_sort_class) > 0 and row_ticks:
        tick_ax = left_cax  # start with the axes we already have
        tick_ax_border = False
        rev_row_sort_class = list(row_sort_class[::-1])
        if row_tick_pad is None:
            row_tick_pad = len(row_sort_class) * [0.5]

        for i, sc in enumerate(rev_row_sort_class):
            if i > 0:  # add a new axis for ticks
                tick_ax = divider.append_axes("left",
                                              size="1%",
                                              pad=row_tick_pad[i],
                                              sharey=ax)
                remove_shared_ax(tick_ax)
                tick_ax.spines["right"].set_visible(False)
                tick_ax.spines["top"].set_visible(True)
                tick_ax.spines["bottom"].set_visible(True)
                tick_ax.spines["left"].set_visible(True)
                tick_ax_border = True

            draw_ticks(
                tick_ax,
                row_meta,
                rev_row_sort_class[i:],
                ax_type="y",
                tick_ax_border=tick_ax_border,
            )

    # if highlight_kws is None:
    #     highlight_kws = dict(color="black", linestyle="-", linewidth=1)
    # if col_highlight is not None:
    #     draw_separators(
    #         ax,
    #         divider=divider,
    #         # tick_ax=tick_ax,
    #         ax_type="x",
    #         sort_meta=col_meta,
    #         all_sort_class=col_sort_class,
    #         level_sort_class=col_highlight,
    #         plot_type=plot_type,
    #         use_ticks=False,
    #         gridline_kws=highlight_kws,
    #         tick_ax_border=False,
    #     )
    # if row_highlight is not None:
    #     draw_separators(
    #         ax,
    #         divider=divider,
    #         # tick_ax=tick_ax,
    #         ax_type="y",
    #         sort_meta=row_meta,
    #         all_sort_class=row_sort_class,
    #         level_sort_class=row_highlight,
    #         plot_type=plot_type,
    #         use_ticks=False,
    #         gridline_kws=highlight_kws,
    #         tick_ax_border=False,
    #     )

    # spines
    if spinestyle_kws is None:
        spinestyle_kws = dict(linestyle="-", linewidth=1, alpha=0.7)
    if border:
        for spine in ax.spines.values():
            spine.set_visible(True)
            # spine.set_color(spinestyle_kws["color"])
            spine.set_linewidth(spinestyle_kws["linewidth"])
            spine.set_linestyle(spinestyle_kws["linestyle"])
            spine.set_alpha(spinestyle_kws["alpha"])

    return ax, divider, top_cax, left_cax
Example #5
0
    for _ in range(n_shuffles):
        fake_adj = shuffle_edges(adj)
        fake_adj = signal_flow_sort(fake_adj)
        fake_triu_prop = compute_triu_prop(fake_adj)
        out_dict = {
            "Proportion": fake_triu_prop,
            "Graph": name,
            "Type": "Shuffled"
        }
        shuffled_triu_outs.append(out_dict)
        print(
            f"{g} shuffled graph sorted proportion in upper triangle: {fake_triu_prop}"
        )
    gridmap_kws = {"c": [palette[i]], "sizes": (5, 10)}

    gridmap(fake_adj, ax=axs[0], **gridmap_kws)
    axs[0].set_title("Shuffled edges")
    axs[0].plot([0, n_verts], [0, n_verts],
                color="grey",
                linewidth=2,
                linestyle="--",
                alpha=0.8)

    z = signal_flow(adj)
    sort_inds = np.argsort(z)[::-1]
    adj = adj[np.ix_(sort_inds, sort_inds)]

    true_triu_prop = compute_triu_prop(adj)

    out_dict = {"Proportion": true_triu_prop, "Graph": name, "Type": "True"}
    true_triu_outs.append(out_dict)
Example #6
0
graph_names = [r"A $\to$ D", r"A $\to$ A", r"D $\to$ D", r"D $\to$ A"]
palette = sns.color_palette("deep", 4)
fig, axs = plt.subplots(2, 2, figsize=(30, 30))  # harex=True, sharey=True)
axs = axs.ravel()
for i, mg in enumerate(mgs):
    ax = axs[i]
    mg.sort_values(["Hemisphere", "dendrite_input"], ascending=False)
    meta = mg.meta
    meta["Original index"] = range(len(meta))
    first_df = mg.meta.groupby(["Hemisphere"]).first()
    first_inds = list(first_df["Original index"].values)
    first_inds.append(len(meta) + 1)
    middle_df = mg.meta.groupby(["Hemisphere"]).mean()
    middle_inds = list(middle_df["Original index"].values)
    middle_labels = list(middle_df.index.values)
    gridmap(mg.adj, ax=ax, sizes=(8, 16), color=palette[i])
    remove_spines(ax)
    ax.set_xticks(middle_inds)
    ax.set_xticklabels(middle_labels)
    ax.set_yticks(middle_inds)
    ax.set_yticklabels(middle_labels)
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xlim((-2, len(meta) + 2))
    ax.set_ylim((len(meta) + 2, -2))
    for t in first_inds:
        ax.axhline(t - 0.5,
                   0.02,
                   0.98,
                   color="grey",
                   linestyle="--",