コード例 #1
0
ファイル: pwaveform.py プロジェクト: mnislamraju/seal
def plot_wfs(waveforms,
             tvec,
             cols=None,
             lw=0.1,
             alpha=0.05,
             xlim=None,
             ylim=None,
             title=None,
             xlab=None,
             ylab=None,
             ffig=None,
             ax=None,
             **kwargs):
    """
    Plot large set of waveforms efficiently as LineCollections.

    waveforms: waveform matrix of (waveform x time sample)
    tvec: vector of sample times
    cols: color matrix of (waveform x RGBA)

    """

    # Init.
    nwfs = waveforms.shape[0]
    ntsamp = tvec.size
    ax = putil.axes(ax)

    if cols is None:
        cols = np.tile(putil.convert_to_rgba('g'), (nwfs, 1))

    # Plot all waveforms efficiently at the same time as LineCollections.
    # Reformat waveform matrix and tvec vector into format:
    # [[(t0, v0), (t1, v1)], [(t0, v0), (t1, v1)], ...]
    wf_col = waveforms.reshape((-1, 1))
    tvec_col = np.tile(np.array(tvec), (nwfs, 1)).reshape((-1, 1))
    tvec_wf_cols = np.hstack((tvec_col, wf_col)).reshape(-1, 1, 2)
    tv_segments = np.hstack([tvec_wf_cols[:-1], tvec_wf_cols[1:]])
    btw_wf_mask = ntsamp * np.arange(1, nwfs) - 1
    tv_segments = np.delete(tv_segments, btw_wf_mask, axis=0)

    # Set color of each segment.
    cols_segments = np.repeat(cols, ntsamp - 1, axis=0)

    # Create and add LineCollection to axes.
    lc = LineCollection(tv_segments,
                        linewidths=lw,
                        colors=cols_segments,
                        **kwargs)
    ax.add_collection(lc)
    # Need to update view manually after adding artists manually.
    ax.autoscale_view()

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
コード例 #2
0
def plot_mean_rates(mRates,
                    aa_res_dir,
                    tasks=None,
                    task_lbls=None,
                    baseline=None,
                    xlim=None,
                    ylim=None,
                    ci=68,
                    ffig=None,
                    figsize=(6, 4)):
    """Plot mean rates across tasks."""

    # Init.
    if tasks is None:
        tasks = mRates.keys()

    # Plot mean activity.
    lRates = []
    for task in tasks:
        lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate'])
        lrates['task'] = task
        lRates.append(lrates)
    lRates = pd.concat(lRates)
    lRates['time'] = lRates.index.get_level_values(0)
    lRates['unit'] = lRates.index.get_level_values(1)

    if task_lbls is not None:
        lRates.task.replace(task_lbls, inplace=True)

    # Plot as time series.
    #putil.set_style('notebook', 'white')
    fig = putil.figure(figsize=figsize)
    ax = putil.axes()

    sns.tsplot(lRates,
               time='time',
               value='rate',
               unit='unit',
               condition='task',
               ci=ci,
               ax=ax)

    # Add periods and baseline.
    putil.plot_periods(ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Format plot.
    sns.despine(ax=ax)
    putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)')
    putil.set_limits(ax, xlim, ylim)
    putil.hide_legend_title(ax)

    # Save plot.
    putil.save_fig(ffig, fig)

    return fig, ax
