예제 #1
0
def plot_wfs(waveforms,
             tvec,
             cols=None,
             lw=0.1,
             alpha=0.05,
             xlim=None,
             ylim=None,
             title=None,
             xlab=None,
             ylab=None,
             ffig=None,
             ax=None,
             **kwargs):
    """
    Plot large set of waveforms efficiently as LineCollections.

    waveforms: waveform matrix of (waveform x time sample)
    tvec: vector of sample times
    cols: color matrix of (waveform x RGBA)

    """

    # Init.
    nwfs = waveforms.shape[0]
    ntsamp = tvec.size
    ax = putil.axes(ax)

    if cols is None:
        cols = np.tile(putil.convert_to_rgba('g'), (nwfs, 1))

    # Plot all waveforms efficiently at the same time as LineCollections.
    # Reformat waveform matrix and tvec vector into format:
    # [[(t0, v0), (t1, v1)], [(t0, v0), (t1, v1)], ...]
    wf_col = waveforms.reshape((-1, 1))
    tvec_col = np.tile(np.array(tvec), (nwfs, 1)).reshape((-1, 1))
    tvec_wf_cols = np.hstack((tvec_col, wf_col)).reshape(-1, 1, 2)
    tv_segments = np.hstack([tvec_wf_cols[:-1], tvec_wf_cols[1:]])
    btw_wf_mask = ntsamp * np.arange(1, nwfs) - 1
    tv_segments = np.delete(tv_segments, btw_wf_mask, axis=0)

    # Set color of each segment.
    cols_segments = np.repeat(cols, ntsamp - 1, axis=0)

    # Create and add LineCollection to axes.
    lc = LineCollection(tv_segments,
                        linewidths=lw,
                        colors=cols_segments,
                        **kwargs)
    ax.add_collection(lc)
    # Need to update view manually after adding artists manually.
    ax.autoscale_view()

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #2
0
def plot_mean_rates(mRates,
                    aa_res_dir,
                    tasks=None,
                    task_lbls=None,
                    baseline=None,
                    xlim=None,
                    ylim=None,
                    ci=68,
                    ffig=None,
                    figsize=(6, 4)):
    """Plot mean rates across tasks."""

    # Init.
    if tasks is None:
        tasks = mRates.keys()

    # Plot mean activity.
    lRates = []
    for task in tasks:
        lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate'])
        lrates['task'] = task
        lRates.append(lrates)
    lRates = pd.concat(lRates)
    lRates['time'] = lRates.index.get_level_values(0)
    lRates['unit'] = lRates.index.get_level_values(1)

    if task_lbls is not None:
        lRates.task.replace(task_lbls, inplace=True)

    # Plot as time series.
    #putil.set_style('notebook', 'white')
    fig = putil.figure(figsize=figsize)
    ax = putil.axes()

    sns.tsplot(lRates,
               time='time',
               value='rate',
               unit='unit',
               condition='task',
               ci=ci,
               ax=ax)

    # Add periods and baseline.
    putil.plot_periods(ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Format plot.
    sns.despine(ax=ax)
    putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)')
    putil.set_limits(ax, xlim, ylim)
    putil.hide_legend_title(ax)

    # Save plot.
    putil.save_fig(ffig, fig)

    return fig, ax
예제 #3
0
파일: test_RF.py 프로젝트: mnislamraju/seal
def plot_RF_results(RF_res, stims, fdir, sup_title):
    """Plot receptive field results."""

    # Plot distribution of coverage values.
    putil.set_style('poster', 'white')
    fig = putil.figure()
    ax = putil.axes()
    sns.distplot(RF_res.S1_cover, bins=np.arange(0, 1.01, 0.1),
                 kde=False, rug=True, ax=ax)
    putil.set_limits(ax, [0, 1])
    fst = util.format_to_fname(sup_title)
    ffig = fdir + fst + '_S1_coverage.png'
    putil.save_fig(ffig, fig, sup_title)

    # Plot RF coverage and rate during S1 on regression plot for each
    # recording and task.
    tasks = RF_res.index.get_level_values(-1).unique()
    for vname, ylim in [('mean_rate', [0, None]), ('max_rate', [0, None]),
                        ('mDSI', [0, 1])]:
        fig, gs, axes = putil.get_gs_subplots(nrow=len(stims), ncol=len(tasks),
                                              subw=4, subh=4, ax_kws_list=None,
                                              create_axes=True)
        colors = sns.color_palette('muted', len(tasks))
        for istim, stim in enumerate(stims):
            for itask, task in enumerate(tasks):
                # Plot regression plot.
                ax = axes[istim, itask]
                scov, sval = [stim + '_' + name for name in ('cover', vname)]
                df = RF_res.xs(task, level=-1)
                sns.regplot(scov, sval, df, color=colors[itask], ax=ax)
                # Add unit labels.
                uids = df.index.droplevel(0)
                putil.add_unit_labels(ax, uids, df[scov], df[sval])
                # Add stats.
                r, p = sp.stats.pearsonr(df[sval], df[scov])
                pstr = util.format_pvalue(p)
                txt = 'r = {:.2f}, {}'.format(r, pstr)
                ax.text(0.02, 0.98, txt, va='top', ha='left',
                        transform=ax.transAxes)
                # Set labels.
                title = '{} {}'.format(task, stim)
                xlab, ylab = [sn.replace('_', ' ') for sn in (scov, sval)]
                putil.set_labels(ax, xlab, ylab, title)
                # Set limits.
                xlim = [0, 1]
                putil.set_limits(ax, xlim, ylim)

        # Save plot.
        fst = util.format_to_fname(sup_title)
        fname = '{}_cover_{}.png'.format(fst, vname)
        ffig = util.join([fdir, vname, fname])
        putil.save_fig(ffig, fig, sup_title)
예제 #4
0
def raster(spk_trains,
           t_unit=ms,
           prds=None,
           c='b',
           xlim=None,
           title=None,
           xlab=None,
           ylab=None,
           ffig=None,
           ax=None):
    """Plot rasterplot."""

    # Init.
    ax = putil.axes(ax)

    putil.plot_periods(prds, ax=ax)
    putil.set_limits(ax, xlim)

    # There's nothing to plot.
    if not len(spk_trains):
        return ax

    # Plot raster.
    for i, spk_tr in enumerate(spk_trains):
        x = np.array(spk_tr.rescale(t_unit))
        y = (i + 1) * np.ones_like(x)

        # Spike markers are plotted in absolute size (figure coordinates).
        # ax.scatter(x, y, c=c, s=1.8, edgecolor=c, marker='|')

        # Spike markers are plotted in relative size (axis coordinates)
        patches = [
            Rectangle((xi - wsp / 2, yi - hsp / 2), wsp, hsp)
            for xi, yi in zip(x, y)
        ]
        collection = PatchCollection(patches, facecolor=c, edgecolor=c)
        ax.add_collection(collection)

    # Format plot.
    ylim = [0.5, len(spk_trains) + 0.5] if len(spk_trains) else [0, 1]
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.hide_axes(ax, show_x=True)
    putil.hide_spines(ax)

    # Order trials from top to bottom, only after setting axis limits.
    ax.invert_yaxis()

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #5
0
파일: pauc.py 프로젝트: mnislamraju/seal
def plot_auc_heatmap(aroc_mat,
                     cmap='viridis',
                     events=None,
                     xlbl_freq=500,
                     ylbl_freq=10,
                     xlab='time',
                     ylab='unit index',
                     title='AROC over time',
                     ffig=None,
                     fig=None):
    """Plot ROC AUC of list of units on heatmap."""

    fig = putil.figure(fig)

    # Plot heatmap.
    yticklabels = np.arange(len(aroc_mat.index)) + 1
    ax = pplot.heatmap(aroc_mat,
                       vmin=0,
                       vmax=1,
                       cmap=cmap,
                       xlab=xlab,
                       ylab=ylab,
                       title=title,
                       yticklabels=yticklabels)

    # Format labels.
    xlbls = pd.Series(aroc_mat.columns.map(str))
    xlbls[aroc_mat.columns % xlbl_freq != 0] = ''
    putil.set_xtick_labels(ax, lbls=xlbls)
    putil.rot_xtick_labels(ax, rot=0, ha='center')
    putil.sparsify_tick_labels(fig,
                               ax,
                               'y',
                               istart=ylbl_freq - 1,
                               freq=ylbl_freq,
                               reverse=True)
    putil.hide_tick_marks(ax)
    putil.hide_spines(ax)

    # Plot events.
    if events is not None:
        putil.plot_events(events,
                          add_names=False,
                          color='black',
                          alpha=0.3,
                          ls='-',
                          lw=1,
                          ax=ax)

    # Save plot.
    putil.save_fig(ffig, dpi=300)
