コード例 #1
0
def plot_multi_neuron(neurons, layout, to_save="", scalebar=(200,"$\mu m$"),fig_size=None):
    """Plot mutiple neurons in array manner.
    neurons: neurom.population.Population.
    layout: 1x2 tuple. including the row number and the colunm number.
    to_save: the file to save as. If left empty, not to save.
    fig_size: default: 2x2 inch multiple with layout."""
    if fig_size is None:
        fig_size=(2*layout[0], 2*layout[1])
    fig, axs = plt.subplots(*layout, figsize=fig_size)
    if layout[1]==1 and axs.ndim==1:
        axs=axs[:,np.newaxis]
    elif layout[0]==1 and axs.ndim==1:
        axs=axs[np.newaxis,:]
    for (ind, ax), neuron in zip(np.ndenumerate(np.array(axs)), neurons):
        view.plot_neuron(ax, neuron)
        ax.autoscale()
        ax.set_title("")
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_axis_off()
    left = len(neurons) - layout[0] * layout[1]
    if left<0:
        for ax in axs.flat[left:]:
            ax.set_visible(False)
    plot_unified_scale_grid(fig, axs)
    add_scalebar(axs.flat[-1], scalebar[0], scalebar[1], fig)
    if bool(to_save):
        to_save_figure(to_save)
コード例 #2
0
    def plot_demo_epsc(self, file_id, to_save=""):
        """Input EPSC file path, plot demo figure.
        file_id:  The n-th EPSC abf file selected as demo;
        to_save: Specify the filename to save as 600 dpi figure.
        Default is empty string which means not to save.
        """
        mDemo = pyabf.ABF(self.files["abf"][file_id])

        plt.figure(figsize=(10, 3))
        plt.plot(mDemo.sweepX, mDemo.sweepY, color="C0", alpha=0.8)
        plt.ylabel(mDemo.sweepLabelY)
        plt.xlabel(mDemo.sweepLabelX)
        plt.axis("off")

        ax = plt.gca()
        axin = zoomed_inset_axes(
            ax,
            10,
            4,
            axes_kwargs={
                "xlabel": "10 s",
                "ylabel": "10 pA",
                "xticks": [],
                "yticks": [],
            },
        )
        axin.spines["top"].set_visible(False)
        axin.spines["right"].set_visible(False)
        if to_save:
            to_save_figure(to_save)
コード例 #3
0
 def plot_sholl(self, sholl_part=True, to_save=""):
     """Plot sholl analysis of apical and basal parts.
     sholl_part: logical. plot domain or plot whole."""
     if not sholl_part:
         imarisData = self.get_imaris_stat()
         shollData = imarisData[
             imarisData.Variable == "Filament No. Sholl Intersections"
         ].loc[:, ["neuron_ID", "Radius", "Value"]]
         shollPlotData = (
             zero_padding(shollData, "Value")
             .groupby("Radius")
             .agg([np.mean, sem])
             .reset_index()
         )
         plt.plot(shollPlotData["Radius"], shollPlotData[("Value", "mean")])
         plt.fill_between(
             shollPlotData["Radius"],
             shollPlotData[("Value", "mean")] + shollPlotData[("Value", "sem")],
             shollPlotData[("Value", "mean")] - shollPlotData[("Value", "sem")],
             alpha=0.6,
         )
         plt.ylabel("Sholl intersections")
         plt.xlabel("$Radius\ (\mu m)$")
     else:
         dat = self.compute_sholl_parts_stat()
         dat = dat.pivot(
             index="radius",
             columns="label",
             values=["intersections_mean", "intersections_sem"],
         )
         plt.plot(
             dat.index,
             dat.intersections_mean.apical,
             dat.index,
             dat.intersections_mean.basal,
             label="",
         )
         ax1 = plt.fill_between(
             dat.index,
             dat.loc[:, ("intersections_mean", "apical")]
             + dat.loc[:, ("intersections_sem", "apical")],
             dat.loc[:, ("intersections_mean", "apical")]
             - dat.loc[:, ("intersections_sem", "apical")],
             alpha=0.6,
             label="apical",
         )
         ax2 = plt.fill_between(
             dat.index,
             dat.loc[:, ("intersections_mean", "basal")]
             + dat.loc[:, ("intersections_sem", "basal")],
             dat.loc[:, ("intersections_mean", "basal")]
             - dat.loc[:, ("intersections_sem", "basal")],
             alpha=0.6,
             label="basal",
         )
         plt.ylabel("Sholl intersections")
         plt.xlabel("$Radius\ (\mu m)$")
         plt.legend()
     if bool(to_save):
         to_save_figure(to_save)