コード例 #3
0
def raster(spk_trains,
           t_unit=ms,
           prds=None,
           c='b',
           xlim=None,
           title=None,
           xlab=None,
           ylab=None,
           ffig=None,
           ax=None):
    """Plot rasterplot."""

    # Init.
    ax = putil.axes(ax)

    putil.plot_periods(prds, ax=ax)
    putil.set_limits(ax, xlim)

    # There's nothing to plot.
    if not len(spk_trains):
        return ax

    # Plot raster.
    for i, spk_tr in enumerate(spk_trains):
        x = np.array(spk_tr.rescale(t_unit))
        y = (i + 1) * np.ones_like(x)

        # Spike markers are plotted in absolute size (figure coordinates).
        # ax.scatter(x, y, c=c, s=1.8, edgecolor=c, marker='|')

        # Spike markers are plotted in relative size (axis coordinates)
        patches = [
            Rectangle((xi - wsp / 2, yi - hsp / 2), wsp, hsp)
            for xi, yi in zip(x, y)
        ]
        collection = PatchCollection(patches, facecolor=c, edgecolor=c)
        ax.add_collection(collection)

    # Format plot.
    ylim = [0.5, len(spk_trains) + 0.5] if len(spk_trains) else [0, 1]
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.hide_axes(ax, show_x=True)
    putil.hide_spines(ax)

    # Order trials from top to bottom, only after setting axis limits.
    ax.invert_yaxis()

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
コード例 #4
0
ファイル: test_RF.py プロジェクト: mnislamraju/seal
def plot_RF_results(RF_res, stims, fdir, sup_title):
    """Plot receptive field results."""

    # Plot distribution of coverage values.
    putil.set_style('poster', 'white')
    fig = putil.figure()
    ax = putil.axes()
    sns.distplot(RF_res.S1_cover, bins=np.arange(0, 1.01, 0.1),
                 kde=False, rug=True, ax=ax)
    putil.set_limits(ax, [0, 1])
    fst = util.format_to_fname(sup_title)
    ffig = fdir + fst + '_S1_coverage.png'
    putil.save_fig(ffig, fig, sup_title)

    # Plot RF coverage and rate during S1 on regression plot for each
    # recording and task.
    tasks = RF_res.index.get_level_values(-1).unique()
    for vname, ylim in [('mean_rate', [0, None]), ('max_rate', [0, None]),
                        ('mDSI', [0, 1])]:
        fig, gs, axes = putil.get_gs_subplots(nrow=len(stims), ncol=len(tasks),
                                              subw=4, subh=4, ax_kws_list=None,
                                              create_axes=True)
        colors = sns.color_palette('muted', len(tasks))
        for istim, stim in enumerate(stims):
            for itask, task in enumerate(tasks):
                # Plot regression plot.
                ax = axes[istim, itask]
                scov, sval = [stim + '_' + name for name in ('cover', vname)]
                df = RF_res.xs(task, level=-1)
                sns.regplot(scov, sval, df, color=colors[itask], ax=ax)
                # Add unit labels.
                uids = df.index.droplevel(0)
                putil.add_unit_labels(ax, uids, df[scov], df[sval])
                # Add stats.
                r, p = sp.stats.pearsonr(df[sval], df[scov])
                pstr = util.format_pvalue(p)
                txt = 'r = {:.2f}, {}'.format(r, pstr)
                ax.text(0.02, 0.98, txt, va='top', ha='left',
                        transform=ax.transAxes)
                # Set labels.
                title = '{} {}'.format(task, stim)
                xlab, ylab = [sn.replace('_', ' ') for sn in (scov, sval)]
                putil.set_labels(ax, xlab, ylab, title)
                # Set limits.
                xlim = [0, 1]
                putil.set_limits(ax, xlim, ylim)

        # Save plot.
        fst = util.format_to_fname(sup_title)
        fname = '{}_cover_{}.png'.format(fst, vname)
        ffig = util.join([fdir, vname, fname])
        putil.save_fig(ffig, fig, sup_title)