예제 #6
0
파일: pauc.py 프로젝트: mnislamraju/seal
def plot_ROC_mean(d_faroc,
                  t1=None,
                  t2=None,
                  ylim=None,
                  colors=None,
                  ylab='AROC',
                  ffig=None):
    """Plot mean ROC curves over given period."""

    # Import results.
    d_aroc = {}
    for name, faroc in d_faroc.items():
        aroc = util.read_objects(faroc, 'aroc')
        d_aroc[name] = aroc.unstack().T

    # Format results.
    laroc = pd.DataFrame(pd.concat(d_aroc), columns=['aroc'])
    laroc['task'] = laroc.index.get_level_values(0)
    laroc['time'] = laroc.index.get_level_values(1)
    laroc['unit'] = laroc.index.get_level_values(2)
    laroc.index = np.arange(len(laroc.index))

    # Init figure.
    fig = putil.figure(figsize=(6, 6))
    ax = sns.tsplot(laroc,
                    time='time',
                    value='aroc',
                    unit='unit',
                    condition='task',
                    color=colors)

    # Highlight stimulus periods.
    putil.plot_periods(ax=ax)

    # Plot mean results.
    [ax.lines[i].set_linewidth(3) for i in range(len(ax.lines))]

    # Add chance level line.
    putil.add_chance_level(ax=ax, alpha=0.8, color='k')
    ax.lines[-1].set_linewidth(1.5)

    # Format plot.
    xlab = 'Time since S1 onset (ms)'
    putil.set_labels(ax, xlab, ylab)
    putil.set_limits(ax, [t1, t2], ylim)
    putil.set_spines(ax, bottom=True, left=True, top=False, right=False)
    putil.set_legend(ax, loc=0)

    # Save plot.
    putil.save_fig(ffig, fig, ytitle=1.05, w_pad=15)
예제 #7
0
def raster_rate(spk_list,
                rate_list,
                names=None,
                prds=None,
                evts=None,
                cols=None,
                baseline=None,
                title=None,
                rs_ylab=True,
                rate_kws=None,
                fig=None,
                ffig=None,
                sps=None):
    """Plot raster and rate plots."""

    if rate_kws is None:
        rate_kws = dict()

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 2, 1, height_ratios=[.66, 1], hspace=.15)
    n_sets = max(len(spk_list), 1)  # let's add an empty axes if no data
    gsp_raster = putil.embed_gsp(gsp[0], n_sets, 1, hspace=.15)
    gsp_rate = putil.embed_gsp(gsp[1], 1, 1)

    # Init colors.
    if cols is None:
        col_cyc = putil.get_colors(mpl_colors=True)
        cols = [next(col_cyc) for i in range(n_sets)]

    # Raster plots.
    raster_axs = [fig.add_subplot(gsp_raster[i, 0]) for i in range(n_sets)]
    for i, (spk_trs, ax) in enumerate(zip(spk_list, raster_axs)):
        ylab = names[i] if (rs_ylab and names is not None) else None
        raster(spk_trs, prds=prds, c=cols[i], ylab=ylab, ax=ax)
        putil.hide_axes(ax)
    if len(raster_axs):
        putil.set_labels(raster_axs[0], title=title)  # add title to top raster

    # Rate plot.
    rate_ax = fig.add_subplot(gsp_rate[0, 0])
    rate(rate_list, names, prds, evts, cols, baseline, **rate_kws, ax=rate_ax)

    # Synchronize raster's x axis limits to rate plot's limits.
    xlim = rate_ax.get_xlim()
    [ax.set_xlim(xlim) for ax in raster_axs]

    # Save and return plot.
    putil.save_fig(ffig, fig)
    return fig, raster_axs, rate_ax
예제 #8
0
def plot_tuning(xfit,
                yfit,
                vals=None,
                meanr=None,
                semr=None,
                color='b',
                baseline=None,
                xticks=None,
                xlim=None,
                ylim=None,
                xlab=None,
                ylab=None,
                title=None,
                ffig=None,
                ax=None,
                **kwargs):
    """Plot tuning curve, optionally with data samples."""

    ax = putil.axes(ax)

    # Plot baseline.
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Plot fitted curve.
    pplot.lines(xfit, yfit, color=color, ax=ax)

    # Plot data samples.
    if meanr is not None and semr is not None:
        pplot.errorbar(vals,
                       meanr,
                       yerr=semr,
                       fmt='o',
                       color=color,
                       ax=ax,
                       **kwargs)

    # Set x axis ticks.
    if xticks is not None:
        putil.set_xtick_labels(ax, xticks)
    elif vals is not None:
        putil.set_xtick_labels(ax, vals)

    # Format plot.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #9
0
파일: pplot.py 프로젝트: mnislamraju/seal
def joint_scatter(x,
                  y,
                  is_sign=None,
                  kind='reg',
                  stat_func=util.pearson_r,
                  c='b',
                  xlim=None,
                  ylim=None,
                  xlab=None,
                  ylab=None,
                  title=None,
                  ytitle=None,
                  ffig=None,
                  **kwargs):
    """
    Plot paired data on scatter plot with
        - marginal distributions added to the side
        - linear regression on center scatter
        - N, r and p values reported

    Additional parameters for scatter (e.g. size) can be passed as kwargs.
    """

    # Create scatter plot and distribution plots on the side.
    g = sns.jointplot(x,
                      y,
                      color=c,
                      kind=kind,
                      stat_func=stat_func,
                      xlim=xlim,
                      ylim=ylim)
    ax = g.ax_joint  # scatter axes

    # Make non-significant points hollow (white face color).
    if is_sign is not None or kwargs is not None:
        ax.collections[0].set_visible(False)  # hide scatter points
        scatter(x, y, is_sign, c=c, ax=ax, **kwargs)

    # Add N to legend.
    leg_txt = g.ax_joint.get_legend().texts[0]
    new_txt = 'n = {}\n'.format(len(x)) + leg_txt.get_text()
    leg_txt.set_text(new_txt)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #10