コード例 #4
0
def plot_cluster_hist(
    dat,
    col,
    group1=0,
    group2=1,
    to_save="",
    fig_size=(3,3),
    alpha=0.6,
    bins=20,
    xlabel=None,
    ylabel="No. of neuron",
    plot_cumulative=False,
    legend_kw=None,
    show_legend=True,
    **kwargs
):
    """Plot histogram of two clusters.
    dat: CA result data, whose columns contains `cluster`.
    col: column in `dat` to be plotted.
    xlabel: x-label string. Default is equal to `col`.
    plot_cumulative: {True, False}. If True, draw cumulative curve inset with `plot_cumulative_curve()`.
    axin_kw: parameters passed to `figure.add_axes()`.
    **kwargs: other parameters passed to `axes.hist()`
    """
    range_ = dat[col].min(), dat[col].max()
    plt.figure(figsize=fig_size)
    ax = plt.axes()
    ax.hist(
        dat[dat.cluster.values == group1][col],
        alpha=alpha,
        bins=bins,
        label="cluster {}".format(group1 + 1),
        range=range_,
        **kwargs
    )
    ax.hist(
        dat[dat.cluster.values == group2][col],
        alpha=alpha,
        bins=bins,
        label="cluster {}".format(group2 + 1),
        range=range_,
        **kwargs
    )
    if plot_cumulative:
        fig = plt.gcf()
        axin = fig.add_axes([0.7, 0.6, 0.17, 0.2])
        plot_cumulative_curve(dat, col, (group1, group2), ax=axin)
    if show_legend:
        ax.legend(**legend_kw)
    if xlabel is None:
        xlabel = col
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        plt.tight_layout()
    if bool(to_save):
        to_save_figure(to_save)
コード例 #5
0
def plot_iv_traces(IVdata, to_save=""):
    """Plot I-V traces for all abf file.
    IVdata: dataframe. Format from generation by function ap_parser.get_all_data('iv')"""
    for key, tb in IVdata.groupby("CellID"):
        plt.plot(tb["I"], tb["Vm"], "-o", alpha=0.5, color="C0")
    plt.xlabel("I (nA)")
    plt.ylabel("Vm (mV)")
    if bool(to_save):
        to_save_figure(to_save=to_save)
コード例 #6
0
def plot_Rm_distribution(Rm_data, to_save=""):
    """Bar plot of Rm for all neurons.
    Rm_data: dataframe. Format from generation by function compute_Rm().
    to_save: string. Filename to be saved."""
    # Filter R_score>0.9
    Rm_data[Rm_data.R_score > 0.9].Rm.plot.hist(alpha=0.75)
    plt.xlabel(r"$Rm\ (M\Omega)$")
    plt.ylabel("Number of neuron")
    if bool(to_save):
        to_save_figure(to_save)
コード例 #7
0
 def plot_tau_distribution(self, to_save=""):
     """Plot a bar plot of tau parameters, exclude outlier examples.
     to_save: file path. If left empty, not to save."""
     dat = self.get_tau_stat()
     dat.tau[dat.tau.between(0, 500)].plot.hist(
         alpha=0.7)  # tau<0 or tau is too large, imnormal
     plt.xlabel("Time constant (ms)")
     plt.ylabel("Number of neuron")
     if bool(to_save):
         to_save_figure(to_save)