コード例 #5
0
def plot_tuning(xfit,
                yfit,
                vals=None,
                meanr=None,
                semr=None,
                color='b',
                baseline=None,
                xticks=None,
                xlim=None,
                ylim=None,
                xlab=None,
                ylab=None,
                title=None,
                ffig=None,
                ax=None,
                **kwargs):
    """Plot tuning curve, optionally with data samples."""

    ax = putil.axes(ax)

    # Plot baseline.
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Plot fitted curve.
    pplot.lines(xfit, yfit, color=color, ax=ax)

    # Plot data samples.
    if meanr is not None and semr is not None:
        pplot.errorbar(vals,
                       meanr,
                       yerr=semr,
                       fmt='o',
                       color=color,
                       ax=ax,
                       **kwargs)

    # Set x axis ticks.
    if xticks is not None:
        putil.set_xtick_labels(ax, xticks)
    elif vals is not None:
        putil.set_xtick_labels(ax, vals)

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
コード例 #6
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def scatter(x,
            y,
            is_sign=None,
            c='b',
            bc='w',
            nc='grey',
            ec='k',
            alpha=0.5,
            xlim=None,
            ylim=None,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            polar=False,
            id_line=True,
            ffig=None,
            ax=None,
            **kwargs):
    """
    Plot paired data on scatter plot.
    Color settings:
        - c:  face color of significant points (is_sign == True)
        - bc: face color of non-significant points (is_sign == False)
        - nc: face color of untested/untestable points (is_sign == None)
    """

    # Init.
    ax = putil.axes(ax, polar=polar)
    cols = c
    # Get point-specific color array.
    if (is_sign is not None) and isinstance(c, str) and isinstance(bc, str):
        cols = putil.get_cmat(is_sign, c, bc, nc)

    # Plot colored points.
    ax.scatter(x, y, c=cols, edgecolor=ec, alpha=alpha, **kwargs)

    # Add identity line.
    if id_line:
        putil.add_identity_line(equal_xy=True, zorder=99, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #7
0
ファイル: pauc.py プロジェクト: mnislamraju/seal
def plot_auc_over_time(auc,
                       tvec,
                       prds=None,
                       evts=None,
                       xlim=None,
                       ylim=None,
                       xlab='time',
                       ylab='AUC',
                       title=None,
                       ax=None):
    """Plot AROC values over time."""

    # Init params.
    ax = putil.axes(ax)
    if xlim is None:
        xlim = [min(tvec), max(tvec)]

    # Plot periods first.
    putil.plot_periods(prds, ax=ax)

    # Plot AUC over time.
    pplot.lines(tvec, auc, ylim, xlim, xlab, ylab, title, color='green', ax=ax)

    # Add chance level line.
    putil.add_chance_level(ax=ax)

    #    # Set minimum y axis scale.
    #    ymin, ymax = ax.get_ylim()
    #    ymin, ymax = min(ymin, 0.3), max(ymax, 0.7)
    #    ax.set_ylim([ymin, ymax])

    # Set y tick labels.
    if ylim is not None and ylim[0] == 0 and ylim[1] == 1:
        tck_marks = np.linspace(0, 1, 5)
        tck_lbls = np.array(tck_marks, dtype=str)
        tck_lbls[1::2] = ''
        putil.set_ytick_labels(ax, tck_marks, tck_lbls)
    putil.set_max_n_ticks(ax, 5, 'y')

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    return ax
コード例 #8
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def sign_scatter(v1,
                 v2,
                 pvals=None,
                 pth=0.01,
                 scol='g',
                 nscol='k',
                 id_line=False,
                 fit_reg=False,
                 ax=None):
    """Plot scatter plot with significant points highlighted."""

    # Init.
    ax = putil.axes(ax)
    s_pars = (True, scol, {'alpha': 1.0})
    ns_pars = (False, nscol, {'alpha': 0.8})

    # Binarize significance stats.
    vsig = (pvals < pth if pvals is not None else pd.Series(True,
                                                            index=v1.index))

    # Plot significant and non-significant points.
    for b, c, a in [ns_pars, s_pars]:
        if (vsig == b).any():
            sns.regplot(v1.loc[vsig == b],
                        v2.loc[vsig == b],
                        fit_reg=fit_reg,
                        color=c,
                        scatter_kws=a,
                        ax=ax)

    # Format plot.
    sns.despine(ax=ax)

    # Add identity line.
    if id_line:
        v_max = max(ax.get_xlim()[1], ax.get_ylim()[1])
        putil.set_limits(ax, [0, v_max], [0, v_max])
        putil.add_identity_line(ax=ax, equal_xy=True)

    return ax
コード例 #9
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def multi_hist(vals,
               xlim=None,
               ylim=None,
               xlab=None,
               ylab='n',
               title=None,
               ytitle=None,
               polar=False,
               ffig=None,
               ax=None,
               **kwargs):
    """Plot histogram of multiple samples side by side."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.hist(vals, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #10
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def cat_hist(vals,
             xlim=None,
             ylim=None,
             xlab=None,
             ylab='n',
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             **kwargs):
    """Plot histogram of categorical data."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax = sns.countplot(vals, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #11
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def heatmap(mat,
            vmin=None,
            vmax=None,
            cmap=None,
            cbar=True,
            cbar_ax=None,
            annot=None,
            square=False,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            xlim=None,
            ylim=None,
            xticklabels=True,
            yticklabels=True,
            ffig=None,
            ax=None):
    """Plot rectangular data as heatmap."""

    # Plot data.
    ax = putil.axes(ax)
    sns.heatmap(mat,
                vmin,
                vmax,
                cmap,
                annot=annot,
                cbar=cbar,
                cbar_ax=cbar_ax,
                square=square,
                xticklabels=xticklabels,
                yticklabels=yticklabels,
                ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #12
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def bars(x,
         y,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot bar plot."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.bar(x, y, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #13
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def mean_err(x,
             ymean,
             ystd,
             ylim=None,
             xlim=None,
             xlab=None,
             ylab=None,
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             mean_kws=None,
             band_kws=None):
    """Plot mean and highlighted band area around it."""

    # Init params.
    if mean_kws is None:
        mean_kws = dict()
    if band_kws is None:
        band_kws = dict()

    # Init data.
    ylower = ymean - ystd
    yupper = ymean + ystd

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    lines(x, ymean, ax=ax, **mean_kws)
    band(x, ylower, yupper, ax=ax, **band_kws)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #14
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def sign_hist(v,
              pvals=None,
              pth=0.01,
              bins=None,
              scol='g',
              nscol='k',
              ax=None):
    """Plot histogram of significant and non-significant values stacked."""

    # Init.
    ax = putil.axes(ax)
    vnonsig = pvals >= pth

    # Plot all values and then non-significant values only.
    sns.distplot(v, kde=False, bins=bins, color=scol, ax=ax)
    sns.distplot(v[vnonsig], kde=False, bins=bins, color=nscol, ax=ax)

    # Add vertical zero line.
    ax.axvline(0, color='gray', lw=1, ls='dashed')

    # Format plot.
    sns.despine(ax=ax)

    return ax
コード例 #15
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def band(x,
         ylower,
         yupper,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot highlighted band area."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.fill_between(x, ylower, yupper, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
コード例 #16
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def plot_group_violin(res,
                      x,
                      y,
                      groups=None,
                      npval=None,
                      pth=0.01,
                      color='grey',
                      ylim=None,
                      ylab=None,
                      ffig=None):
    """Plot group-wise results on violin plots."""

    if groups is None:
        groups = res['group'].unique()

    # Test difference from zero in each groups.
    ttest_res = {
        group: sp.stats.ttest_1samp(gres[y], 0)
        for group, gres in res.groupby(x)
    }
    ttest_res = pd.DataFrame.from_dict(ttest_res, 'index')

    # Binarize significance test.
    res['is_sign'] = res[npval] < pth if npval is not None else True
    res['direction'] = np.sign(res[y])

    # Set up figure and plot data.
    fig = putil.figure()
    ax = putil.axes()
    putil.add_baseline(ax=ax)
    sns.violinplot(x=x, y=y, data=res, inner=None, order=groups, ax=ax)
    sns.swarmplot(x=x,
                  y=y,
                  hue='is_sign',
                  data=res,
                  color=color,
                  order=groups,
                  hue_order=[True, False],
                  ax=ax)
    putil.set_labels(ax, xlab='', ylab=ylab)
    putil.set_limits(ax, ylim=ylim)
    putil.hide_legend(ax)

    # Add annotations.
    ymin, ymax = ax.get_ylim()
    ylvl = ymax
    for i, group in enumerate(groups):
        gres = res.loc[res.group == group]
        # Mean.
        mean_str = 'Mean:\n' if i == 0 else '\n'
        mean_str += '{:.2f}'.format(gres[y].mean())
        # Non-zero test of distribution.
        str_pval = util.format_pvalue(ttest_res.loc[group, 'pvalue'])
        mean_str += '\n({})'.format(str_pval)
        # Stats on difference from baseline.
        nnonsign, ntot = (~gres.is_sign).sum(), len(gres)
        npos, nneg = [
            sum(gres.is_sign & (gres.direction == d)) for d in (1, -1)
        ]
        sign_to_report = [('+', npos), ('=', nnonsign), ('-', nneg)]
        nsign_str = ''
        for symb, n in sign_to_report:
            prc = str(int(round(100 * n / ntot)))
            nsign_str += '{} {:>3} / {} ({:>2}%)\n'.format(
                symb, int(n), ntot, prc)
        lbl = '{}\n\n{}'.format(mean_str, nsign_str)
        ax.text(i, ylvl, lbl, fontsize='smaller', va='bottom', ha='center')

    # Save plot.
    putil.save_fig(ffig, fig)

    return fig, ax
コード例 #17
0
def plot_DR(dirs,
            resp,
            DSI=None,
            PD=None,
            baseline=None,
            plot_type='line',
            complete_missing_dirs=False,
            color='b',
            title=None,
            ffig=None,
            ax=None):
    """
    Plot response to each directions on polar plot, with a vector pointing to
    preferred direction (PD) with length DSI.
    Use plot_type to change between sector ('bar') and connected ('line') plot
    types.
    """

    ax = putil.axes(ax, polar=True)

    # Plot baseline.
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Remove NaNs.
    all_dirs = dirs
    not_nan = np.array(~pd.isnull(dirs) & ~pd.isnull(resp))
    dirs, resp = dirs[not_nan], np.array(resp[not_nan])

    # Prepare data.
    # Complete missing directions with 0 response.
    if complete_missing_dirs:
        for i, d in enumerate(all_dirs):
            if d not in dirs:
                dirs = np.insert(dirs, i, d) * dirs.units
                resp = np.insert(resp, i, 0) * 1 / s

    rad_dirs = dirs.rescale(rad)

    # Plot response to each directions on polar plot.
    if plot_type == 'bar':  # sector plot
        ndirs = all_dirs.size
        left_rad_dirs = rad_dirs - np.pi / ndirs  # no need for this in MPL 2.0?
        w = 2 * np.pi / ndirs  # same with edgecolor and else?
        pplot.bars(left_rad_dirs,
                   resp,
                   width=w,
                   alpha=0.50,
                   color=color,
                   lw=1,
                   edgecolor='w',
                   title=title,
                   ytitle=1.08,
                   ax=ax)
    else:  # line plot
        rad_dirs, resp = [np.append(v, [v[0]]) for v in (rad_dirs, resp)]
        pplot.lines(rad_dirs,
                    resp,
                    color=color,
                    marker='o',
                    lw=1,
                    ms=4,
                    mew=0,
                    title=title,
                    ytitle=1.08,
                    ax=ax)
        ax.fill(rad_dirs, resp, color=color, alpha=0.15)

    # Add arrow representing PD and weighted DSI.
    if DSI is not None and PD is not None:
        rho = np.max(resp) * DSI
        xy = (float(PD.rescale(rad)), rho)
        arr_props = dict(facecolor=color, edgecolor='k', shrink=0.0, alpha=0.5)
        ax.annotate('', xy, xytext=(0, 0), arrowprops=arr_props)

    # Remove spines and tick marks, maximize tick labels.
    putil.set_spines(ax, False, False)
    putil.hide_tick_marks(ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
コード例 #18
0
ファイル: pplot.py プロジェクト: mnislamraju/seal
def cat_mean(df,
             x,
             y,
             add_stats=True,
             fstats=None,
             bar_ylvl=None,
             ci=68,
             add_mean=True,
             mean_ylvl=None,
             ylbl=None,
             fig=None,
             ax=None,
             ffig=None):
    """Plot mean of two categorical dataset."""

    # Init.
    if fig is None and ax is None:
        fig = putil.figure(figsize=(3, 3))
    if ax is None:
        ax = putil.axes()
    if fstats is None:
        fstats = stats.mann_whithney_u_test

    # Plot means as bars.
    sns.barplot(x=x,
                y=y,
                data=df,
                ci=ci,
                ax=ax,
                palette=palette,
                errwidth=errwidth,
                **kwargs)

    # Get plotted vectors.
    ngrps = [t.get_text() for t in ax.get_xticklabels()]
    v1, v2 = [df.loc[df[x] == ngrp, y] for ngrp in ngrps]

    # Add significance bar.
    if add_stats:
        _, pval = fstats(v1, v2)
        pval_str = util.format_pvalue(pval)
        if bar_ylvl is None:
            bar_ylvl = 1.1 * max(v1.mean() + stats.sem(v1),
                                 v2.mean() + stats.sem(v2))
        lines([0.1, 0.9], [bar_ylvl, bar_ylvl], color='grey', ax=ax)
        ax.text(0.5,
                1.01 * bar_ylvl,
                pval_str,
                fontsize='medium',
                fontstyle='italic',
                va='bottom',
                ha='center')

    # Add mean values.
    for vec, xpos in [(v1, 0.2), (v2, 1.2)]:
        mstr = '{:.2f}'.format(vec.mean())
        ypos = 1.005 * vec.mean()
        ax.text(xpos,
                ypos,
                mstr,
                fontstyle='italic',
                fontsize='smaller',
                va='bottom',
                ha='center')

    # Format plot.
    sns.despine()
    putil.hide_legend_title(ax)
    putil.set_labels(ax, '', ylbl)
    putil.sparsify_tick_labels(fig, ax, 'y', freq=2)
    # Save plot.
    putil.save_fig(ffig, fig)

    return ax
コード例 #19
0
def rate(rate_list,
         names=None,
         prds=None,
         evts=None,
         cols=None,
         baseline=None,
         pval=0.05,
         test='mann_whitney_u',
         test_kws=None,
         xlim=None,
         ylim=None,
         title=None,
         xlab=None,
         ylab=putil.FR_lbl,
         add_lgn=True,
         lgn_lbl='trs',
         ffig=None,
         ax=None):
    """Plot firing rate."""

    # Init.
    ax = putil.axes(ax)
    if test_kws is None:
        test_kws = dict()

    # Plot periods and baseline first.
    putil.plot_periods(prds, ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)
    putil.set_limits(ax, xlim)

    if not len(rate_list):
        return ax

    if cols is None:
        cols = putil.get_colors(as_cycle=False)
    if names is None:
        names = len(rate_list) * ['']

    # Iterate through list of rate arrays
    xmin, xmax, ymax = None, None, None
    for i, rts in enumerate(rate_list):

        # Init.
        name = names[i]
        col = cols[i]

        # Skip empty array (no trials).
        if not rts.shape[0]:
            continue

        # Set line label. Convert to Numpy array to format floats nicely.
        lbl = str(np.array(name)) if util.is_iterable(name) else str(name)

        if lgn_lbl is not None:
            lbl += ' ({} {})'.format(rts.shape[0], lgn_lbl)

        # Plot mean +- SEM of rate vectors.
        tvec, meanr, semr = rts.columns, rts.mean(), rts.sem()
        ax.plot(tvec, meanr, label=lbl, color=col)
        ax.fill_between(tvec,
                        meanr - semr,
                        meanr + semr,
                        alpha=0.2,
                        facecolor=col,
                        edgecolor=col)

        # Update limits.
        tmin, tmax, rmax = tvec.min(), tvec.max(), (meanr + semr).max()
        xmin = np.min([xmin, tmin]) if xmin is not None else tmin
        xmax = np.max([xmax, tmax]) if xmax is not None else tmax
        ymax = np.max([ymax, rmax]) if ymax is not None else rmax

    # Set ticks, labels and axis limits.
    if xlim is None:
        if xmin == xmax:  # avoid setting identical limits
            xmax = None
        xlim = (xmin, xmax)
    if ylim is None:
        ymax = 1.02 * ymax if (ymax is not None) and (ymax > 0) else None
        ylim = (0, ymax)
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    t1, t2 = ax.get_xlim()  # in case it was set to None
    tmarks, tlbls = putil.get_tick_marks_and_labels(t1, t2)
    putil.set_xtick_labels(ax, tmarks, tlbls)
    putil.set_max_n_ticks(ax, 7, 'y')

    # Add legend.
    if add_lgn and len(rate_list):
        putil.set_legend(ax,
                         loc=1,
                         borderaxespad=0.0,
                         handletextpad=0.4,
                         handlelength=0.6)

    # Add significance line to top of axes.
    if (pval is not None) and (len(rate_list) == 2):
        rates1, rates2 = rate_list
        sign_prds = stats.sign_periods(rates1, rates2, pval, test, **test_kws)
        putil.plot_signif_prds(sign_prds, color='m', linewidth=4.0, ax=ax)

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax