Esempio n. 1
0
def plot_DR_3x3(u, fig=None, sps=None):
    """Plot 3x3 direction response plot, with polar plot in center."""

    if not u.to_plot():
        return

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 3, 3)  # inner gsp with subplots

    # Polar plot.
    putil.set_style('notebook', 'white')
    ax_polar = fig.add_subplot(gsp[4], polar=True)
    for stim in constants.stim_dur.index:  # for each stimulus
        stim_resp = u.get_stim_resp_vals(stim, 'Dir')
        resp_stats = util.calc_stim_resp_stats(stim_resp)
        dirs, resp = np.array(resp_stats.index) * deg, resp_stats['mean']
        c = putil.stim_colors[stim]
        baseline = u.get_baseline()
        ptuning.plot_DR(dirs, resp, color=c, baseline=baseline, ax=ax_polar)
    putil.hide_ticks(ax_polar, 'y')

    # Raster-rate plots.
    putil.set_style('notebook', 'ticks')
    rr_pos = [5, 2, 1, 0, 3, 6, 7, 8]  # Position of each direction.
    rr_dir_plot_pos = pd.Series(constants.all_dirs, index=rr_pos)

    rate_axs = []
    for isp, d in rr_dir_plot_pos.iteritems():

        # Prepare plot formatting.
        first_dir = (isp == 0)

        # Plot direction response across trial periods.
        res = plot_SR(u, 'Dir', [d], fig=fig, sps=gsp[isp], no_labels=True)
        draster_axs, drate_axs, _ = res

        # Remove axis ticks.
        for i, ax in enumerate(drate_axs):
            first_prd = (i == 0)
            show_x_tick_lbls = first_dir
            show_y_tick_lbls = first_dir & first_prd
            putil.hide_tick_labels(ax, show_x_tick_lbls, show_y_tick_lbls)

        # Add task name as title (to top center axes).
        if isp == 1:
            ttl = u.get_task() + (' [excluded]' if u.is_excluded() else '')
            putil.set_labels(draster_axs[0],
                             title=ttl,
                             ytitle=1.10,
                             title_kws={'loc': 'right'})

        rate_axs.extend(drate_axs)
        rate_axs.extend(drate_axs)

    # Match scale of y axes.
    putil.sync_axes(rate_axs, sync_y=True)
    [putil.adjust_decorators(ax) for ax in rate_axs]

    return ax_polar, rate_axs