コード例 #8
0
 def plot_decomposition_scatter(
     self, ca_dat=None, dim1=0, dim2=1, to_save="", plot_3d=False, fig_size=(3,3), show_legend=True,
     label_map=None,
 ):
     """Plot PCA scatter.
     dim1, dim2: The number of dimension after PCA.
     plot_3d: 3D scatter plot. If true, auto-select first 3 dimensions.
     label_map: dict. map cluster ID to label string. Default: {0: 'cluster 1', 1: 'cluster 2', ...}."""
     if ca_dat is None:
         ca_dat = self.k_means(return_filled=True, return_scaled=True)
     X = ca_dat.to_numpy()[:, :-1]
     pca = PCA(n_components=0.8)
     newX = pca.fit_transform(X)
     _color = ("C{}".format(i) for i in range(9))
     fig = plt.figure(figsize=fig_size)
     if label_map is None:
         label_map={clust:'cluster {}'.format(clust + 1) for clust in np.sort(ca_dat.cluster.unique())}
     if plot_3d:
         ax = Axes3D(fig)
         ax.view_init(30,30)
         dim1, dim2, dim3 = (0, 1, 2)
         for clust in np.sort(ca_dat.cluster.unique()):
             clust_row = ca_dat.cluster.values == clust
             ax.scatter(
                 newX[clust_row, 0],
                 newX[clust_row, 1],
                 newX[clust_row, 2],
                 color=next(_color),
                 label=label_map[clust],
             )
         ax.dist = 12 # avoid labels incomplete
         ax.legend()
         ax.set_xlabel("Component {}".format(dim1 + 1))
         ax.set_ylabel("Component {}".format(dim2 + 1))
         ax.set_zlabel("Component {}".format(dim3 + 1))
         plt.yticks(rotation=30,horizontalalignment='center',
                     verticalalignment='baseline',rotation_mode='anchor')
         plt.tight_layout()
     else:
         ax = fig.add_subplot(1, 1, 1)
         for clust in np.sort(ca_dat.cluster.unique()):
             clust_row = ca_dat.cluster.values == clust
             ax.scatter(
                 newX[clust_row, dim1],
                 newX[clust_row, dim2],
                 color=next(_color),
                 label=label_map[clust],
             )
         if show_legend:
             ax.legend()
         ax.set_xlabel("Component {}".format(dim1 + 1))
         ax.set_ylabel("Component {}".format(dim2 + 1))
         plt.tight_layout()
     if bool(to_save):
         to_save_figure(to_save)
コード例 #9
0
def plot_cumulative_curve(tab, feature, clusters=(0,1), n_bins=50, ax=None,fig_size=(3,3), to_save="", **kwargs):
    """Plot cumulative density distribution curve."""
    if ax is None:
        fig, ax = plt.subplots(figsize=fig_size)
    x1 = tab[tab.cluster==clusters[0]][feature]
    x2 = tab[tab.cluster==clusters[1]][feature]
    range_ = tab[feature].min(), tab[feature].max()
    ax.hist(x1, n_bins, cumulative=True, histtype='step', range=range_, density=True)
    ax.hist(x2, n_bins, cumulative=True, histtype='step', range=range_, density=True)
    if bool(to_save):
        to_save_figure(to_save)
コード例 #10
0
 def plot_demo_distribution(self, term, file_id, to_save=""):
     """Plot a bar plot of AP number vs. sweep number.
     term: String. Include 'freq', 'amp'.
     file_id: Interger. the n-th 'freq' file.
     to_save: file path. If left empty, not to save."""
     demoData = self.read_demo_stat(term, file_id)
     demoData.plot("sweep", "n", kind="bar", legend=False)
     identifier = getCellID(self.files[term][file_id])
     plt.title(identifier)
     plt.ylabel("AP number")
     if bool(to_save):
         to_save_figure(term + " distribution of " + to_save)