0
파일: pplot.py 프로젝트: mnislamraju/seal
def scatter(x,
            y,
            is_sign=None,
            c='b',
            bc='w',
            nc='grey',
            ec='k',
            alpha=0.5,
            xlim=None,
            ylim=None,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            polar=False,
            id_line=True,
            ffig=None,
            ax=None,
            **kwargs):
    """
    Plot paired data on scatter plot.
    Color settings:
        - c:  face color of significant points (is_sign == True)
        - bc: face color of non-significant points (is_sign == False)
        - nc: face color of untested/untestable points (is_sign == None)
    """

    # Init.
    ax = putil.axes(ax, polar=polar)
    cols = c
    # Get point-specific color array.
    if (is_sign is not None) and isinstance(c, str) and isinstance(bc, str):
        cols = putil.get_cmat(is_sign, c, bc, nc)

    # Plot colored points.
    ax.scatter(x, y, c=cols, edgecolor=ec, alpha=alpha, **kwargs)

    # Add identity line.
    if id_line:
        putil.add_identity_line(equal_xy=True, zorder=99, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #11
0
파일: pupdown.py 프로젝트: mnislamraju/seal
def plot_spike_count_results(bspk_cnts, rec, task, prds, binsize):
    """Plot spike count results on composite plot for multiple periods."""

    # Init figure.
    putil.set_style('notebook', 'ticks')
    fig, gsp, _ = putil.get_gs_subplots(nrow=len(prds), ncol=1,
                                        subw=15, subh=4, create_axes=False)

    # Plot each period.
    for prd, sps in zip(prds, gsp):
        plot_prd_spike_count(bspk_cnts[prd], prd, sps, fig)

    # Save figure.
    title = '{} {}, binsize: {} ms'.format(rec, task, int(binsize))
    fname = 'UpDown_spk_cnt_hist_{}_{}_bin_{}.png'.format(rec, task,
                                                          int(binsize))
    ffig = util.join(['results', 'UpDown', fname])
    putil.save_fig(ffig, fig, title)
예제 #12
0
def DR_plot(UA, ftempl=None, match_scales=True):
    """Plot responses to all 8 directions and polar plot in the center."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=subw,
                                            subh=subw)
        task_rate_axs, task_polar_axs = [], []

        # Plot direction response of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            # Plot DR of unit.
            res = (pselectivity.plot_DR_3x3(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ax_polar, rate_axs = res
                task_rate_axs.extend(rate_axs)
                task_polar_axs.append(ax_polar)
            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            putil.sync_axes(task_polar_axs, sync_y=True)
            putil.sync_axes(task_rate_axs, sync_y=True)
            [putil.adjust_decorators(ax) for ax in task_rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
예제 #13
0
def selectivity_summary(UA, ftempl=None, match_scales=True):
    """Test selectivity of unit responses."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=8,
                                            subh=16)
        ls_axs, ds_axs = [], []

        # Plot stimulus response summary plot of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            res = (pselectivity.plot_selectivity(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ls_axs.extend(res[0])
                ds_axs.extend(res[1])

            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            for rate_axs in [ls_axs, ds_axs]:
                putil.sync_axes(rate_axs, sync_y=True)
                [putil.adjust_decorators(ax) for ax in rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
예제 #14
0
파일: pplot.py 프로젝트: mnislamraju/seal
def cat_hist(vals,
             xlim=None,
             ylim=None,
             xlab=None,
             ylab='n',
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             **kwargs):
    """Plot histogram of categorical data."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax = sns.countplot(vals, ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #15
0
파일: pplot.py 프로젝트: mnislamraju/seal
def multi_hist(vals,
               xlim=None,
               ylim=None,
               xlab=None,
               ylab='n',
               title=None,
               ytitle=None,
               polar=False,
               ffig=None,
               ax=None,
               **kwargs):
    """Plot histogram of multiple samples side by side."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.hist(vals, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #16
0
파일: pplot.py 프로젝트: mnislamraju/seal
def heatmap(mat,
            vmin=None,
            vmax=None,
            cmap=None,
            cbar=True,
            cbar_ax=None,
            annot=None,
            square=False,
            xlab=None,
            ylab=None,
            title=None,
            ytitle=None,
            xlim=None,
            ylim=None,
            xticklabels=True,
            yticklabels=True,
            ffig=None,
            ax=None):
    """Plot rectangular data as heatmap."""

    # Plot data.
    ax = putil.axes(ax)
    sns.heatmap(mat,
                vmin,
                vmax,
                cmap,
                annot=annot,
                cbar=cbar,
                cbar_ax=cbar_ax,
                square=square,
                xticklabels=xticklabels,
                yticklabels=yticklabels,
                ax=ax)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #17
0
def plot_slope_diffs_btw_groups(fit_res,
                                prd,
                                res_dir,
                                groups=None,
                                figsize=None):
    """Plot differences in slopes between group pairs."""

    # Test pair-wise difference from each other.
    if groups is None:
        groups = fit_res['group'].unique()
    empty_df = pd.DataFrame(np.nan, index=groups, columns=groups)
    pw_diff_v = empty_df.copy()
    pw_diff_p = empty_df.copy()
    pw_diff_a = empty_df.copy()
    for grp1, grp2 in combinations(groups, 2):
        # Get slopes in each task.
        slp_t1 = fit_res.slope[fit_res.group == grp1]
        slp_t2 = fit_res.slope[fit_res.group == grp2]
        # Mean difference value.
        diff = slp_t1.mean() - slp_t2.mean()
        pw_diff_v.loc[grp1, grp2] = diff
        # Do test for statistical difference.
        stat, pval = stats.mann_whithney_u_test(slp_t1, slp_t2)
        pw_diff_p.loc[grp1, grp2] = pval
        # Annotation DF plot.
        a1, a2 = [
            '{:.2f}{}'.format(v, util.star_pvalue(pval)) for v in (diff, -diff)
        ]
        pw_diff_a.loc[grp1, grp2] = a1
    # Plot and save figure of mean pair-wise difference.
    fig = putil.figure(figsize=figsize)
    sns.heatmap(pw_diff_v, annot=pw_diff_a, fmt='', linewidths=0.5, cbar=False)
    title = 'Mean difference in {} slopes (sp/s / s)'.format(prd)
    putil.set_labels(title=title)
    ffig = res_dir + '{}_anticipatory_slope_pairwise_diff.png'.format(prd)
    putil.save_fig(ffig, fig)

    return pw_diff_v
예제 #18
0
파일: pplot.py 프로젝트: mnislamraju/seal
def bars(x,
         y,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot bar plot."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.bar(x, y, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #19
0
def plot_CE_time_distribution(ulists,
                              eff_t_res,
                              eff_pars,
                              aroc_res_dir,
                              bins=None):
    """Plot distribution of ROC comparison effect timing across groups."""

    # Init.
    putil.set_style('notebook', 'white')
    if bins is None:
        bins = np.arange(2000, 2600, 50)
    fig, _, axs = putil.get_gs_subplots(nrow=1,
                                        ncol=len(eff_pars),
                                        subw=5,
                                        subh=4,
                                        create_axes=True,
                                        as_array=False)

    # Plot CE timing distribution for each unit group.
    for (eff_dir, eff_lbl), ax in zip(eff_pars, axs):
        etd = eff_t_res.loc[eff_t_res.effect_dir == eff_dir, 'time']
        for nlist in ulists:
            tvals = etd.loc[nlist]
            lbl = '{} (n={})'.format(nlist, len(tvals))
            sns.distplot(tvals, bins, label=lbl, ax=ax)
        putil.set_labels(ax, 'effect timing (ms since S1 onset)', '', eff_lbl)

    # Format plots.
    sns.despine(ax=ax)
    [ax.legend() for ax in axs]
    [putil.hide_tick_labels(ax, show_x_tick_lbls=True) for ax in axs]
    putil.sync_axes(axs, sync_y=True)

    # Save plot.s
    ffig = aroc_res_dir + 'CE/CE_timing_distributions.png'
    putil.save_fig(ffig, fig)
예제 #20
0
파일: pplot.py 프로젝트: mnislamraju/seal
def band(x,
         ylower,
         yupper,
         ylim=None,
         xlim=None,
         xlab=None,
         ylab=None,
         title=None,
         ytitle=None,
         polar=False,
         ffig=None,
         ax=None,
         **kwargs):
    """Plot highlighted band area."""

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    ax.fill_between(x, ylower, yupper, **kwargs)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #21
0
파일: pplot.py 프로젝트: mnislamraju/seal
def mean_err(x,
             ymean,
             ystd,
             ylim=None,
             xlim=None,
             xlab=None,
             ylab=None,
             title=None,
             ytitle=None,
             polar=False,
             ffig=None,
             ax=None,
             mean_kws=None,
             band_kws=None):
    """Plot mean and highlighted band area around it."""

    # Init params.
    if mean_kws is None:
        mean_kws = dict()
    if band_kws is None:
        band_kws = dict()

    # Init data.
    ylower = ymean - ystd
    yupper = ymean + ystd

    # Plot data.
    ax = putil.axes(ax, polar=polar)
    lines(x, ymean, ax=ax, **mean_kws)
    band(x, ylower, yupper, ax=ax, **band_kws)

    # Format and save figure.
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title, ytitle)
    putil.save_fig(ffig)

    return ax
예제 #22
0
파일: prepare.py 프로젝트: mnislamraju/seal
def select_units_trials(UA, utids=None, fres=None, ffig=None,
                        min_n_units=5, min_n_trs_per_unit=5):
    """
    Select optimal set of units and trials for population decoding.

    min_n_units: minimum number of units to keep (0: off)
    min_n_trs_per_unit: minimum number of trials per unit to keep (0: off)
    """

    print('Selecting optimal set of units and trials for decoding...')

    # Init.
    if utids is None:
        utids = UA.utids(as_series=True)
    u_rt_grpby = utids.groupby(level=['subj', 'date', 'task'])

    # Unit info frame.
    UInc = pd.Series(False, index=utids.index)

    # Included trials by unit.
    IncTrs = pd.Series([(UA.get_unit(utid[:-1], utid[-1]).inc_trials())
                        for utid in utids], index=utids.index)

    # Result DF.
    rec_task = pd.MultiIndex.from_tuples([rt for rt, _ in u_rt_grpby],
                                         names=['subj', 'date', 'task'])
    cols = ['elec', 'units', 'nunits', 'nallunits', '% remaining units',
            'trials', 'ntrials', 'nalltrials', '% remaining trials']
    RecInfo = pd.DataFrame(index=rec_task, columns=cols)
    rt_utids = [utids.xs((s, d, t), level=('subj', 'date', 'task'))
                for s, d, t in rec_task]
    RecInfo.nallunits = [len(utids) for utids in rt_utids]
    rt_ulist = [UA.get_unit(utids[0][:-1], utids[0][-1]) for utids in rt_utids]
    RecInfo.nalltrials = [int(u.QualityMetrics['NTrialsTotal'])
                          for u in rt_ulist]

    # Function to plot matrix (DF) included/excluded trials.
    def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None,
                            xlab='Trial #', ylab=None,):
        # Plot on heatmap.
        sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax)
        # Set tick labels.
        putil.hide_tick_marks(ax)
        tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25))
        ax.xaxis.set_ticks(tr_ticks)
        ax.set_xticklabels(tr_ticks)
        putil.rot_xtick_labels(ax, 0)
        putil.rot_ytick_labels(ax, 0, va='center')
        putil.set_labels(ax, xlab, ylab, title, ytitle)

    # Init plotting.
    ytitle = 1.40
    putil.set_style('notebook', 'whitegrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(rec_task), ncol=3,
                                          subw=6, subh=4, create_axes=True)

    for i_rt, ((subj, date, task), rt_utids) in enumerate(u_rt_grpby):
        print('{} / {}: {} - {} {}'.format(i_rt+1, len(u_rt_grpby), subj,
                                           date, task))

        # Init electrode.
        elecs = rt_utids.index.get_level_values('elec').unique()
        if len(elecs) != 1:
            warnings.warn('More than one electrode?')
        elec = elecs[0]
        RecInfo.loc[(subj, date, task), 'elec'] = elec

        # Create matrix of included trials of recording & task of units.
        ch_idxs = rt_utids.index.droplevel(-1).droplevel(2).droplevel(1).droplevel(0)
        n_alltrs = RecInfo.nalltrials[(subj, date, task)]
        IncTrsMat = pd.DataFrame(np.zeros((len(ch_idxs), n_alltrs), dtype=int),
                                 index=ch_idxs, columns=np.arange(n_alltrs)+1)
        for ch_idx, utid in zip(ch_idxs, rt_utids):
            IncTrsMat.loc[ch_idx].iloc[IncTrs[utid]] = 1

        # Plot included/excluded trials after preprocessing.
        ax = axs[i_rt, 0]
        ylab = '{} {} {}'.format(subj, date, task)
        title = ('Included (green) and excluded (red) trials'
                 if i_rt == 0 else None)
        plot_inc_exc_trials(IncTrsMat, ax, title, ytitle, ylab=ylab)

        # Calculate and plot overlap of trials across units.
        # How many trials will remain if we iteratively excluding units
        # with the least overlap with the rest of the units?
        def n_cov_trs(df):  # return number of trials covered in df
            return sum(df.all())

        def calc_heuristic(df):
            return df.shape[0] * n_cov_trs(df)

        n_trs = IncTrsMat.sum(1)
        n_units = IncTrsMat.shape[0]

        # Init results DF.
        columns = ('uid', 'ntrs_cov', 'n_rem_u', 'trial x units')
        tr_covs = pd.DataFrame(columns=columns, index=range(n_units+1))
        tr_covs.loc[0] = ('none', n_cov_trs(IncTrsMat), n_units,
                          calc_heuristic(IncTrsMat))

        # Subset of included units (to be updated in each iteration).
        uinc = IncTrsMat.index.to_series()
        for iu in range(1, len(uinc)):

            # Number of covered trials after removing each unit.
            sntrscov = pd.Series([n_cov_trs(IncTrsMat.loc[uinc.drop(uid)])
                                  for uid in uinc], index=uinc.index)

            #########################################
            # Select and remove unit that           #
            # (a) yields maximum trial coverage,    #
            # (b) has minimum number of trials      #
            #########################################
            maxtrscov = sntrscov.max()
            worst_us = sntrscov[sntrscov == maxtrscov].index  # (a)
            utrs = n_trs.loc[worst_us]
            uid_remove = utrs[(utrs == min(utrs))].index[0]   # (b)

            # Update current subset of units and their trial DF.
            uinc.drop(uid_remove, inplace=True)
            tr_covs.loc[iu] = (uid_remove, maxtrscov, len(uinc),
                               calc_heuristic(IncTrsMat.loc[uinc]))

        # Add last unit.
        tr_covs.iloc[-1] = (uinc[0], 0, 0, 0)

        # Plot covered trials against each units removed.
        ax_trc = axs[i_rt, 1]
        sns.tsplot(tr_covs['ntrs_cov'], marker='o', ms=4, color='b',
                   ax=ax_trc)
        title = ('Trial coverage during iterative unit removal'
                 if i_rt == 0 else None)
        xlab, ylab = 'current unit removed', '# trials covered'
        putil.set_labels(ax_trc, xlab, ylab, title, ytitle)
        ax_trc.xaxis.set_ticks(tr_covs.index)
        x_ticklabs = ['none'] + ['{} - {}'.format(ch, ui)
                                 for ch, ui in tr_covs.uid.loc[1:]]
        ax_trc.set_xticklabels(x_ticklabs)
        putil.rot_xtick_labels(ax_trc, 45)
        ax_trc.grid(True)

        # Add # of remaining units to top.
        ax_remu = ax_trc.twiny()
        ax_remu.xaxis.set_ticks(tr_covs.index)
        ax_remu.set_xticklabels(list(range(len(x_ticklabs)))[::-1])
        ax_remu.set_xlabel('# units remaining')
        ax_remu.grid(None)

        # Add heuristic index.
        ax_heur = ax_trc.twinx()
        sns.tsplot(tr_covs['trial x units'], linestyle='--', marker='o',
                   ms=4, color='m',  ax=ax_heur)
        putil.set_labels(ax_heur, ylab='remaining units x covered trials')
        [tl.set_color('m') for tl in ax_heur.get_yticklabels()]
        [tl.set_color('b') for tl in ax_trc.get_yticklabels()]
        ax_heur.grid(None)

        # Decide on which units to exclude.
        min_n_trials = min_n_trs_per_unit * tr_covs['n_rem_u']
        sub_tr_covs = tr_covs[(tr_covs['n_rem_u'] >= min_n_units) &
                              (tr_covs['ntrs_cov'] >= min_n_trials)]

        # If any subset of units passed above criteria.
        rem_uids, exc_uids = pd.Series(), tr_covs.uid[1:]
        n_tr_rem, n_tr_exc = 0, IncTrsMat.shape[1]
        if len(sub_tr_covs.index):
            hmax_idx = sub_tr_covs['trial x units'].argmax()
            rem_uids = tr_covs.uid[(hmax_idx+1):]
            exc_uids = tr_covs.uid[1:hmax_idx+1]
            n_tr_rem = tr_covs.ntrs_cov[hmax_idx]
            n_tr_exc = IncTrsMat.shape[1] - n_tr_rem

            # Add to UnitInfo dataframe
            rt_utids = [(subj, date, elec, ch, ui, task)
                        for ch, ui in rem_uids]
            UInc[rt_utids] = True

        # Highlight selected point in middle plot.
        sel_seg = [('selection', exc_uids.shape[0]-0.4,
                    exc_uids.shape[0]+0.4)]
        putil.plot_periods(sel_seg, ax=ax_trc, alpha=0.3)
        [ax.set_xlim([-0.5, n_units+0.5]) for ax in (ax_trc, ax_remu)]

        # Generate remaining trials dataframe.
        RemTrsMat = IncTrsMat.copy().astype(float)
        for exc_uid in exc_uids:   # Remove all trials from excluded units.
            RemTrsMat.loc[exc_uid] = 0.5
        # Remove uncovered trials in remaining units.
        exc_trs = np.where(~RemTrsMat.loc[list(rem_uids)].all())[0]
        if exc_trs.size:
            RemTrsMat.iloc[:, exc_trs] = 0.5
        # Overwrite by trials excluded during preprocessing.
        RemTrsMat[IncTrsMat == False] = 0.0

        # Plot remaining trials.
        ax = axs[i_rt, 2]
        n_u_rem, n_u_exc = len(rem_uids), len(exc_uids)
        title = ('# units remaining: {}, excluded: {}'.format(n_u_rem,
                                                              n_u_exc) +
                 '\n# trials remaining: {}, excluded: {}'.format(n_tr_rem,
                                                                 n_tr_exc))
        plot_inc_exc_trials(RemTrsMat, ax, title=title, ylab='')

        # Add remaining units and trials to RecInfo.
        rt = (subj, date, task)
        RecInfo.loc[rt, ('units', 'nunits')] = list(rem_uids), len(rem_uids)
        cov_trs = RemTrsMat.loc[list(rem_uids)].all()
        inc_trs = pd.Int64Index(np.where(cov_trs)[0])
        RecInfo.loc[rt, ('trials', 'ntrials')] = inc_trs, sum(cov_trs)

    RecInfo['% remaining units'] = 100 * RecInfo.nunits / RecInfo.nallunits
    RecInfo['% remaining trials'] = 100 * RecInfo.ntrials / RecInfo.nalltrials

    # Save results.
    if fres is not None:
        results = {'RecInfo': RecInfo, 'UInc': UInc}
        util.write_objects(results, fres)

    # Save plot.
    title = 'Trial & unit selection prior decoding'
    putil.save_fig(ffig, fig, title, w_pad=3, h_pad=3)

    return RecInfo, UInc