Esempio n. 2
0
def DR_plot(UA, ftempl=None, match_scales=True):
    """Plot responses to all 8 directions and polar plot in the center."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=subw,
                                            subh=subw)
        task_rate_axs, task_polar_axs = [], []

        # Plot direction response of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            # Plot DR of unit.
            res = (pselectivity.plot_DR_3x3(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ax_polar, rate_axs = res
                task_rate_axs.extend(rate_axs)
                task_polar_axs.append(ax_polar)
            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            putil.sync_axes(task_polar_axs, sync_y=True)
            putil.sync_axes(task_rate_axs, sync_y=True)
            [putil.adjust_decorators(ax) for ax in task_rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
Esempio n. 3
0
def selectivity_summary(UA, ftempl=None, match_scales=True):
    """Test selectivity of unit responses."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=8,
                                            subh=16)
        ls_axs, ds_axs = [], []

        # Plot stimulus response summary plot of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            res = (pselectivity.plot_selectivity(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ls_axs.extend(res[0])
                ds_axs.extend(res[1])

            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            for rate_axs in [ls_axs, ds_axs]:
                putil.sync_axes(rate_axs, sync_y=True)
                [putil.adjust_decorators(ax) for ax in rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
Esempio n. 4
0
def plot_CE_time_distribution(ulists,
                              eff_t_res,
                              eff_pars,
                              aroc_res_dir,
                              bins=None):
    """Plot distribution of ROC comparison effect timing across groups."""

    # Init.
    putil.set_style('notebook', 'white')
    if bins is None:
        bins = np.arange(2000, 2600, 50)
    fig, _, axs = putil.get_gs_subplots(nrow=1,
                                        ncol=len(eff_pars),
                                        subw=5,
                                        subh=4,
                                        create_axes=True,
                                        as_array=False)

    # Plot CE timing distribution for each unit group.
    for (eff_dir, eff_lbl), ax in zip(eff_pars, axs):
        etd = eff_t_res.loc[eff_t_res.effect_dir == eff_dir, 'time']
        for nlist in ulists:
            tvals = etd.loc[nlist]
            lbl = '{} (n={})'.format(nlist, len(tvals))
            sns.distplot(tvals, bins, label=lbl, ax=ax)
        putil.set_labels(ax, 'effect timing (ms since S1 onset)', '', eff_lbl)

    # Format plots.
    sns.despine(ax=ax)
    [ax.legend() for ax in axs]
    [putil.hide_tick_labels(ax, show_x_tick_lbls=True) for ax in axs]
    putil.sync_axes(axs, sync_y=True)

    # Save plot.s
    ffig = aroc_res_dir + 'CE/CE_timing_distributions.png'
    putil.save_fig(ffig, fig)
Esempio n. 5
0
def plot_SR(u,
            param=None,
            vals=None,
            from_trs=None,
            prd_pars=None,
            nrate=None,
            colors=None,
            add_roc=False,
            add_prd_name=False,
            fig=None,
            sps=None,
            title=None,
            **kwargs):
    """Plot stimulus response (raster, rate and ROC) for mutliple stimuli."""

    has_nan_val = (util.is_iterable(vals)
                   and len([v for v in vals if np.isnan(float(v))]))
    if not u.to_plot() or has_nan_val:
        return [], [], []

    # Set up stimulus parameters.
    if prd_pars is None:
        prd_pars = u.get_analysis_prds()

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)

    # Create a gridspec for each period.
    wratio = [
        float(min(dur, lmax))
        for dur, lmax in zip(prd_pars.dur, prd_pars.max_len)
    ]
    wspace = 0.0  # 0.05 to separate periods
    gsp = putil.embed_gsp(sps,
                          1,
                          len(prd_pars.index),
                          width_ratios=wratio,
                          wspace=wspace)

    axes_raster, axes_rate, axes_roc = [], [], []

    for i, prd in enumerate(prd_pars.index):

        ppars = prd_pars.loc[prd]
        ref = ppars.ref

        # Prepare trial set.
        if param is None:
            trs = u.ser_inc_trials()
        elif param in u.TrData.columns:
            trs = u.trials_by_param(param, vals)
        else:
            trs = u.trials_by_param((ppars.stim, param), vals)

        if from_trs is not None:
            trs = util.filter_lists(trs, from_trs)

        plot_roc = add_roc and (len(trs) == 2)

        # Init subplots.
        if plot_roc:
            rr_sps, roc_sps = putil.embed_gsp(gsp[i],
                                              2,
                                              1,
                                              hspace=0.2,
                                              height_ratios=[1, .4])
        else:
            rr_sps = putil.embed_gsp(gsp[i], 1, 1, hspace=0.3)[0]

        # Init params.
        prds = [constants.ev_stims.loc[ref]]
        evnts = None
        if ('cue' in ppars) and (ppars.cue is not None):
            evnts = [{'time': ppars.cue}]
            evnts[0]['color'] = (ppars.cue_color if 'cue_color' in ppars.index
                                 else putil.cue_colors['all'])

        if colors is None:
            colcyc = putil.get_colors()
            colors = [next(colcyc) for itrs in range(len(trs))]

        # Plot response on raster and rate plots.
        _, raster_axs, rate_ax = prate.plot_rr(u,
                                               prd,
                                               ref,
                                               prds,
                                               evnts,
                                               nrate,
                                               trs,
                                               ppars.max_len,
                                               cols=colors,
                                               fig=fig,
                                               sps=rr_sps,
                                               **kwargs)

        # Add period name to rate plot.
        if add_prd_name:
            rate_ax.text(0.02,
                         0.95,
                         prd,
                         fontsize=10,
                         color='k',
                         va='top',
                         ha='left',
                         transform=rate_ax.transAxes)

        # Plot ROC curve.
        if plot_roc:
            roc_ax = fig.add_subplot(roc_sps)
            # Init rates.
            plot_params = prate.prep_rr_plot_params(u, prd, ref, nrate, trs)
            _, _, (rates1, rates2), _, _ = plot_params
            # Calculate ROC results.
            aroc = roccore.run_ROC_over_time(rates1, rates2, n_perm=0)
            # Set up plot params and plot results.
            tvec, auc = aroc.index, aroc.auc
            xlim = rate_ax.get_xlim()
            pauc.plot_auc_over_time(auc,
                                    tvec,
                                    prds,
                                    evnts,
                                    xlim=xlim,
                                    ax=roc_ax)

        # Some exception handling: set x axis range to predefined values if
        # it is not set (probably due to no trial data plotted).
        if prd in constants.fixed_tr_prds.index:
            t1, t2 = constants.fixed_tr_prds.loc[prd]
            axs = raster_axs + [rate_ax] + ([roc_ax] if plot_roc else [])
            for ax in axs:
                if ax.get_xlim() == (-0.001, 0.001):
                    ax.set_xlim([t1, t2])

        # Collect axes.
        axes_raster.extend(raster_axs)
        axes_rate.append(rate_ax)
        if plot_roc:
            axes_roc.append(roc_ax)

    # Format rate and roc plots to make them match across trial periods.

    # Remove y-axis label, spine and ticks from second and later periods.
    for ax in axes_rate[1:] + axes_roc[1:]:
        ax.set_ylabel('')
        putil.set_spines(ax, bottom=True, left=False)
        putil.hide_ticks(ax, show_x_ticks=True, show_y_ticks=False)

    # Remove (hide) legend from all but last rate plot.
    [putil.hide_legend(ax) for ax in axes_rate[:-1]]

    # Get middle x point in relative coordinates of first rate axes.
    xranges = np.array([ax.get_xlim() for ax in axes_rate])
    xlens = xranges[:, 1] - xranges[:, 0]
    xmid = xlens.sum() / xlens[0] / 2

    # Add title.
    if title is not None:
        axes_raster[0].set_title(title, x=xmid, y=1.0)

    # Set common x label.
    if ('no_labels' not in kwargs) or (not kwargs['no_labels']):
        [ax.set_xlabel('') for ax in axes_rate + axes_roc]
        xlab = putil.t_lbl.format(prd_pars.stim[0] + ' onset')
        ax = axes_roc[0] if len(axes_roc) else axes_rate[0]
        ax.set_xlabel(xlab, x=xmid)

    # Reformat ticks on x axis. First period is reference.
    for axes in [axes_rate, axes_roc]:
        for ax, lbl_shift in zip(axes, prd_pars.lbl_shift):
            x1, x2 = ax.get_xlim()
            shift = float(lbl_shift)
            tmakrs, tlbls = putil.get_tick_marks_and_labels(
                x1 + shift, x2 + shift)
            putil.set_xtick_labels(ax, tmakrs - shift, tlbls)
    # Remove x tick labels from rate axes if roc's are present.
    if len(axes_roc):
        [ax.tick_params(labelbottom='off') for ax in axes_rate]

    # Match scale of rate and roc plots' y axes.
    for axes in [axes_rate, axes_roc]:
        putil.sync_axes(axes, sync_y=True)
        [putil.adjust_decorators(ax) for ax in axes]

    return axes_raster, axes_rate, axes_roc