コード例 #11
0
 def plot_depth(self, to_save=""):
     """Plot branch order distribution."""
     depthData = self.get_depth_data()
     depthPlotData = depthData.groupby("Depth").agg([np.mean, sem]).reset_index()
     plt.bar(
         depthPlotData["Depth"] + 1, depthPlotData[("counts", "mean")], alpha=0.6
     )
     plt.errorbar(
         depthPlotData["Depth"] + 1,
         depthPlotData[("counts", "mean")],
         depthPlotData[("counts", "sem")],
         linestyle="",
     )
     plt.ylabel("Number of filaments")
     plt.xlabel("Branch order")
     if bool(to_save):
         to_save_figure(to_save)
コード例 #12
0
 def plot_demo_ap(self,
                  file_id=None,
                  to_save="",
                  sweep="all",
                  cell_id=None,
                  fig_size=(3, 3),
                  aspect_ratio=150,
                  sweepYcolor='C0',
                  sweepCcolor='C0'):
     """plot AP demo trace from a file.
     file_id: Interger. The n-th file in root_path.
     to_save: String. The figure to be saved. Default: empty (not to save).
     sweep: String or interger. The n-th sweep to be plotted. 
     aspect_ratio: y-axis scale ratio of voltage and current.
     sweepYcolor, sweepCcolor: the color of sweepY axes and sweepC axes.
     Default: 'all', all sweeps will be plotted."""
     if cell_id is not None:
         file_id = self.get_file_id(cell_id)
     fig = plt.figure(figsize=(3, 3))
     ax1 = fig.add_subplot(211)
     ax2 = fig.add_subplot(212)
     demo = pyabf.ABF(self.files["abf"][file_id])
     if sweep == "all":
         for sweep in demo.sweepList:
             demo.setSweep(sweep)
             ax1.plot(demo.sweepX[500:5500], demo.sweepY[500:5500],
                      sweepYcolor)
             ax2.plot(demo.sweepX[500:5500], demo.sweepC[500:5500],
                      sweepCcolor)
     else:
         if not isinstance(sweep, int):
             sweep = int(sweep)
         demo.setSweep(sweep - 1)
         ax1.plot(demo.sweepX[500:5500], demo.sweepY[500:5500], sweepYcolor)
         ax2.plot(demo.sweepX[500:5500], demo.sweepC[500:5500], sweepCcolor)
     asp1 = get_aspect_scale(ax1)
     ax2.set_aspect(asp1[1] * aspect_ratio / asp1[0])
     add_scalebar(ax2, (.1, .2), ('s', 'nA'), fig, y_label='200pA/\n30mV')
     ax1.axis("off")
     ax2.axis("off")
     with warnings.catch_warnings():
         warnings.simplefilter('ignore')
         plt.tight_layout()
     if bool(to_save):
         to_save_figure(to_save)
コード例 #13
0
def plot_cluster_stat(tab, feature, clusters=(0,1), fig_size=(2.5,2), xticks=None, capsize=8, plot_box=False, 
                      ylabel=None, to_save="", signif=None, 
                      brokenaxes_dict=None, **kwargs):
    """Plot bar plot of feature for 2 clusters.
    Parameters:
    tab: DataFrame. Should contain columns {`feature`, 'cluster'}.
    feature: String. A column in `tab`.
    capsize: cap size of errorbar. Default 8 points.
    plot_box: Draw boxplot instead of barplot with errorbar.
    ylabel: ylabel.
    to_save: filename to save.
    signif: significant marker, default: None.
    brokenaxes_dict: Dictionary. Parameters passed to `brokenaxes()`.
                    If None, not draw broken axes. Default: None.
    **kwargs: other parameters passed to `matplotlib.pyplot.bar()` or `matplotlib.pyplot.boxplot()`."""
    plt.figure(figsize=fig_size)
    if brokenaxes_dict is None:
        ax = plt.gca()
    else:
        ax = brokenaxes(**brokenaxes_dict)
    if plot_box:
        x = [ind+1 for ind, _ in enumerate(clusters)]
        y = [tab[tab['cluster']==c][feature].dropna().values for c in clusters]
        box_prop=ax.boxplot(y, **kwargs)
    else:
        x = [ind for ind, _ in enumerate(clusters)]
        y = [np.nanmean(tab[tab['cluster']==c][feature]) for c in clusters]
        err = [sem(tab[tab['cluster']==c][feature], nan_policy='omit') for c in clusters]
        ax.bar(x, y, yerr=err, capsize=capsize, **kwargs)
    if xticks is None:
        xticks = ['cluster {}'.format(c+1) for c in clusters]
    plt.xticks(x, xticks)
    if ylabel is None:
        ylabel=feature
    plt.ylabel(ylabel)
    plt.tight_layout()
    if bool(to_save):
        to_save_figure(to_save)