예제 #23
0
def plot_qm(u,
            bs_stats,
            stab_prd_res,
            prd_inc,
            tr_inc,
            spk_inc,
            add_lbls=False,
            ftempl=None,
            fig=None,
            sps=None):
    """Plot quality metrics related figures."""

    # Init values.
    waveforms = np.array(u.Waveforms)
    wavetime = u.Waveforms.columns * us
    spk_times = np.array(u.SpikeParams['time'], dtype=float)
    base_rate = u.QualityMetrics['baseline']

    # Minimum and maximum gain.
    gmin = u.SessParams['minV']
    gmax = u.SessParams['maxV']

    # %% Init plots.

    # Disable inline plotting to prevent memory leak.
    putil.inline_off()

    # Init figure and gridspec.
    fig = putil.figure(fig)
    if sps is None:
        sps = putil.gridspec(1, 1)[0]
    ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1])

    info_sps, qm_sps = ogsp[0], ogsp[1]

    # Info header.
    info_ax = fig.add_subplot(info_sps)
    putil.hide_axes(info_ax)
    title = putil.get_unit_info_title(u)
    putil.set_labels(ax=info_ax, title=title, ytitle=0.80)

    # Create axes.
    gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4)
    ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)]
    ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)]
    ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)]

    # Trial markers.
    trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop
    tr_markers = pd.DataFrame({'time': trial_starts[9::10]})
    tr_markers['label'] = [
        str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index)
    ]

    # Common variables, limits and labels.
    WF_T_START = test_sorting.WF_T_START
    spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) -
                                      WF_T_START)
    ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts,
                                                  trial_stops)
    ss, sa = 1.0, 0.8  # marker size and alpha on scatter plot

    # Color spikes by their occurance over session time.
    my_cmap = putil.get_cmap('jet')
    spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1))
    if np.any(spk_inc):  # check if there is any spike included
        spk_t_inc = np.array(spk_times[spk_inc])
        tmin, tmax = float(spk_times.min()), float(spk_times.max())
        spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin))
    # Put excluded trials to the front, and randomise order of included trials
    # so later spikes don't systematically cover earlier ones.
    spk_order = np.hstack((np.where(np.invert(spk_inc))[0],
                           np.random.permutation(np.where(spk_inc)[0])))

    # Common labels for plots
    ses_t_lab = 'Recording time (s)'

    # %% Waveform shape analysis.

    # Plot included and excluded waveforms on different axes.
    # Color included by occurance in session time to help detect drifts.
    s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order]
    wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax]
    wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage'
    for st in ('Included', 'Excluded'):
        ax = ax_wf_inc if st == 'Included' else ax_wf_exc
        spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc)
        tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc)

        nspsk, ntrs = sum(spk_idx), sum(tr_idx)
        title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs)

        # Select waveforms and colors.
        rand_spk_idx = spk_idx[spk_order]
        wfs = s_waveforms[rand_spk_idx, :]
        cols = s_spk_cols[rand_spk_idx]

        # Plot waveforms.
        xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None)
        pwaveform.plot_wfs(wfs,
                           spk_t,
                           cols=cols,
                           lw=0.1,
                           alpha=0.05,
                           xlim=wf_t_lim,
                           ylim=glim,
                           title=title,
                           xlab=xlab,
                           ylab=ylab,
                           ax=ax)

    # %% Waveform summary metrics.

    # Init data.
    wf_amp_all = u.SpikeParams['amplitude']
    wf_amp_inc = wf_amp_all[spk_inc]
    wf_dur_all = u.SpikeParams['duration']
    wf_dur_inc = wf_dur_all[spk_inc]

    # Set common limits and labels.
    dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]]  # same across units
    glim = max(wf_amp_all.max(), gmax - gmin)
    amp_lim = [0, glim]

    amp_lab = 'Amplitude'
    dur_lab = 'Duration ($\mu$s)'

    # Waveform amplitude across session time.
    m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std()
    title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp)
    xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_amp_all,
                  spk_inc,
                  c='m',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_amp)

    # Waveform duration across session time.
    mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std()
    title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur)
    xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_dur_all,
                  spk_inc,
                  c='c',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=dur_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_dur)

    # Waveform duration against amplitude.
    title = 'WF duration - amplitude'
    xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(wf_dur_all[spk_order],
                  wf_amp_all[spk_order],
                  c=spk_cols[spk_order],
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=dur_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_amp_dur)

    # %% Firing rate.

    tmean = np.array(bs_stats['tmean'])
    rmean = util.remove_dim_from_series(bs_stats['rate'])
    prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop']

    # Color segments depending on whether they are included / excluded.
    def plot_periods(v, color, ax):
        # Plot line segments.
        for i in range(len(prd_inc[:-1])):
            col = color if prd_inc[i] and prd_inc[i + 1] else 'grey'
            x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])]
            ax.plot(x, y, color=col)
        # Plot line points.
        for i in range(len(prd_inc)):
            col = color if prd_inc[i] else 'grey'
            x, y = [tmean[i], v[i]]
            ax.plot(x,
                    y,
                    color=col,
                    marker='o',
                    markersize=3,
                    markeredgecolor=col)

    # Firing rate over session time.
    title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate))
    xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None)
    ylim = [0, 1.25 * np.max(rmean)]
    plot_periods(rmean, 'b', ax_rate)
    pplot.lines([], [],
                c='b',
                xlim=ses_t_lim,
                ylim=ylim,
                title=title,
                xlab=xlab,
                ylab=ylab,
                ax=ax_rate)

    # Trial markers.
    putil.plot_events(tr_markers,
                      lw=0.5,
                      ls='--',
                      alpha=0.35,
                      y_lbl=0.92,
                      ax=ax_rate)

    # Excluded periods.
    excl_prds = []
    tstart, tstop = ses_t_lim
    if tstart != prd_tstart:
        excl_prds.append(('beg', tstart, prd_tstart))
    if tstop != prd_tstop:
        excl_prds.append(('end', prd_tstop, tstop))
    putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate)

    # %% Post-formatting.

    # Maximize number of ticks on recording time axes to prevent covering.
    for ax in (ax_wf_amp, ax_wf_dur, ax_rate):
        putil.set_max_n_ticks(ax, 6, 'x')

    # %% Save figure.
    if ftempl is not None:
        fname = ftempl.format(u.name_to_fname())
        putil.save_fig(fname, fig, title, rect_height=0.92)
        putil.inline_on()

    return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