Esempio n. 6
0
def quality_test(UA, ftempl=None, plot_qm=False, fselection=None):
    """Test and plot quality metrics of recording and spike sorting """

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # Import unit&trial selection file.
    UnTrSel = pd.read_excel(fselection) if (fselection is not None) else None

    # For each unit over all tasks.
    d_QC_tests = {}
    for uid in UA.uids():

        # Init figure.
        if plot_qm:
            fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                                ncol=len(UA.tasks()),
                                                subw=subw,
                                                subh=1.6 * subw)
            wf_axs, amp_axs, dur_axs, amp_dur_axs, rate_axs = ([], [], [], [],
                                                               [])

        for i, task in enumerate(UA.tasks()):

            # Do quality test.
            u = UA.get_unit(uid, task)
            include, first_tr, last_tr = get_selection_params(u, UnTrSel)
            res = test_sorting.test_qm(u, include, first_tr, last_tr)

            if res is not None:
                d_QC_tests[uid + (task, )] = res.pop('QC_tests')

            # Plot QC results.
            if plot_qm:

                if res is not None:
                    ax_res = pquality.plot_qm(u, fig=fig, sps=gsp[i], **res)

                    # Collect axes.
                    ax_wfs, ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate = ax_res
                    wf_axs.extend(ax_wfs)
                    amp_axs.append(ax_wf_amp)
                    dur_axs.append(ax_wf_dur)
                    amp_dur_axs.append(ax_amp_dur)
                    rate_axs.append(ax_rate)

                else:
                    putil.add_mock_axes(fig, gsp[i])

        if plot_qm:

            # Match axis scales across tasks.
            putil.sync_axes(wf_axs, sync_x=True, sync_y=True)
            putil.sync_axes(amp_axs, sync_y=True)
            putil.sync_axes(dur_axs, sync_y=True)
            putil.sync_axes(amp_dur_axs, sync_x=True, sync_y=True)
            putil.sync_axes(rate_axs, sync_y=True)
            [putil.move_event_lbls(ax, y_lbl=0.92) for ax in rate_axs]

            # Save figure.
            if ftempl is not None:
                uid_str = util.format_uid(uid)
                title = uid_str.replace('_', ' ')
                ffig = ftempl.format(uid_str)
                putil.save_fig(ffig, fig, title, w_pad=15)

    # Collect QC test results.
    QC_tests = pd.DataFrame(d_QC_tests).T

    return QC_tests