コード例 #14
0
 def plot_cluster_scatter(self, item1, item2, to_save="", ca_dat=None):
     """Plot cluster analysis scatter plot `item1` vs. `item2`.
     Parameters:
     item1, item2: columns string from cluster_processor.k_means() or cluster_processor.ca_dat
     ca_dat: result from cluster_processor.k_means()"""
     if ca_dat is None:
         ca_dat = self.k_means() if self.ca_dat is None else self.ca_dat.copy()
     _color = ("C{}".format(i) for i in range(9))
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
     for clust in ca_dat.cluster.unique():
         clust_row = ca_dat.cluster.values == clust
         ax.scatter(
             ca_dat[clust_row][item1],
             ca_dat[clust_row][item2],
             color=next(_color),
             label='cluster {}'.format(clust + 1),
         )
     ax.legend()
     plt.xlabel(item1)
     plt.ylabel(item2)
     if bool(to_save):
         to_save_figure(to_save)
コード例 #15
0
 def plot_demo_ramp(self,
                    file_id=0,
                    cell_id=None,
                    sweep='all',
                    fig_size=(3, 3),
                    to_save="",
                    aspect_ratio=150,
                    sweepYcolor='C0',
                    sweepCcolor='C0'):
     """Plot voltage reaction to ramp current.
     parameters: see `ap_parser().plot_demo_ap()`."""
     if cell_id is not None:
         file_id = self.get_file_id(cell_id)
     fig = plt.figure(figsize=fig_size)
     ax1 = fig.add_subplot(211)
     ax2 = fig.add_subplot(212)
     abf = pyabf.ABF(self.files["abf"][file_id])
     if sweep == 'all':
         for sweep in abf.sweepList:
             abf.setSweep(sweep)
             ax1.plot(abf.sweepX, abf.sweepY, sweepYcolor)
             ax2.plot(abf.sweepX, abf.sweepC, sweepCcolor)
     else:
         abf.setSweep(sweep - 1)
         ax1.plot(abf.sweepX, abf.sweepY, sweepYcolor)
         ax2.plot(abf.sweepX, abf.sweepC, sweepCcolor)
     asp1 = get_aspect_scale(ax1)
     ax2.set_aspect(asp1[1] * aspect_ratio / asp1[0])
     ax1.axis("off")
     ax2.axis("off")
     add_scalebar(ax2, (1, 0.2), ('s', 'pA'), fig, y_label='200pA/\n30mV')
     with warnings.catch_warnings():
         warnings.simplefilter('ignore')
         plt.tight_layout()
     if bool(to_save):
         to_save_figure(to_save)
