Exemplo n.º 1
0
def plot_DR_3x3(u, fig=None, sps=None):
    """Plot 3x3 direction response plot, with polar plot in center."""

    if not u.to_plot():
        return

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 3, 3)  # inner gsp with subplots

    # Polar plot.
    putil.set_style('notebook', 'white')
    ax_polar = fig.add_subplot(gsp[4], polar=True)
    for stim in constants.stim_dur.index:  # for each stimulus
        stim_resp = u.get_stim_resp_vals(stim, 'Dir')
        resp_stats = util.calc_stim_resp_stats(stim_resp)
        dirs, resp = np.array(resp_stats.index) * deg, resp_stats['mean']
        c = putil.stim_colors[stim]
        baseline = u.get_baseline()
        ptuning.plot_DR(dirs, resp, color=c, baseline=baseline, ax=ax_polar)
    putil.hide_ticks(ax_polar, 'y')

    # Raster-rate plots.
    putil.set_style('notebook', 'ticks')
    rr_pos = [5, 2, 1, 0, 3, 6, 7, 8]  # Position of each direction.
    rr_dir_plot_pos = pd.Series(constants.all_dirs, index=rr_pos)

    rate_axs = []
    for isp, d in rr_dir_plot_pos.iteritems():

        # Prepare plot formatting.
        first_dir = (isp == 0)

        # Plot direction response across trial periods.
        res = plot_SR(u, 'Dir', [d], fig=fig, sps=gsp[isp], no_labels=True)
        draster_axs, drate_axs, _ = res

        # Remove axis ticks.
        for i, ax in enumerate(drate_axs):
            first_prd = (i == 0)
            show_x_tick_lbls = first_dir
            show_y_tick_lbls = first_dir & first_prd
            putil.hide_tick_labels(ax, show_x_tick_lbls, show_y_tick_lbls)

        # Add task name as title (to top center axes).
        if isp == 1:
            ttl = u.get_task() + (' [excluded]' if u.is_excluded() else '')
            putil.set_labels(draster_axs[0],
                             title=ttl,
                             ytitle=1.10,
                             title_kws={'loc': 'right'})

        rate_axs.extend(drate_axs)
        rate_axs.extend(drate_axs)

    # Match scale of y axes.
    putil.sync_axes(rate_axs, sync_y=True)
    [putil.adjust_decorators(ax) for ax in rate_axs]

    return ax_polar, rate_axs
Exemplo n.º 2
0
def plot_RF_results(RF_res, stims, fdir, sup_title):
    """Plot receptive field results."""

    # Plot distribution of coverage values.
    putil.set_style('poster', 'white')
    fig = putil.figure()
    ax = putil.axes()
    sns.distplot(RF_res.S1_cover, bins=np.arange(0, 1.01, 0.1),
                 kde=False, rug=True, ax=ax)
    putil.set_limits(ax, [0, 1])
    fst = util.format_to_fname(sup_title)
    ffig = fdir + fst + '_S1_coverage.png'
    putil.save_fig(ffig, fig, sup_title)

    # Plot RF coverage and rate during S1 on regression plot for each
    # recording and task.
    tasks = RF_res.index.get_level_values(-1).unique()
    for vname, ylim in [('mean_rate', [0, None]), ('max_rate', [0, None]),
                        ('mDSI', [0, 1])]:
        fig, gs, axes = putil.get_gs_subplots(nrow=len(stims), ncol=len(tasks),
                                              subw=4, subh=4, ax_kws_list=None,
                                              create_axes=True)
        colors = sns.color_palette('muted', len(tasks))
        for istim, stim in enumerate(stims):
            for itask, task in enumerate(tasks):
                # Plot regression plot.
                ax = axes[istim, itask]
                scov, sval = [stim + '_' + name for name in ('cover', vname)]
                df = RF_res.xs(task, level=-1)
                sns.regplot(scov, sval, df, color=colors[itask], ax=ax)
                # Add unit labels.
                uids = df.index.droplevel(0)
                putil.add_unit_labels(ax, uids, df[scov], df[sval])
                # Add stats.
                r, p = sp.stats.pearsonr(df[sval], df[scov])
                pstr = util.format_pvalue(p)
                txt = 'r = {:.2f}, {}'.format(r, pstr)
                ax.text(0.02, 0.98, txt, va='top', ha='left',
                        transform=ax.transAxes)
                # Set labels.
                title = '{} {}'.format(task, stim)
                xlab, ylab = [sn.replace('_', ' ') for sn in (scov, sval)]
                putil.set_labels(ax, xlab, ylab, title)
                # Set limits.
                xlim = [0, 1]
                putil.set_limits(ax, xlim, ylim)

        # Save plot.
        fst = util.format_to_fname(sup_title)
        fname = '{}_cover_{}.png'.format(fst, vname)
        ffig = util.join([fdir, vname, fname])
        putil.save_fig(ffig, fig, sup_title)