예제 #24
0
def plot_DR_tuning(DS,
                   title=None,
                   labels=True,
                   baseline=None,
                   DR_legend=True,
                   tuning_legend=True,
                   ffig=None,
                   fig=None,
                   sps=None):
    """Plot direction selectivity on polar plot and tuning curve."""

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 1, 2, wspace=0.3)
    ax_DR = fig.add_subplot(gsp[0], polar=True)
    ax_tuning = fig.add_subplot(gsp[1])
    DR_patches = []
    tuning_patches = []

    stims = DS.DSI.index
    for stim in stims:

        # Extract params and results.
        a, b, x0, sigma, FWHM, R2, RMSE = DS.TP.loc[stim]
        DSI = DS.DSI.wDS[stim]
        PD, cPD = DS.PD.loc[(stim, 'weighted'), ['PD', 'cPD']]
        dirs = np.array(DS.DR[stim].index) * deg
        SR, SRsem = [DS.DR[stim][col] for col in ('mean', 'sem')]

        # Center direction  and response stats.
        dirsc = direction.center_to_dir(dirs, PD)

        # Generate data points for plotting fitted tuning curve.
        x, y = tuning.gen_fit_curve(tuning.gaus,
                                    -180,
                                    180,
                                    a=a,
                                    b=b,
                                    x0=x0,
                                    sigma=sigma)

        # Plot DR polar plot.
        color = putil.stim_colors.loc[stim]
        ttl = 'Direction response' if labels else None
        plot_DR(dirs, SR, DSI, PD, baseline, title=ttl, color=color, ax=ax_DR)

        # Collect parameters of DR plot (stimulus - response).
        s_pd = str(float(round(PD, 1)))
        s_pd_c = str(int(cPD)) if not np.isnan(float(cPD)) else 'nan'
        lgd_lbl = '{}:   {:.3f}'.format(stim, DSI)
        lgd_lbl += '     {:>5}$^\circ$ --> {:>3}$^\circ$ '.format(s_pd, s_pd_c)
        DR_patches.append(putil.get_artist(lgd_lbl, color))

        # Calculate and plot direction tuning curve.
        xticks = np.arange(-180, 180 + 1, 45)
        plot_tuning(x,
                    y,
                    dirsc,
                    SR,
                    SRsem,
                    color,
                    baseline,
                    xticks,
                    ax=ax_tuning)

        # Collect parameters tuning curve fit.
        s_a, s_b, s_x0, s_sigma, s_FWHM = [
            str(float(round(p, 1))) for p in (a, b, x0, sigma, FWHM)
        ]
        s_R2 = format(R2, '.2f')
        lgd_lbl = '{}:{}{:>6}{}{:>6}'.format(stim, 5 * ' ', s_a, 5 * ' ', s_b)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(5 * ' ', s_x0, 8 * ' ', s_sigma)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(8 * ' ', s_FWHM, 8 * ' ', s_R2)
        tuning_patches.append(putil.get_artist(lgd_lbl, color))

    # Format tuning curve plot (only after all stimuli has been plotted!).
    xlab = 'Difference from preferred direction (deg)' if labels else None
    ylab = putil.FR_lbl if labels else None
    xlim = [-180 - 5, 180 + 5]  # degrees
    ylim = [0, None]
    ttl = 'Tuning curve' if labels else None
    putil.format_plot(ax_tuning, xlim, ylim, xlab, ylab, ttl)
    putil.add_zero_line('y', ax=ax_tuning)  # add zero reference line

    # Set super title.
    if title is not None:
        fig.suptitle(title, y=1.1, fontsize='x-large')

    # Set legends.
    ylegend = -0.38 if labels else -0.15
    fr_on = False if labels else True
    lgd_kws = dict([('fancybox', True), ('shadow', False), ('frameon', fr_on),
                    ('framealpha', 1.0), ('loc', 'lower center'),
                    ('bbox_to_anchor', [0., ylegend, 1., .0]),
                    ('prop', {
                        'family': 'monospace'
                    })])
    DR_lgn_ttl = 'DSI'.rjust(20) + 'PD'.rjust(14) + 'PD8'.rjust(14)
    tuning_lgd_ttl = ('a (sp/s)'.rjust(35) + 'b (sp/s)'.rjust(15) +
                      'x0 (deg)'.rjust(13) + 'sigma (deg)'.rjust(15) +
                      'FWHM (deg)'.rjust(15) + 'R-squared'.rjust(15))

    lgn_params = [(DR_legend, DR_lgn_ttl, DR_patches, ax_DR),
                  (tuning_legend, tuning_lgd_ttl, tuning_patches, ax_tuning)]

    for (plot_legend, lgd_ttl, patches, ax) in lgn_params:
        if not plot_legend:
            continue
        if not labels:  # customize for summary plot
            lgd_ttl = None
        lgd = putil.set_legend(ax, handles=patches, title=lgd_ttl, **lgd_kws)
        lgd.get_title().set_ha('left')
        if lgd_kws['frameon']:
            lgd.get_frame().set_linewidth(.5)

    # Save figure.
    if ffig is not None:
        putil.save_fig(ffig, fig, rect_height=0.55)

    return ax_DR, ax_tuning