Esempio n. 7
0
def plot_combined_rec_mean(recs, stims, res_dir, par_kws,
                           list_n_most_DS, list_min_nunits,
                           n_boot=1e4, ci=95,
                           tasks=None, task_labels=None, add_title=True,
                           fig=None):
    """Test and plot results combined across sessions."""

    # Init.
    # putil.set_style('notebook', 'ticks')
    vkey = 'all'

    # This should be made more explicit!
    prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
            for stim in stims]

    # Load all results to plot.
    dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws)

    # Create figures.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(dict_rt_res),
                                                ncol=len(list_min_nunits),
                                                subw=8, subh=6, fig=fig,
                                                create_axes=True)

    # Query data.
    allScores = {}
    allnunits = {}
    for n_most_DS, rt_res in dict_rt_res.items():
        # Get accuracy scores.
        dScores = {(rec, task): res[vkey]['Scores'].mean()
                   for (rec, task), res in rt_res.items()
                   if (vkey in res) and (res[vkey] is not None)}
        allScores[n_most_DS] = pd.concat(dScores, axis=1).T
        # Get number of units.
        allnunits[n_most_DS] = {(rec, task): res[vkey]['nunits'].iloc[0]
                                for (rec, task), res in rt_res.items()
                                if (vkey in res) and (res[vkey] is not None)}
        # Get # values (for baseline plotting.)
        all_nvals = pd.Series({(rec, task): res[vkey]['nclasses'].iloc[0]
                               for (rec, task), res in rt_res.items()
                               if (vkey in res) and (res[vkey] is not None)})
        un_nvals = all_nvals.unique()
        if len(un_nvals) > 1 and verbose:
            print('Found multiple # of classes to decode: {}'.format(un_nvals))
        nvals = un_nvals[0]

    allnunits = pd.DataFrame(allnunits)

    # Plot mean performance across recordings and
    # test significance by bootstrapping.
    for inmost, n_most_DS in enumerate(list_n_most_DS):
        Scores = allScores[n_most_DS]
        nunits = allnunits[n_most_DS]

        for iminu, min_nunits in enumerate(list_min_nunits):

            ax_scr = axs_scr[inmost, iminu]

            # Select only recordings with minimum number of units.
            sel_rt = nunits.index[nunits >= min_nunits]
            nScores = Scores.loc[sel_rt].copy()

            # Nothing to plot.
            if nScores.empty:
                ax_scr.axis('off')
                continue

            # Prepare data.
            if tasks is None:
                tasks = nScores.index.get_level_values(1).unique()  # in data
            if task_labels is None:
                task_labels = {task: task for task in tasks}
            dScores = {task: pd.DataFrame(nScores.xs(task, level=1).unstack(),
                                          columns=['accuracy'])
                       for task in tasks}
            lScores = pd.concat(dScores, axis=0)
            lScores['time'] = lScores.index.get_level_values(1)
            lScores['task'] = lScores.index.get_level_values(0)
            lScores['rec'] = lScores.index.get_level_values(2)
            lScores.index = np.arange(len(lScores.index))
            lScores.task.replace(task_labels, inplace=True)

            # Add altered task names for legend plotting.
            nrecs = {task_labels[task]: len(nScores.xs(task, level=1))
                     for task in tasks}
            my_format = lambda x: '{} (n={})'.format(x, nrecs[x])
            lScores['task_nrecs'] = lScores['task'].apply(my_format)

            # Plot as time series.
            sns.tsplot(lScores, time='time', value='accuracy', unit='rec',
                       condition='task_nrecs', ci=ci, n_boot=n_boot, ax=ax_scr)

            # Add chance level line.
            chance_lvl = 1.0 / nvals
            putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Add stimulus periods.
            putil.plot_periods(prds, ax=ax_scr)

            # Set axis limits.
            putil.set_limits(ax_scr, tlim)

            # Format plot.
            title = ('{} most DS units'.format(n_most_DS)
                     if n_most_DS != 0 else 'all units')
            title += (', recordings with at least {} units'.format(min_nunits)
                      if (min_nunits > 1 and len(list_min_nunits) > 1) else '')
            ytitle = 1.0
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)
            putil.hide_legend_title(ax_scr)

    # Match axes across decoding plots.
    [putil.sync_axes(axs_scr[inmost, :], sync_y=True)
     for inmost in range(axs_scr.shape[0])]

    # Save plots.
    list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS]
    par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str)
    title = ''
    if add_title:
        title = decutil.fig_title(res_dir, **par_kws)
        title += '\n{}% CE with {} bootstrapped subsamples'.format(ci,
                                                                   int(n_boot))
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str)
    ffig = decutil.fig_fname(res_dir, 'combined_score', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)

    return fig_scr, axs_scr, ffig