Exemplo n.º 3
0
def plot_spike_count_results(bspk_cnts, rec, task, prds, binsize):
    """Plot spike count results on composite plot for multiple periods."""

    # Init figure.
    putil.set_style('notebook', 'ticks')
    fig, gsp, _ = putil.get_gs_subplots(nrow=len(prds), ncol=1,
                                        subw=15, subh=4, create_axes=False)

    # Plot each period.
    for prd, sps in zip(prds, gsp):
        plot_prd_spike_count(bspk_cnts[prd], prd, sps, fig)

    # Save figure.
    title = '{} {}, binsize: {} ms'.format(rec, task, int(binsize))
    fname = 'UpDown_spk_cnt_hist_{}_{}_bin_{}.png'.format(rec, task,
                                                          int(binsize))
    ffig = util.join(['results', 'UpDown', fname])
    putil.save_fig(ffig, fig, title)
Exemplo n.º 4
0
def DR_plot(UA, ftempl=None, match_scales=True):
    """Plot responses to all 8 directions and polar plot in the center."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=subw,
                                            subh=subw)
        task_rate_axs, task_polar_axs = [], []

        # Plot direction response of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            # Plot DR of unit.
            res = (pselectivity.plot_DR_3x3(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ax_polar, rate_axs = res
                task_rate_axs.extend(rate_axs)
                task_polar_axs.append(ax_polar)
            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            putil.sync_axes(task_polar_axs, sync_y=True)
            putil.sync_axes(task_rate_axs, sync_y=True)
            [putil.adjust_decorators(ax) for ax in task_rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
Exemplo n.º 5
0
def selectivity_summary(UA, ftempl=None, match_scales=True):
    """Test selectivity of unit responses."""

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # For each unit over all tasks.
    for uid in UA.uids():

        # Init figure.
        fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                            ncol=len(UA.tasks()),
                                            subw=8,
                                            subh=16)
        ls_axs, ds_axs = [], []

        # Plot stimulus response summary plot of unit in each task.
        for task, sps in zip(UA.tasks(), gsp):
            u = UA.get_unit(uid, task)

            res = (pselectivity.plot_selectivity(u, fig, sps)
                   if u.to_plot() else None)
            if res is not None:
                ls_axs.extend(res[0])
                ds_axs.extend(res[1])

            else:
                putil.add_mock_axes(fig, sps)

        # Match scale of y axes across tasks.
        if match_scales:
            for rate_axs in [ls_axs, ds_axs]:
                putil.sync_axes(rate_axs, sync_y=True)
                [putil.adjust_decorators(ax) for ax in rate_axs]

        # Save figure.
        if ftempl is not None:
            uid_str = util.format_uid(uid)
            title = uid_str.replace('_', ' ')
            ffig = ftempl.format(uid_str)
            putil.save_fig(ffig, fig, title)
Exemplo n.º 6
0
def plot_CE_time_distribution(ulists,
                              eff_t_res,
                              eff_pars,
                              aroc_res_dir,
                              bins=None):
    """Plot distribution of ROC comparison effect timing across groups."""

    # Init.
    putil.set_style('notebook', 'white')
    if bins is None:
        bins = np.arange(2000, 2600, 50)
    fig, _, axs = putil.get_gs_subplots(nrow=1,
                                        ncol=len(eff_pars),
                                        subw=5,
                                        subh=4,
                                        create_axes=True,
                                        as_array=False)

    # Plot CE timing distribution for each unit group.
    for (eff_dir, eff_lbl), ax in zip(eff_pars, axs):
        etd = eff_t_res.loc[eff_t_res.effect_dir == eff_dir, 'time']
        for nlist in ulists:
            tvals = etd.loc[nlist]
            lbl = '{} (n={})'.format(nlist, len(tvals))
            sns.distplot(tvals, bins, label=lbl, ax=ax)
        putil.set_labels(ax, 'effect timing (ms since S1 onset)', '', eff_lbl)

    # Format plots.
    sns.despine(ax=ax)
    [ax.legend() for ax in axs]
    [putil.hide_tick_labels(ax, show_x_tick_lbls=True) for ax in axs]
    putil.sync_axes(axs, sync_y=True)

    # Save plot.s
    ffig = aroc_res_dir + 'CE/CE_timing_distributions.png'
    putil.save_fig(ffig, fig)
Exemplo n.º 7
0
def plot_trial_type_distribution(UA, RecInfo, utids=None, tr_par=('S1', 'Dir'),
                                 save_plot=False, fname=None):
    """Plot distribution of trial types."""

    # Init.
    par_str = util.format_to_fname(str(tr_par))
    if utids is None:
        utids = UA.utids(as_series=True)
    recs = util.get_subj_date_pairs(utids)
    tasks = RecInfo.index.get_level_values('task').unique()
    tasks = [task for task in UA.tasks() if task in tasks]  # reorder tasks

    # Init plotting.
    putil.set_style('notebook', 'darkgrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks),
                                          subw=4, subh=3, create_axes=True)

    for ir, rec in enumerate(recs):
        for it, task in enumerate(tasks):

            ax = axs[ir, it]
            rt = rec + (task,)
            if rt not in RecInfo.index:
                ax.set_axis_off()
                continue

            # Get includecd trials and their parameters.
            inc_trs = RecInfo.loc[rt, 'trials']
            utid = utids.xs(rt, level=('subj', 'date', 'task'))[0]
            TrData = UA.get_unit(utid[:-1], utid[-1]).TrData.loc[inc_trs]

            # Create DF to plot.
            anw_df = TrData[[tr_par, 'correct']].copy()
            anw_df['answer'] = 'error'
            anw_df.loc[anw_df.correct, 'answer'] = 'correct'
            all_df = anw_df.copy()
            all_df.answer = 'all'
            comb_df = pd.concat([anw_df, all_df])

            if not TrData.size:
                ax.set_axis_off()
                continue

            # Plot as countplot.
            sns.countplot(x=tr_par, hue='answer', data=comb_df,
                          hue_order=['all', 'correct', 'error'], ax=ax)
            sns.despine(ax=ax)
            putil.hide_tick_marks(ax)
            putil.set_max_n_ticks(ax, 6, 'y')
            ax.legend(loc=[0.95, 0.7])

            # Add title.
            title = '{} {}'.format(rec, task)
            nce = anw_df.answer.value_counts()
            nc, ne = [nce[c] if c in nce else 0 for c in ('correct', 'error')]
            pnc, pne = 100*nc/nce.sum(), 100*ne/nce.sum()
            title += '\n\n# correct: {} ({:.0f}%)'.format(nc, pnc)
            title += '      # error: {} ({:.0f}%)'.format(ne, pne)
            putil.set_labels(ax, title=title, xlab=par_str)

            # Format legend.
            if (ir != 0) or (it != 0):
                ax.legend_.remove()

    # Save plot.
    if save_plot:
        title = 'Trial type distribution'
        if fname is None:
            fname = util.join(['results', 'decoding', 'prepare',
                               par_str + '_trial_type_distr.pdf'])

        putil.save_fig(fname, fig, title, w_pad=3, h_pad=3)
Exemplo n.º 8
0
def PD_across_units(UA, UInc, utids=None, fres=None, ffig=None):
    """
    Test consistency/spread of PD across units per recording.
    What is the spread in the preferred directions across units?

    Return population level preferred direction (and direction selectivity),
    that can be used to determine dominant preferred direction to decode.
    """

    # Init.
    if utids is None:
        utids = UA.utids(as_series=True)
    tasks = utids.index.get_level_values('task').unique()
    recs = util.get_subj_date_pairs(utids)

    # Get DS info frame.
    DSInfo = ua_query.get_DSInfo_table(UA, utids)
    DSInfo['include'] = UInc

    # Calculate population PD and DSI.
    dPPDres = {}
    for rec in recs:
        for task in tasks:

            # Init.
            rt = rec + (task,)
            rtDSInfo = DSInfo.xs(rt, level=[0, 1, -1])
            if rtDSInfo.empty:
                continue

            # Calculate population PD and population DSI.
            res = direction.calc_PPD(rtDSInfo.loc[rtDSInfo.include])
            dPPDres[rt] = res

    PPDres = pd.DataFrame(dPPDres).T

    # Save results.
    if fres is not None:
        results = {'DSInfo': DSInfo, 'PPDres': PPDres}
        util.write_objects(results, fres)

    # Plot results.

    # Init plotting.
    putil.set_style('notebook', 'darkgrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks),
                                          subw=6, subh=6, create_axes=True,
                                          ax_kws_list={'projection': 'polar'})
    xticks = direction.deg2rad(constants.all_dirs + 360/8/2*deg)

    for ir, rec in enumerate(recs):
        for it, task in enumerate(tasks):

            # Init.
            rt = rec + (task,)
            rtDSInfo = DSInfo.xs(rt, level=[0, 1, -1])
            ax = axs[ir, it]
            if rtDSInfo.empty:
                ax.set_axis_off()
                continue
            PDSI, PPD, PPDc, PADc = PPDres.loc[rt]

            # Plot PD - DSI on polar plot.
            sPPDc, sPADc = [int(v) if not np.isnan(v) else v
                            for v in (PPDc, PADc)]
            title = (' '.join(rt) + '\n' +
                     'PPDc = {}$^\circ$ - {}$^\circ$'.format(PPDc, PADc) +
                     ', PDSI = {:.2f}'.format(PDSI))
            PDrad = direction.deg2rad(util.remove_dim_from_series(rtDSInfo.PD))
            pplot.scatter(PDrad, rtDSInfo.DSI, rtDSInfo.include, ylim=[0, 1],
                          title=title, ytitle=1.08, c='darkblue', ec='k',
                          linewidth=1, s=80, alpha=0.8, zorder=2, ax=ax)

            # Highlight PPD and PAD.
            offsets = np.array([-45, 0, 45]) * deg
            for D, c in [(PPDc, 'g'), (PADc, 'r')]:
                if np.isnan(D):
                    continue
                hlDs = direction.deg2rad(np.array(D+offsets))
                for hlD, alpha in [(hlDs, 0.2), ([hlDs[1]], 0.4)]:
                    pplot.bars(hlD, len(hlD)*[1], align='center',
                               alpha=alpha, color=c, zorder=1, ax=ax)

            # Format ticks.
            ax.set_xticks(xticks, minor=True)
            ax.grid(b=True, axis='x', which='minor')
            ax.grid(b=False, axis='x', which='major')
            putil.hide_tick_marks(ax)

    # Save plot.
    title = 'Population direction selectivity'
    putil.save_fig(ffig, fig, title, w_pad=12, h_pad=20)

    return DSInfo, PPDres
Exemplo n.º 9
0
def select_units_trials(UA, utids=None, fres=None, ffig=None,
                        min_n_units=5, min_n_trs_per_unit=5):
    """
    Select optimal set of units and trials for population decoding.

    min_n_units: minimum number of units to keep (0: off)
    min_n_trs_per_unit: minimum number of trials per unit to keep (0: off)
    """

    print('Selecting optimal set of units and trials for decoding...')

    # Init.
    if utids is None:
        utids = UA.utids(as_series=True)
    u_rt_grpby = utids.groupby(level=['subj', 'date', 'task'])

    # Unit info frame.
    UInc = pd.Series(False, index=utids.index)

    # Included trials by unit.
    IncTrs = pd.Series([(UA.get_unit(utid[:-1], utid[-1]).inc_trials())
                        for utid in utids], index=utids.index)

    # Result DF.
    rec_task = pd.MultiIndex.from_tuples([rt for rt, _ in u_rt_grpby],
                                         names=['subj', 'date', 'task'])
    cols = ['elec', 'units', 'nunits', 'nallunits', '% remaining units',
            'trials', 'ntrials', 'nalltrials', '% remaining trials']
    RecInfo = pd.DataFrame(index=rec_task, columns=cols)
    rt_utids = [utids.xs((s, d, t), level=('subj', 'date', 'task'))
                for s, d, t in rec_task]
    RecInfo.nallunits = [len(utids) for utids in rt_utids]
    rt_ulist = [UA.get_unit(utids[0][:-1], utids[0][-1]) for utids in rt_utids]
    RecInfo.nalltrials = [int(u.QualityMetrics['NTrialsTotal'])
                          for u in rt_ulist]

    # Function to plot matrix (DF) included/excluded trials.
    def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None,
                            xlab='Trial #', ylab=None,):
        # Plot on heatmap.
        sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax)
        # Set tick labels.
        putil.hide_tick_marks(ax)
        tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25))
        ax.xaxis.set_ticks(tr_ticks)
        ax.set_xticklabels(tr_ticks)
        putil.rot_xtick_labels(ax, 0)
        putil.rot_ytick_labels(ax, 0, va='center')
        putil.set_labels(ax, xlab, ylab, title, ytitle)

    # Init plotting.
    ytitle = 1.40
    putil.set_style('notebook', 'whitegrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(rec_task), ncol=3,
                                          subw=6, subh=4, create_axes=True)

    for i_rt, ((subj, date, task), rt_utids) in enumerate(u_rt_grpby):
        print('{} / {}: {} - {} {}'.format(i_rt+1, len(u_rt_grpby), subj,
                                           date, task))

        # Init electrode.
        elecs = rt_utids.index.get_level_values('elec').unique()
        if len(elecs) != 1:
            warnings.warn('More than one electrode?')
        elec = elecs[0]
        RecInfo.loc[(subj, date, task), 'elec'] = elec

        # Create matrix of included trials of recording & task of units.
        ch_idxs = rt_utids.index.droplevel(-1).droplevel(2).droplevel(1).droplevel(0)
        n_alltrs = RecInfo.nalltrials[(subj, date, task)]
        IncTrsMat = pd.DataFrame(np.zeros((len(ch_idxs), n_alltrs), dtype=int),
                                 index=ch_idxs, columns=np.arange(n_alltrs)+1)
        for ch_idx, utid in zip(ch_idxs, rt_utids):
            IncTrsMat.loc[ch_idx].iloc[IncTrs[utid]] = 1

        # Plot included/excluded trials after preprocessing.
        ax = axs[i_rt, 0]
        ylab = '{} {} {}'.format(subj, date, task)
        title = ('Included (green) and excluded (red) trials'
                 if i_rt == 0 else None)
        plot_inc_exc_trials(IncTrsMat, ax, title, ytitle, ylab=ylab)

        # Calculate and plot overlap of trials across units.
        # How many trials will remain if we iteratively excluding units
        # with the least overlap with the rest of the units?
        def n_cov_trs(df):  # return number of trials covered in df
            return sum(df.all())

        def calc_heuristic(df):
            return df.shape[0] * n_cov_trs(df)

        n_trs = IncTrsMat.sum(1)
        n_units = IncTrsMat.shape[0]

        # Init results DF.
        columns = ('uid', 'ntrs_cov', 'n_rem_u', 'trial x units')
        tr_covs = pd.DataFrame(columns=columns, index=range(n_units+1))
        tr_covs.loc[0] = ('none', n_cov_trs(IncTrsMat), n_units,
                          calc_heuristic(IncTrsMat))

        # Subset of included units (to be updated in each iteration).
        uinc = IncTrsMat.index.to_series()
        for iu in range(1, len(uinc)):

            # Number of covered trials after removing each unit.
            sntrscov = pd.Series([n_cov_trs(IncTrsMat.loc[uinc.drop(uid)])
                                  for uid in uinc], index=uinc.index)

            #########################################
            # Select and remove unit that           #
            # (a) yields maximum trial coverage,    #
            # (b) has minimum number of trials      #
            #########################################
            maxtrscov = sntrscov.max()
            worst_us = sntrscov[sntrscov == maxtrscov].index  # (a)
            utrs = n_trs.loc[worst_us]
            uid_remove = utrs[(utrs == min(utrs))].index[0]   # (b)

            # Update current subset of units and their trial DF.
            uinc.drop(uid_remove, inplace=True)
            tr_covs.loc[iu] = (uid_remove, maxtrscov, len(uinc),
                               calc_heuristic(IncTrsMat.loc[uinc]))

        # Add last unit.
        tr_covs.iloc[-1] = (uinc[0], 0, 0, 0)

        # Plot covered trials against each units removed.
        ax_trc = axs[i_rt, 1]
        sns.tsplot(tr_covs['ntrs_cov'], marker='o', ms=4, color='b',
                   ax=ax_trc)
        title = ('Trial coverage during iterative unit removal'
                 if i_rt == 0 else None)
        xlab, ylab = 'current unit removed', '# trials covered'
        putil.set_labels(ax_trc, xlab, ylab, title, ytitle)
        ax_trc.xaxis.set_ticks(tr_covs.index)
        x_ticklabs = ['none'] + ['{} - {}'.format(ch, ui)
                                 for ch, ui in tr_covs.uid.loc[1:]]
        ax_trc.set_xticklabels(x_ticklabs)
        putil.rot_xtick_labels(ax_trc, 45)
        ax_trc.grid(True)

        # Add # of remaining units to top.
        ax_remu = ax_trc.twiny()
        ax_remu.xaxis.set_ticks(tr_covs.index)
        ax_remu.set_xticklabels(list(range(len(x_ticklabs)))[::-1])
        ax_remu.set_xlabel('# units remaining')
        ax_remu.grid(None)

        # Add heuristic index.
        ax_heur = ax_trc.twinx()
        sns.tsplot(tr_covs['trial x units'], linestyle='--', marker='o',
                   ms=4, color='m',  ax=ax_heur)
        putil.set_labels(ax_heur, ylab='remaining units x covered trials')
        [tl.set_color('m') for tl in ax_heur.get_yticklabels()]
        [tl.set_color('b') for tl in ax_trc.get_yticklabels()]
        ax_heur.grid(None)

        # Decide on which units to exclude.
        min_n_trials = min_n_trs_per_unit * tr_covs['n_rem_u']
        sub_tr_covs = tr_covs[(tr_covs['n_rem_u'] >= min_n_units) &
                              (tr_covs['ntrs_cov'] >= min_n_trials)]

        # If any subset of units passed above criteria.
        rem_uids, exc_uids = pd.Series(), tr_covs.uid[1:]
        n_tr_rem, n_tr_exc = 0, IncTrsMat.shape[1]
        if len(sub_tr_covs.index):
            hmax_idx = sub_tr_covs['trial x units'].argmax()
            rem_uids = tr_covs.uid[(hmax_idx+1):]
            exc_uids = tr_covs.uid[1:hmax_idx+1]
            n_tr_rem = tr_covs.ntrs_cov[hmax_idx]
            n_tr_exc = IncTrsMat.shape[1] - n_tr_rem

            # Add to UnitInfo dataframe
            rt_utids = [(subj, date, elec, ch, ui, task)
                        for ch, ui in rem_uids]
            UInc[rt_utids] = True

        # Highlight selected point in middle plot.
        sel_seg = [('selection', exc_uids.shape[0]-0.4,
                    exc_uids.shape[0]+0.4)]
        putil.plot_periods(sel_seg, ax=ax_trc, alpha=0.3)
        [ax.set_xlim([-0.5, n_units+0.5]) for ax in (ax_trc, ax_remu)]

        # Generate remaining trials dataframe.
        RemTrsMat = IncTrsMat.copy().astype(float)
        for exc_uid in exc_uids:   # Remove all trials from excluded units.
            RemTrsMat.loc[exc_uid] = 0.5
        # Remove uncovered trials in remaining units.
        exc_trs = np.where(~RemTrsMat.loc[list(rem_uids)].all())[0]
        if exc_trs.size:
            RemTrsMat.iloc[:, exc_trs] = 0.5
        # Overwrite by trials excluded during preprocessing.
        RemTrsMat[IncTrsMat == False] = 0.0

        # Plot remaining trials.
        ax = axs[i_rt, 2]
        n_u_rem, n_u_exc = len(rem_uids), len(exc_uids)
        title = ('# units remaining: {}, excluded: {}'.format(n_u_rem,
                                                              n_u_exc) +
                 '\n# trials remaining: {}, excluded: {}'.format(n_tr_rem,
                                                                 n_tr_exc))
        plot_inc_exc_trials(RemTrsMat, ax, title=title, ylab='')

        # Add remaining units and trials to RecInfo.
        rt = (subj, date, task)
        RecInfo.loc[rt, ('units', 'nunits')] = list(rem_uids), len(rem_uids)
        cov_trs = RemTrsMat.loc[list(rem_uids)].all()
        inc_trs = pd.Int64Index(np.where(cov_trs)[0])
        RecInfo.loc[rt, ('trials', 'ntrials')] = inc_trs, sum(cov_trs)

    RecInfo['% remaining units'] = 100 * RecInfo.nunits / RecInfo.nallunits
    RecInfo['% remaining trials'] = 100 * RecInfo.ntrials / RecInfo.nalltrials

    # Save results.
    if fres is not None:
        results = {'RecInfo': RecInfo, 'UInc': UInc}
        util.write_objects(results, fres)

    # Save plot.
    title = 'Trial & unit selection prior decoding'
    putil.save_fig(ffig, fig, title, w_pad=3, h_pad=3)

    return RecInfo, UInc
Exemplo n.º 10
0
def quality_test(UA, ftempl=None, plot_qm=False, fselection=None):
    """Test and plot quality metrics of recording and spike sorting """

    # Init plotting theme.
    putil.set_style('notebook', 'ticks')

    # Import unit&trial selection file.
    UnTrSel = pd.read_excel(fselection) if (fselection is not None) else None

    # For each unit over all tasks.
    d_QC_tests = {}
    for uid in UA.uids():

        # Init figure.
        if plot_qm:
            fig, gsp, _ = putil.get_gs_subplots(nrow=1,
                                                ncol=len(UA.tasks()),
                                                subw=subw,
                                                subh=1.6 * subw)
            wf_axs, amp_axs, dur_axs, amp_dur_axs, rate_axs = ([], [], [], [],
                                                               [])

        for i, task in enumerate(UA.tasks()):

            # Do quality test.
            u = UA.get_unit(uid, task)
            include, first_tr, last_tr = get_selection_params(u, UnTrSel)
            res = test_sorting.test_qm(u, include, first_tr, last_tr)

            if res is not None:
                d_QC_tests[uid + (task, )] = res.pop('QC_tests')

            # Plot QC results.
            if plot_qm:

                if res is not None:
                    ax_res = pquality.plot_qm(u, fig=fig, sps=gsp[i], **res)

                    # Collect axes.
                    ax_wfs, ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate = ax_res
                    wf_axs.extend(ax_wfs)
                    amp_axs.append(ax_wf_amp)
                    dur_axs.append(ax_wf_dur)
                    amp_dur_axs.append(ax_amp_dur)
                    rate_axs.append(ax_rate)

                else:
                    putil.add_mock_axes(fig, gsp[i])

        if plot_qm:

            # Match axis scales across tasks.
            putil.sync_axes(wf_axs, sync_x=True, sync_y=True)
            putil.sync_axes(amp_axs, sync_y=True)
            putil.sync_axes(dur_axs, sync_y=True)
            putil.sync_axes(amp_dur_axs, sync_x=True, sync_y=True)
            putil.sync_axes(rate_axs, sync_y=True)
            [putil.move_event_lbls(ax, y_lbl=0.92) for ax in rate_axs]

            # Save figure.
            if ftempl is not None:
                uid_str = util.format_uid(uid)
                title = uid_str.replace('_', ' ')
                ffig = ftempl.format(uid_str)
                putil.save_fig(ffig, fig, title, w_pad=15)

    # Collect QC test results.
    QC_tests = pd.DataFrame(d_QC_tests).T

    return QC_tests
Exemplo n.º 11
0
def plot_scores_across_nunits(recs, stims, res_dir, list_n_most_DS, par_kws):
    """
    Plot prediction score results across different number of units included.
    """

    # Init.
    putil.set_style('notebook', 'ticks')
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load all results to plot.
    dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws)

    # Create figures.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)
    # Do plotting per recording and task.
    for irec, rec in enumerate(recs):
        if verbose:
            print('\n' + rec)
        for itask, task in enumerate(tasks):
            if verbose:
                print('    ' + task)

            ax_scr = axs_scr[irec, itask]

            # Init data.
            dict_lScores = {}
            cols = sns.color_palette('hls', len(dict_rt_res.keys()))
            lncls = []
            for (n_most_DS, rt_res), col in zip(dict_rt_res.items(), cols):

                # Check if results exist for rec-task combination.
                if (((rec, task) not in rt_res.keys()) or
                    (not len(rt_res[(rec, task)].keys()))):
                    continue

                res = rt_res[(rec, task)]
                for v, col in zip(res.keys(), cols):
                    vres = res[v]
                    Scores = vres['Scores']
                    lncls.append(vres['nclasses'])

                    # Unstack dataframe with results.
                    lScores = pd.DataFrame(Scores.unstack(), columns=['score'])
                    lScores['time'] = lScores.index.get_level_values(0)
                    lScores['fold'] = lScores.index.get_level_values(1)
                    lScores.index = np.arange(len(lScores.index))

                    # Get number of units tested.
                    nunits = vres['nunits']
                    uni_nunits = nunits.unique()
                    if len(uni_nunits) > 1 and verbose:
                        print('Different number of units found.')
                    nunits = uni_nunits[0]

                    # Collect results.
                    dict_lScores[(nunits, v)] = lScores

            # Skip rest if no data is available.
            # Check if any result exists for rec-task combination.
            if not len(dict_lScores):
                ax_scr.axis('off')
                continue

            # Concatenate accuracy scores from every recording.
            all_lScores = pd.concat(dict_lScores)
            all_lScores['n_most_DS'] = all_lScores.index.get_level_values(0)
            all_lScores.index = np.arange(len(all_lScores.index))

            # Plot decoding results.
            nnunits = len(all_lScores['n_most_DS'].unique())
            title = '{} {}, {} sets of units'.format(' '.join(rec), task,
                                                     nnunits)
            ytitle = 1.0
            prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                    for stim in stims]

            # Plot time series.
            palette = sns.color_palette('muted')
            sns.tsplot(all_lScores, time='time', value='score', unit='fold',
                       condition='n_most_DS', color=palette, ax=ax_scr)

            # Add chance level line.
            # This currently plots a chance level line for every nvals,
            # combined across stimulus period!
            uni_ncls = np.unique(np.array(lncls).flatten())
            if len(uni_ncls) > 1 and verbose:
                print('Different number of classes found.')
            for nvals in uni_ncls:
                chance_lvl = 1.0 / nvals
                putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Add stimulus periods.
            if prds is not None:
                putil.plot_periods(prds, ax=ax_scr)

            # Set axis limits.
            putil.set_limits(ax_scr, tlim, ylim_scr)

            # Format plot.
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)

    # Match axes across decoding plots.
    # [putil.sync_axes(axs_scr[:, itask], sync_y=True)
    #  for itask in range(axs_scr.shape[1])]

    # Save plots.
    list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS]
    par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str)
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str)
    ffig = decutil.fig_fname(res_dir, 'score_nunits', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
Exemplo n.º 12
0
def plot_score_multi_rec(recs, stims, res_dir, par_kws):
    """Plot prediction scores for multiple recordings."""

    # Init.
    putil.set_style('notebook', 'ticks')
    n_most_DS = par_kws['n_most_DS']
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load results.
    rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS]

    # Create figure.
    ret = putil.get_gs_subplots(nrow=1, ncol=len(tasks),
                                subw=8, subh=6, create_axes=True)
    fig_scr, _, axs_scr = ret

    print('\nPlotting multi-recording results...')
    for itask, task in enumerate(tasks):
        if verbose:
            print('    ' + task)
        ax_scr = axs_scr[0, itask]

        dict_lScores = {}
        for irec, rec in enumerate(recs):

            # Check if results exist for rec-task combination.
            if (((rec, task) not in rt_res.keys()) or
               (not len(rt_res[(rec, task)].keys()))):
                continue

            # Init data.
            res = rt_res[(rec, task)]
            cols = sns.color_palette('hls', len(res.keys()))
            lncls = []
            for v, col in zip(res.keys(), cols):
                vres = res[v]
                if vres is None:
                    continue

                Scores = vres['Scores']
                lncls.append(vres['nclasses'])

                # Unstack dataframe with results.
                lScores = pd.DataFrame(Scores.unstack(), columns=['score'])
                lScores['time'] = lScores.index.get_level_values(0)
                lScores['fold'] = lScores.index.get_level_values(1)
                lScores.index = np.arange(len(lScores.index))

                dict_lScores[(rec, v)] = lScores

        if not len(dict_lScores):
            ax_scr.axis('off')
            continue

        # Concatenate accuracy scores from every recording.
        all_lScores = pd.concat(dict_lScores)
        all_lScores['rec'] = all_lScores.index.get_level_values(0)
        all_lScores['rec'] = all_lScores['rec'].str.join(' ')  # format label
        all_lScores.index = np.arange(len(all_lScores.index))

        # Plot decoding results.
        nrec = len(all_lScores['rec'].unique())
        title = '{}, {} recordings'.format(task, nrec)
        ytitle = 1.0
        prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                for stim in stims]

        # Plot time series.
        palette = sns.color_palette('muted')
        sns.tsplot(all_lScores, time='time', value='score', unit='fold',
                   condition='rec', color=palette, ax=ax_scr)

        # Add chance level line.
        # This currently plots a chance level line for every nvals,
        # combined across stimulus period!
        uni_ncls = np.unique(np.array(lncls).flatten())
        if len(uni_ncls) > 1 and verbose:
            print('Different number of classes found.')
        for nvals in uni_ncls:
            chance_lvl = 1.0 / nvals
            putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

        # Add stimulus periods.
        if prds is not None:
            putil.plot_periods(prds, ax=ax_scr)

        # Set axis limits.
        putil.set_limits(ax_scr, tlim, ylim_scr)

        # Format plot.
        putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)

    # Save figure.
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3
    ffig = decutil.fig_fname(res_dir, 'all_scores', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
Exemplo n.º 13
0
def plot_scores_weights(recs, stims, res_dir, par_kws):
    """
    Plot prediction scores and model weights for given recording and analysis.
    """

    # Init.
    putil.set_style('notebook', 'ticks')
    n_most_DS = par_kws['n_most_DS']
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load results.
    rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS]

    # Create figures.
    # For prediction scores.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)

    # For unit weights (coefficients).
    fig_wgt, _, axs_wgt = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)

    for irec, rec in enumerate(recs):
        if verbose:
            print('\n' + rec)
        for itask, task in enumerate(tasks):
            if verbose:
                print('    ' + task)

            # Init figures.
            ax_scr = axs_scr[irec, itask]
            ax_wgt = axs_wgt[irec, itask]

            # Check if any result exists for rec-task combination.
            if (((rec, task) not in rt_res.keys()) or
               (not len(rt_res[(rec, task)].keys()))):
                ax_scr.axis('off')
                ax_wgt.axis('off')
                continue

            # Init data.
            res = rt_res[(rec, task)]
            vals = [v for v in res.keys() if not util.is_null(res[v])]
            cols = sns.color_palette('hls', len(vals))
            lnunits, lntrs, lncls,  = [], [], []
            for v, col in zip(vals, cols):
                # Basic results.
                vres = res[v]
                Scores = vres['Scores']
                Coefs = vres['Coefs']
                Perm = vres['Perm']
                Psdo = vres['Psdo']
                # Decoding params.
                lnunits.append(vres['nunits'])
                lntrs.append(vres['ntrials'])
                lncls.append(vres['nclasses'])
                # Plot decoding accuracy.
                plot_scores(ax_scr, Scores, Perm, Psdo, col=col)

            # Add labels.
            uni_lnunits = np.unique(np.array(lnunits).flatten())
            if len(uni_lnunits) > 1 and verbose:
                print('Different number of units found.')
            nunits = uni_lnunits[0]
            title = '{} {}, {} units'.format(' '.join(rec), task, nunits)
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle=1.04)

            # Add chance level line.
            uni_ncls = np.unique(np.array(lncls).flatten())
            if len(uni_ncls) > 1 and verbose:
                print('Different number of classes found.')
            for nvals in uni_ncls:
                chance_lvl = 1.0 / nvals
                putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Plot stimulus periods.
            prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                    for stim in stims]
            putil.plot_periods(prds, ax=ax_scr)

            # Plot unit weights over time.
            plot_weights(ax_wgt, Coefs, prds, tlim, tlab, title=title)

    # Match axes across decoding plots.
    # [putil.sync_axes(axs_scr[:, itask], sync_y=True)
    #  for itask in range(axs_scr.shape[1])]

    # Save plots.
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    # Performance.
    ffig = decutil.fig_fname(res_dir, 'score', 'pdf', **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)

    # Weights.
    ffig = decutil.fig_fname(res_dir, 'weight', 'pdf', **par_kws)
    putil.save_fig(ffig, fig_wgt, title, fs_title, w_pad=w_pad, h_pad=h_pad)
Exemplo n.º 14
0
def plot_up_down_raster(Spikes, task, rec, itrs):
    """Plot spike raster for up-down dynamics analysis."""

    # Set params for plotting.
    uids, trs = Spikes.index, Spikes.columns
    plot_trs = trs[itrs]
    ntrs = len(plot_trs)
    nunits = len(uids)
    tr_gap = nunits / 2

    # Init figure.
    putil.set_style('notebook', 'ticks')
    fig = putil.figure(figsize=(10, ntrs))
    ax = fig.add_subplot(111)

    # Per trial, per unit.
    for itr, tr in enumerate(plot_trs):
        for iu, uid in enumerate(uids):

            # Init y level and spike times.
            i = (tr_gap + nunits) * itr + iu
            spk_tr = Spikes.loc[uid, tr]

            # Plot (spike time, y-level) pairs.
            x = np.array(spk_tr.rescale('ms'))
            y = (i+1) * np.ones_like(x)

            patches = [Rectangle((xi-wsp/2, yi-hsp/2), wsp, hsp)
                       for xi, yi in zip(x, y)]
            collection = PatchCollection(patches, facecolor=c, edgecolor=c)
            ax.add_collection(collection)

    # Add stimulus lines.
    for stim in constants.stim_dur.index:
        t_start, t_stop = constants.fixed_tr_prds.loc[stim]
        events = pd.DataFrame([(t_start, 't_start'), (t_stop, 't_stop')],
                              index=['start', 'stop'],
                              columns=['time', 'label'])
        putil.plot_events(events, add_names=False, color='grey',
                          alpha=0.5, ls='-', lw=0.5, ax=ax)

    # Add inter-trial shading.
    for itr in range(ntrs+1):
        ymin = itr * (tr_gap + nunits) - tr_gap + 0.5
        ax.axhspan(ymin, ymin+tr_gap, alpha=.05, color='grey')

    # Set tick labels.
    pos = np.arange(ntrs) * (tr_gap + nunits) + nunits/2
    lbls = plot_trs + 1
    putil.set_ytick_labels(ax, pos, lbls)
    # putil.sparsify_tick_labels(ax, 'y', freq=2, istart=1)
    putil.hide_tick_marks(ax, show_x_tick_mrks=True)

    # Format plot.
    xlim = constants.fixed_tr_prds.loc['whole trial']
    ylim = [-tr_gap/2, ntrs * (nunits+tr_gap)-tr_gap/2]
    xlab = 'Time since S1 onset (ms)'
    ylab = 'Trial number'
    title = '{} {}'.format(rec, task)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.set_spines(ax, True, False, False, False)

    # Save figure.
    fname = 'UpDown_dynamics_{}_{}.pdf'.format(rec, task)
    ffig = util.join(['results', 'UpDown', fname])
    putil.save_fig(ffig, fig, dpi=600)