예제 #25
0
def plot_DR(dirs,
            resp,
            DSI=None,
            PD=None,
            baseline=None,
            plot_type='line',
            complete_missing_dirs=False,
            color='b',
            title=None,
            ffig=None,
            ax=None):
    """
    Plot response to each directions on polar plot, with a vector pointing to
    preferred direction (PD) with length DSI.
    Use plot_type to change between sector ('bar') and connected ('line') plot
    types.
    """

    ax = putil.axes(ax, polar=True)

    # Plot baseline.
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Remove NaNs.
    all_dirs = dirs
    not_nan = np.array(~pd.isnull(dirs) & ~pd.isnull(resp))
    dirs, resp = dirs[not_nan], np.array(resp[not_nan])

    # Prepare data.
    # Complete missing directions with 0 response.
    if complete_missing_dirs:
        for i, d in enumerate(all_dirs):
            if d not in dirs:
                dirs = np.insert(dirs, i, d) * dirs.units
                resp = np.insert(resp, i, 0) * 1 / s

    rad_dirs = dirs.rescale(rad)

    # Plot response to each directions on polar plot.
    if plot_type == 'bar':  # sector plot
        ndirs = all_dirs.size
        left_rad_dirs = rad_dirs - np.pi / ndirs  # no need for this in MPL 2.0?
        w = 2 * np.pi / ndirs  # same with edgecolor and else?
        pplot.bars(left_rad_dirs,
                   resp,
                   width=w,
                   alpha=0.50,
                   color=color,
                   lw=1,
                   edgecolor='w',
                   title=title,
                   ytitle=1.08,
                   ax=ax)
    else:  # line plot
        rad_dirs, resp = [np.append(v, [v[0]]) for v in (rad_dirs, resp)]
        pplot.lines(rad_dirs,
                    resp,
                    color=color,
                    marker='o',
                    lw=1,
                    ms=4,
                    mew=0,
                    title=title,
                    ytitle=1.08,
                    ax=ax)
        ax.fill(rad_dirs, resp, color=color, alpha=0.15)

    # Add arrow representing PD and weighted DSI.
    if DSI is not None and PD is not None:
        rho = np.max(resp) * DSI
        xy = (float(PD.rescale(rad)), rho)
        arr_props = dict(facecolor=color, edgecolor='k', shrink=0.0, alpha=0.5)
        ax.annotate('', xy, xytext=(0, 0), arrowprops=arr_props)

    # Remove spines and tick marks, maximize tick labels.
    putil.set_spines(ax, False, False)
    putil.hide_tick_marks(ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #26
0
def rate(rate_list,
         names=None,
         prds=None,
         evts=None,
         cols=None,
         baseline=None,
         pval=0.05,
         test='mann_whitney_u',
         test_kws=None,
         xlim=None,
         ylim=None,
         title=None,
         xlab=None,
         ylab=putil.FR_lbl,
         add_lgn=True,
         lgn_lbl='trs',
         ffig=None,
         ax=None):
    """Plot firing rate."""

    # Init.
    ax = putil.axes(ax)
    if test_kws is None:
        test_kws = dict()

    # Plot periods and baseline first.
    putil.plot_periods(prds, ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)
    putil.set_limits(ax, xlim)

    if not len(rate_list):
        return ax

    if cols is None:
        cols = putil.get_colors(as_cycle=False)
    if names is None:
        names = len(rate_list) * ['']

    # Iterate through list of rate arrays
    xmin, xmax, ymax = None, None, None
    for i, rts in enumerate(rate_list):

        # Init.
        name = names[i]
        col = cols[i]

        # Skip empty array (no trials).
        if not rts.shape[0]:
            continue

        # Set line label. Convert to Numpy array to format floats nicely.
        lbl = str(np.array(name)) if util.is_iterable(name) else str(name)

        if lgn_lbl is not None:
            lbl += ' ({} {})'.format(rts.shape[0], lgn_lbl)

        # Plot mean +- SEM of rate vectors.
        tvec, meanr, semr = rts.columns, rts.mean(), rts.sem()
        ax.plot(tvec, meanr, label=lbl, color=col)
        ax.fill_between(tvec,
                        meanr - semr,
                        meanr + semr,
                        alpha=0.2,
                        facecolor=col,
                        edgecolor=col)

        # Update limits.
        tmin, tmax, rmax = tvec.min(), tvec.max(), (meanr + semr).max()
        xmin = np.min([xmin, tmin]) if xmin is not None else tmin
        xmax = np.max([xmax, tmax]) if xmax is not None else tmax
        ymax = np.max([ymax, rmax]) if ymax is not None else rmax

    # Set ticks, labels and axis limits.
    if xlim is None:
        if xmin == xmax:  # avoid setting identical limits
            xmax = None
        xlim = (xmin, xmax)
    if ylim is None:
        ymax = 1.02 * ymax if (ymax is not None) and (ymax > 0) else None
        ylim = (0, ymax)
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    t1, t2 = ax.get_xlim()  # in case it was set to None
    tmarks, tlbls = putil.get_tick_marks_and_labels(t1, t2)
    putil.set_xtick_labels(ax, tmarks, tlbls)
    putil.set_max_n_ticks(ax, 7, 'y')

    # Add legend.
    if add_lgn and len(rate_list):
        putil.set_legend(ax,
                         loc=1,
                         borderaxespad=0.0,
                         handletextpad=0.4,
                         handlelength=0.6)

    # Add significance line to top of axes.
    if (pval is not None) and (len(rate_list) == 2):
        rates1, rates2 = rate_list
        sign_prds = stats.sign_periods(rates1, rates2, pval, test, **test_kws)
        putil.plot_signif_prds(sign_prds, color='m', linewidth=4.0, ax=ax)

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #27
0
파일: prepare.py 프로젝트: mnislamraju/seal
def PD_across_units(UA, UInc, utids=None, fres=None, ffig=None):
    """
    Test consistency/spread of PD across units per recording.
    What is the spread in the preferred directions across units?

    Return population level preferred direction (and direction selectivity),
    that can be used to determine dominant preferred direction to decode.
    """

    # Init.
    if utids is None:
        utids = UA.utids(as_series=True)
    tasks = utids.index.get_level_values('task').unique()
    recs = util.get_subj_date_pairs(utids)

    # Get DS info frame.
    DSInfo = ua_query.get_DSInfo_table(UA, utids)
    DSInfo['include'] = UInc

    # Calculate population PD and DSI.
    dPPDres = {}
    for rec in recs:
        for task in tasks:

            # Init.
            rt = rec + (task,)
            rtDSInfo = DSInfo.xs(rt, level=[0, 1, -1])
            if rtDSInfo.empty:
                continue

            # Calculate population PD and population DSI.
            res = direction.calc_PPD(rtDSInfo.loc[rtDSInfo.include])
            dPPDres[rt] = res

    PPDres = pd.DataFrame(dPPDres).T

    # Save results.
    if fres is not None:
        results = {'DSInfo': DSInfo, 'PPDres': PPDres}
        util.write_objects(results, fres)

    # Plot results.

    # Init plotting.
    putil.set_style('notebook', 'darkgrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks),
                                          subw=6, subh=6, create_axes=True,
                                          ax_kws_list={'projection': 'polar'})
    xticks = direction.deg2rad(constants.all_dirs + 360/8/2*deg)

    for ir, rec in enumerate(recs):
        for it, task in enumerate(tasks):

            # Init.
            rt = rec + (task,)
            rtDSInfo = DSInfo.xs(rt, level=[0, 1, -1])
            ax = axs[ir, it]
            if rtDSInfo.empty:
                ax.set_axis_off()
                continue
            PDSI, PPD, PPDc, PADc = PPDres.loc[rt]

            # Plot PD - DSI on polar plot.
            sPPDc, sPADc = [int(v) if not np.isnan(v) else v
                            for v in (PPDc, PADc)]
            title = (' '.join(rt) + '\n' +
                     'PPDc = {}$^\circ$ - {}$^\circ$'.format(PPDc, PADc) +
                     ', PDSI = {:.2f}'.format(PDSI))
            PDrad = direction.deg2rad(util.remove_dim_from_series(rtDSInfo.PD))
            pplot.scatter(PDrad, rtDSInfo.DSI, rtDSInfo.include, ylim=[0, 1],
                          title=title, ytitle=1.08, c='darkblue', ec='k',
                          linewidth=1, s=80, alpha=0.8, zorder=2, ax=ax)

            # Highlight PPD and PAD.
            offsets = np.array([-45, 0, 45]) * deg
            for D, c in [(PPDc, 'g'), (PADc, 'r')]:
                if np.isnan(D):
                    continue
                hlDs = direction.deg2rad(np.array(D+offsets))
                for hlD, alpha in [(hlDs, 0.2), ([hlDs[1]], 0.4)]:
                    pplot.bars(hlD, len(hlD)*[1], align='center',
                               alpha=alpha, color=c, zorder=1, ax=ax)

            # Format ticks.
            ax.set_xticks(xticks, minor=True)
            ax.grid(b=True, axis='x', which='minor')
            ax.grid(b=False, axis='x', which='major')
            putil.hide_tick_marks(ax)

    # Save plot.
    title = 'Population direction selectivity'
    putil.save_fig(ffig, fig, title, w_pad=12, h_pad=20)

    return DSInfo, PPDres
예제 #28
0
파일: pplot.py 프로젝트: mnislamraju/seal
def plot_group_violin(res,
                      x,
                      y,
                      groups=None,
                      npval=None,
                      pth=0.01,
                      color='grey',
                      ylim=None,
                      ylab=None,
                      ffig=None):
    """Plot group-wise results on violin plots."""

    if groups is None:
        groups = res['group'].unique()

    # Test difference from zero in each groups.
    ttest_res = {
        group: sp.stats.ttest_1samp(gres[y], 0)
        for group, gres in res.groupby(x)
    }
    ttest_res = pd.DataFrame.from_dict(ttest_res, 'index')

    # Binarize significance test.
    res['is_sign'] = res[npval] < pth if npval is not None else True
    res['direction'] = np.sign(res[y])

    # Set up figure and plot data.
    fig = putil.figure()
    ax = putil.axes()
    putil.add_baseline(ax=ax)
    sns.violinplot(x=x, y=y, data=res, inner=None, order=groups, ax=ax)
    sns.swarmplot(x=x,
                  y=y,
                  hue='is_sign',
                  data=res,
                  color=color,
                  order=groups,
                  hue_order=[True, False],
                  ax=ax)
    putil.set_labels(ax, xlab='', ylab=ylab)
    putil.set_limits(ax, ylim=ylim)
    putil.hide_legend(ax)

    # Add annotations.
    ymin, ymax = ax.get_ylim()
    ylvl = ymax
    for i, group in enumerate(groups):
        gres = res.loc[res.group == group]
        # Mean.
        mean_str = 'Mean:\n' if i == 0 else '\n'
        mean_str += '{:.2f}'.format(gres[y].mean())
        # Non-zero test of distribution.
        str_pval = util.format_pvalue(ttest_res.loc[group, 'pvalue'])
        mean_str += '\n({})'.format(str_pval)
        # Stats on difference from baseline.
        nnonsign, ntot = (~gres.is_sign).sum(), len(gres)
        npos, nneg = [
            sum(gres.is_sign & (gres.direction == d)) for d in (1, -1)
        ]
        sign_to_report = [('+', npos), ('=', nnonsign), ('-', nneg)]
        nsign_str = ''
        for symb, n in sign_to_report:
            prc = str(int(round(100 * n / ntot)))
            nsign_str += '{} {:>3} / {} ({:>2}%)\n'.format(
                symb, int(n), ntot, prc)
        lbl = '{}\n\n{}'.format(mean_str, nsign_str)
        ax.text(i, ylvl, lbl, fontsize='smaller', va='bottom', ha='center')

    # Save plot.
    putil.save_fig(ffig, fig)

    return fig, ax
예제 #29
0
파일: prepare.py 프로젝트: mnislamraju/seal
def plot_trial_type_distribution(UA, RecInfo, utids=None, tr_par=('S1', 'Dir'),
                                 save_plot=False, fname=None):
    """Plot distribution of trial types."""

    # Init.
    par_str = util.format_to_fname(str(tr_par))
    if utids is None:
        utids = UA.utids(as_series=True)
    recs = util.get_subj_date_pairs(utids)
    tasks = RecInfo.index.get_level_values('task').unique()
    tasks = [task for task in UA.tasks() if task in tasks]  # reorder tasks

    # Init plotting.
    putil.set_style('notebook', 'darkgrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks),
                                          subw=4, subh=3, create_axes=True)

    for ir, rec in enumerate(recs):
        for it, task in enumerate(tasks):

            ax = axs[ir, it]
            rt = rec + (task,)
            if rt not in RecInfo.index:
                ax.set_axis_off()
                continue

            # Get includecd trials and their parameters.
            inc_trs = RecInfo.loc[rt, 'trials']
            utid = utids.xs(rt, level=('subj', 'date', 'task'))[0]
            TrData = UA.get_unit(utid[:-1], utid[-1]).TrData.loc[inc_trs]

            # Create DF to plot.
            anw_df = TrData[[tr_par, 'correct']].copy()
            anw_df['answer'] = 'error'
            anw_df.loc[anw_df.correct, 'answer'] = 'correct'
            all_df = anw_df.copy()
            all_df.answer = 'all'
            comb_df = pd.concat([anw_df, all_df])

            if not TrData.size:
                ax.set_axis_off()
                continue

            # Plot as countplot.
            sns.countplot(x=tr_par, hue='answer', data=comb_df,
                          hue_order=['all', 'correct', 'error'], ax=ax)
            sns.despine(ax=ax)
            putil.hide_tick_marks(ax)
            putil.set_max_n_ticks(ax, 6, 'y')
            ax.legend(loc=[0.95, 0.7])

            # Add title.
            title = '{} {}'.format(rec, task)
            nce = anw_df.answer.value_counts()
            nc, ne = [nce[c] if c in nce else 0 for c in ('correct', 'error')]
            pnc, pne = 100*nc/nce.sum(), 100*ne/nce.sum()
            title += '\n\n# correct: {} ({:.0f}%)'.format(nc, pnc)
            title += '      # error: {} ({:.0f}%)'.format(ne, pne)
            putil.set_labels(ax, title=title, xlab=par_str)

            # Format legend.
            if (ir != 0) or (it != 0):
                ax.legend_.remove()

    # Save plot.
    if save_plot:
        title = 'Trial type distribution'
        if fname is None:
            fname = util.join(['results', 'decoding', 'prepare',
                               par_str + '_trial_type_distr.pdf'])

        putil.save_fig(fname, fig, title, w_pad=3, h_pad=3)
예제 #30
0
def quality_test(UA, ftempl=None, plot_qm=False, fselection=None):
    """Test and plot quality metrics of recording and spike sorting """

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # Import unit&trial selection file.
    UnTrSel = pd.read_excel(fselection) if (fselection is not None) else None

    # For each unit over all tasks.
    d_QC_tests = {}
    for uid in UA.uids():

        # Init figure.
        if plot_qm:
            fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                                ncol=len(UA.tasks()),
                                                subw=subw,
                                                subh=1.6 * subw)
            wf_axs, amp_axs, dur_axs, amp_dur_axs, rate_axs = ([], [], [], [],
                                                               [])

        for i, task in enumerate(UA.tasks()):

            # Do quality test.
            u = UA.get_unit(uid, task)
            include, first_tr, last_tr = get_selection_params(u, UnTrSel)
            res = test_sorting.test_qm(u, include, first_tr, last_tr)

            if res is not None:
                d_QC_tests[uid + (task, )] = res.pop('QC_tests')

            # Plot QC results.
            if plot_qm:

                if res is not None:
                    ax_res = pquality.plot_qm(u, fig=fig, sps=gsp[i], **res)

                    # Collect axes.
                    ax_wfs, ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate = ax_res
                    wf_axs.extend(ax_wfs)
                    amp_axs.append(ax_wf_amp)
                    dur_axs.append(ax_wf_dur)
                    amp_dur_axs.append(ax_amp_dur)
                    rate_axs.append(ax_rate)

                else:
                    putil.add_mock_axes(fig, gsp[i])

        if plot_qm:

            # Match axis scales across tasks.
            putil.sync_axes(wf_axs, sync_x=True, sync_y=True)
            putil.sync_axes(amp_axs, sync_y=True)
            putil.sync_axes(dur_axs, sync_y=True)
            putil.sync_axes(amp_dur_axs, sync_x=True, sync_y=True)
            putil.sync_axes(rate_axs, sync_y=True)
            [putil.move_event_lbls(ax, y_lbl=0.92) for ax in rate_axs]

            # Save figure.
            if ftempl is not None:
                uid_str = util.format_uid(uid)
                title = uid_str.replace('_', ' ')
                ffig = ftempl.format(uid_str)
                putil.save_fig(ffig, fig, title, w_pad=15)

    # Collect QC test results.
    QC_tests = pd.DataFrame(d_QC_tests).T

    return QC_tests