コード例 #16
0
def plot_cluster_sholl(
    morpho_parser,
    cluster1_ids, cluster2_ids,
    clusters=(0,1),
    label_map=None,
    sholl_part=True,
    fig_size=(3,3),
    to_save="",
):
    """Plot sholl analysis of apical and basal parts between two clusters.
    morpho_parser: `morpho_parser` object.
    clusters: clusters to show in plot labels.
    cluster1_ids, cluster2_ids: neurons IDs of two clusters.
    """
    if not sholl_part:
        imarisData = morpho_parser.get_imaris_stat()
        shollData_clust1 = imarisData[
            (imarisData.Variable == "Filament No. Sholl Intersections") & imarisData.neuron_ID.isin(cluster1_ids)
        ].loc[:, ["neuron_ID", "Radius", "Value"]]
        shollData_clust1['cluster'] = clusters[0]
        shollData_clust1 = zero_padding(shollData_clust1, 'Value')
        shollData_clust2 = imarisData[
            (imarisData.Variable == "Filament No. Sholl Intersections") & imarisData.neuron_ID.isin(cluster2_ids)
        ].loc[:, ["neuron_ID", "Radius", "Value"]]
        shollData_clust2['cluster'] = clusters[1]
        shollData_clust2 = zero_padding(shollData_clust2, 'Value')
        shollData = pd.concat(
            [shollData_clust1, shollData_clust2], sort=False,
            ignore_index=True
        )
        shollPlotData = (
            shollData
            .groupby(["cluster","Radius"])
            .agg([np.mean, sem])
        )
        plt.figure(figsize=fig_size)
        line1,=plt.plot(
            shollPlotData.loc[clusters[0],:].index,
            shollPlotData.loc[clusters[0],('Value','mean')],
        )
        line2,=plt.plot(
            shollPlotData.loc[clusters[1],:].index,
            shollPlotData.loc[clusters[1],('Value','mean')],
        )
        plt.fill_between(
            shollPlotData.loc[clusters[0],:].index,
            shollPlotData.loc[clusters[0],('Value', 'mean')]+shollPlotData.loc[clusters[0],('Value','sem')],
            shollPlotData.loc[clusters[0],('Value', 'mean')]-shollPlotData.loc[clusters[0],('Value','sem')],
            alpha=0.6,
        )
        p1=mpatch.Patch(color='C0', alpha=0.6)
        plt.fill_between(
            shollPlotData.loc[clusters[1],:].index,
            shollPlotData.loc[clusters[1],('Value', 'mean')]+shollPlotData.loc[clusters[1],('Value','sem')],
            shollPlotData.loc[clusters[1],('Value', 'mean')]-shollPlotData.loc[clusters[1],('Value','sem')],
            alpha=0.6,
        )
        p2=mpatch.Patch(color='C1', alpha=0.6)
        plt.ylabel('Sholl intersections')
        plt.xlabel('$Radius\ (\mu m)$')
        ax=plt.gca()
        if label_map is None:
            label_map={c:'clusters {}'.format(c+1) for c in clusters}
        plt.legend(
            ((line1, p1),(line2,p2)),
            (label_map[clusters[0]], label_map[clusters[1]])
        )
        if bool(to_save):
            to_save_figure(to_save)
コード例 #17
0
        basal_points = np.concatenate([x.points for x in nm.iter_neurites(neuron, filt=lambda t: t.type==nm.BASAL_DENDRITE)])
        basal_label_pos_xs = [basal_points[:,0].min()-center[0],basal_points[:,0].max()-center[0]]
        basal_label_pos_x = basal_label_pos_xs[0] if np.abs(basal_label_pos_xs[0])>np.abs(basal_label_pos_xs[1]) else basal_label_pos_xs[1]
        basal_label_pos_y = center[1]+(basal_points[:,1].min()-center[1])/2
        label_dict = {'Apical':(apical_label_pos_x, apical_label_pos_y), 'Basal': (basal_label_pos_x, basal_label_pos_y)}

    for name,pos in label_dict.items():
        plt.annotate(name, pos)

    ax.autoscale()
    ax.set_axis_off()
    plt.title(None)
    if bool(to_save):
<<<<<<< Updated upstream
=======
        to_save_figure(to_save)

def plot_single_neuron(neuron, put_apical_upside=False, to_save=""):
    """Plot a neuron.
    neuron: [neurom.fst._core.FstNeuron] The neuron to plot.
    put_apical_upside: logical. Whether put apical dendrite upside."""
    fig, ax = plt.subplots(subplot_kw={'aspect':'equal'})
    if put_apical_upside:
        neuron=apical_upside(neuron)
    view.plot_neuron(ax, neuron)
    ax.autoscale()
    ax.set_title("")
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    ax.set_axis_off()
    if bool(to_save):