예제 #1
0
def plot_weights(ax, Coefs, prds=None, xlim=None, xlab=tlab,
                 ylab='unit coefficient', title='', ytitle=1.04):
    """Plot decoding weights."""

    # Unstack dataframe with results.
    lCoefs = pd.DataFrame(Coefs.unstack().unstack(), columns=['coef'])
    lCoefs['time'] = lCoefs.index.get_level_values(0)
    lCoefs['value'] = lCoefs.index.get_level_values(1)
    lCoefs['uid'] = lCoefs.index.get_level_values(2)
    lCoefs.index = np.arange(len(lCoefs.index))

    # Plot time series.
    sns.tsplot(lCoefs, time='time', value='coef', unit='value',
               condition='uid', ax=ax)

    # Add chance level line and stimulus periods.
    putil.add_chance_level(ax=ax, ylevel=0)
    putil.plot_periods(prds, ax=ax)

    # Set axis limits.
    xlim = xlim if xlim is not None else tlim
    putil.set_limits(ax, xlim)

    # Format plot.
    putil.set_labels(ax, xlab, ylab, title, ytitle)
    putil.hide_legend(ax)
예제 #2
0
def plot_mean_rates(mRates,
                    aa_res_dir,
                    tasks=None,
                    task_lbls=None,
                    baseline=None,
                    xlim=None,
                    ylim=None,
                    ci=68,
                    ffig=None,
                    figsize=(6, 4)):
    """Plot mean rates across tasks."""

    # Init.
    if tasks is None:
        tasks = mRates.keys()

    # Plot mean activity.
    lRates = []
    for task in tasks:
        lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate'])
        lrates['task'] = task
        lRates.append(lrates)
    lRates = pd.concat(lRates)
    lRates['time'] = lRates.index.get_level_values(0)
    lRates['unit'] = lRates.index.get_level_values(1)

    if task_lbls is not None:
        lRates.task.replace(task_lbls, inplace=True)

    # Plot as time series.
    #putil.set_style('notebook', 'white')
    fig = putil.figure(figsize=figsize)
    ax = putil.axes()

    sns.tsplot(lRates,
               time='time',
               value='rate',
               unit='unit',
               condition='task',
               ci=ci,
               ax=ax)

    # Add periods and baseline.
    putil.plot_periods(ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)

    # Format plot.
    sns.despine(ax=ax)
    putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)')
    putil.set_limits(ax, xlim, ylim)
    putil.hide_legend_title(ax)

    # Save plot.
    putil.save_fig(ffig, fig)

    return fig, ax
예제 #3
0
def raster(spk_trains,
           t_unit=ms,
           prds=None,
           c='b',
           xlim=None,
           title=None,
           xlab=None,
           ylab=None,
           ffig=None,
           ax=None):
    """Plot rasterplot."""

    # Init.
    ax = putil.axes(ax)

    putil.plot_periods(prds, ax=ax)
    putil.set_limits(ax, xlim)

    # There's nothing to plot.
    if not len(spk_trains):
        return ax

    # Plot raster.
    for i, spk_tr in enumerate(spk_trains):
        x = np.array(spk_tr.rescale(t_unit))
        y = (i + 1) * np.ones_like(x)

        # Spike markers are plotted in absolute size (figure coordinates).
        # ax.scatter(x, y, c=c, s=1.8, edgecolor=c, marker='|')

        # Spike markers are plotted in relative size (axis coordinates)
        patches = [
            Rectangle((xi - wsp / 2, yi - hsp / 2), wsp, hsp)
            for xi, yi in zip(x, y)
        ]
        collection = PatchCollection(patches, facecolor=c, edgecolor=c)
        ax.add_collection(collection)

    # Format plot.
    ylim = [0.5, len(spk_trains) + 0.5] if len(spk_trains) else [0, 1]
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    putil.hide_axes(ax, show_x=True)
    putil.hide_spines(ax)

    # Order trials from top to bottom, only after setting axis limits.
    ax.invert_yaxis()

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #4
0
파일: pauc.py 프로젝트: mnislamraju/seal
def plot_ROC_mean(d_faroc,
                  t1=None,
                  t2=None,
                  ylim=None,
                  colors=None,
                  ylab='AROC',
                  ffig=None):
    """Plot mean ROC curves over given period."""

    # Import results.
    d_aroc = {}
    for name, faroc in d_faroc.items():
        aroc = util.read_objects(faroc, 'aroc')
        d_aroc[name] = aroc.unstack().T

    # Format results.
    laroc = pd.DataFrame(pd.concat(d_aroc), columns=['aroc'])
    laroc['task'] = laroc.index.get_level_values(0)
    laroc['time'] = laroc.index.get_level_values(1)
    laroc['unit'] = laroc.index.get_level_values(2)
    laroc.index = np.arange(len(laroc.index))

    # Init figure.
    fig = putil.figure(figsize=(6, 6))
    ax = sns.tsplot(laroc,
                    time='time',
                    value='aroc',
                    unit='unit',
                    condition='task',
                    color=colors)

    # Highlight stimulus periods.
    putil.plot_periods(ax=ax)

    # Plot mean results.
    [ax.lines[i].set_linewidth(3) for i in range(len(ax.lines))]

    # Add chance level line.
    putil.add_chance_level(ax=ax, alpha=0.8, color='k')
    ax.lines[-1].set_linewidth(1.5)

    # Format plot.
    xlab = 'Time since S1 onset (ms)'
    putil.set_labels(ax, xlab, ylab)
    putil.set_limits(ax, [t1, t2], ylim)
    putil.set_spines(ax, bottom=True, left=True, top=False, right=False)
    putil.set_legend(ax, loc=0)

    # Save plot.
    putil.save_fig(ffig, fig, ytitle=1.05, w_pad=15)
예제 #5
0
def plot_scores(ax, Scores, Perm=None, Psdo=None, nvals=None, prds=None,
                col='b', perm_col='grey', psdo_col='g', xlim=None,
                ylim=ylim_scr, xlab=tlab, ylab=ylab_scr, title='',
                ytitle=1.04):
    """Plot decoding accuracy results."""

    lgn_patches = []

    # Plot permuted results (if exist).
    if not util.is_null(Perm) and not Perm.isnull().all().all():
        x, pval = Perm.columns, Perm.loc['pval']
        ymean, ystd = Perm.loc['mean'], Perm.loc['std']
        plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=6,
                            color=perm_col, ax=ax)
        lgn_patches.append(putil.get_artist('permuted', perm_col))

    # Plot population shuffled results (if exist).
    if not util.is_null(Psdo) and not Psdo.isnull().all().all():
        x, pval = Psdo.columns, Psdo.loc['pval']
        ymean, ystd = Psdo.loc['mean'], Psdo.loc['std']
        plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=3,
                            color=psdo_col, ax=ax)
        lgn_patches.append(putil.get_artist('pseudo-population', psdo_col))

    # Plot scores.
    plot_score_set(Scores, ax, color=col)
    lgn_patches.append(putil.get_artist('synchronous', col))

    # Add legend.
    lgn_patches = lgn_patches[::-1]
    putil.set_legend(ax, handles=lgn_patches)

    # Add chance level line.
    # This currently plots all nvals combined across stimulus period!
    if nvals is not None:
        chance_lvl = 1.0 / nvals
        putil.add_chance_level(ax=ax, ylevel=chance_lvl)

    # Add stimulus periods.
    if prds is not None:
        putil.plot_periods(prds, ax=ax)

    # Set axis limits.
    xlim = xlim if xlim is not None else tlim
    putil.set_limits(ax, xlim, ylim)

    # Format plot.
    putil.set_labels(ax, xlab, ylab, title, ytitle)
예제 #6
0
파일: pauc.py 프로젝트: mnislamraju/seal
def plot_auc_over_time(auc,
                       tvec,
                       prds=None,
                       evts=None,
                       xlim=None,
                       ylim=None,
                       xlab='time',
                       ylab='AUC',
                       title=None,
                       ax=None):
    """Plot AROC values over time."""

    # Init params.
    ax = putil.axes(ax)
    if xlim is None:
        xlim = [min(tvec), max(tvec)]

    # Plot periods first.
    putil.plot_periods(prds, ax=ax)

    # Plot AUC over time.
    pplot.lines(tvec, auc, ylim, xlim, xlab, ylab, title, color='green', ax=ax)

    # Add chance level line.
    putil.add_chance_level(ax=ax)

    #    # Set minimum y axis scale.
    #    ymin, ymax = ax.get_ylim()
    #    ymin, ymax = min(ymin, 0.3), max(ymax, 0.7)
    #    ax.set_ylim([ymin, ymax])

    # Set y tick labels.
    if ylim is not None and ylim[0] == 0 and ylim[1] == 1:
        tck_marks = np.linspace(0, 1, 5)
        tck_lbls = np.array(tck_marks, dtype=str)
        tck_lbls[1::2] = ''
        putil.set_ytick_labels(ax, tck_marks, tck_lbls)
    putil.set_max_n_ticks(ax, 5, 'y')

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    return ax
예제 #7
0
파일: prepare.py 프로젝트: mnislamraju/seal
def select_units_trials(UA, utids=None, fres=None, ffig=None,
                        min_n_units=5, min_n_trs_per_unit=5):
    """
    Select optimal set of units and trials for population decoding.

    min_n_units: minimum number of units to keep (0: off)
    min_n_trs_per_unit: minimum number of trials per unit to keep (0: off)
    """

    print('Selecting optimal set of units and trials for decoding...')

    # Init.
    if utids is None:
        utids = UA.utids(as_series=True)
    u_rt_grpby = utids.groupby(level=['subj', 'date', 'task'])

    # Unit info frame.
    UInc = pd.Series(False, index=utids.index)

    # Included trials by unit.
    IncTrs = pd.Series([(UA.get_unit(utid[:-1], utid[-1]).inc_trials())
                        for utid in utids], index=utids.index)

    # Result DF.
    rec_task = pd.MultiIndex.from_tuples([rt for rt, _ in u_rt_grpby],
                                         names=['subj', 'date', 'task'])
    cols = ['elec', 'units', 'nunits', 'nallunits', '% remaining units',
            'trials', 'ntrials', 'nalltrials', '% remaining trials']
    RecInfo = pd.DataFrame(index=rec_task, columns=cols)
    rt_utids = [utids.xs((s, d, t), level=('subj', 'date', 'task'))
                for s, d, t in rec_task]
    RecInfo.nallunits = [len(utids) for utids in rt_utids]
    rt_ulist = [UA.get_unit(utids[0][:-1], utids[0][-1]) for utids in rt_utids]
    RecInfo.nalltrials = [int(u.QualityMetrics['NTrialsTotal'])
                          for u in rt_ulist]

    # Function to plot matrix (DF) included/excluded trials.
    def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None,
                            xlab='Trial #', ylab=None,):
        # Plot on heatmap.
        sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax)
        # Set tick labels.
        putil.hide_tick_marks(ax)
        tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25))
        ax.xaxis.set_ticks(tr_ticks)
        ax.set_xticklabels(tr_ticks)
        putil.rot_xtick_labels(ax, 0)
        putil.rot_ytick_labels(ax, 0, va='center')
        putil.set_labels(ax, xlab, ylab, title, ytitle)

    # Init plotting.
    ytitle = 1.40
    putil.set_style('notebook', 'whitegrid')
    fig, gsp, axs = putil.get_gs_subplots(nrow=len(rec_task), ncol=3,
                                          subw=6, subh=4, create_axes=True)

    for i_rt, ((subj, date, task), rt_utids) in enumerate(u_rt_grpby):
        print('{} / {}: {} - {} {}'.format(i_rt+1, len(u_rt_grpby), subj,
                                           date, task))

        # Init electrode.
        elecs = rt_utids.index.get_level_values('elec').unique()
        if len(elecs) != 1:
            warnings.warn('More than one electrode?')
        elec = elecs[0]
        RecInfo.loc[(subj, date, task), 'elec'] = elec

        # Create matrix of included trials of recording & task of units.
        ch_idxs = rt_utids.index.droplevel(-1).droplevel(2).droplevel(1).droplevel(0)
        n_alltrs = RecInfo.nalltrials[(subj, date, task)]
        IncTrsMat = pd.DataFrame(np.zeros((len(ch_idxs), n_alltrs), dtype=int),
                                 index=ch_idxs, columns=np.arange(n_alltrs)+1)
        for ch_idx, utid in zip(ch_idxs, rt_utids):
            IncTrsMat.loc[ch_idx].iloc[IncTrs[utid]] = 1

        # Plot included/excluded trials after preprocessing.
        ax = axs[i_rt, 0]
        ylab = '{} {} {}'.format(subj, date, task)
        title = ('Included (green) and excluded (red) trials'
                 if i_rt == 0 else None)
        plot_inc_exc_trials(IncTrsMat, ax, title, ytitle, ylab=ylab)

        # Calculate and plot overlap of trials across units.
        # How many trials will remain if we iteratively excluding units
        # with the least overlap with the rest of the units?
        def n_cov_trs(df):  # return number of trials covered in df
            return sum(df.all())

        def calc_heuristic(df):
            return df.shape[0] * n_cov_trs(df)

        n_trs = IncTrsMat.sum(1)
        n_units = IncTrsMat.shape[0]

        # Init results DF.
        columns = ('uid', 'ntrs_cov', 'n_rem_u', 'trial x units')
        tr_covs = pd.DataFrame(columns=columns, index=range(n_units+1))
        tr_covs.loc[0] = ('none', n_cov_trs(IncTrsMat), n_units,
                          calc_heuristic(IncTrsMat))

        # Subset of included units (to be updated in each iteration).
        uinc = IncTrsMat.index.to_series()
        for iu in range(1, len(uinc)):

            # Number of covered trials after removing each unit.
            sntrscov = pd.Series([n_cov_trs(IncTrsMat.loc[uinc.drop(uid)])
                                  for uid in uinc], index=uinc.index)

            #########################################
            # Select and remove unit that           #
            # (a) yields maximum trial coverage,    #
            # (b) has minimum number of trials      #
            #########################################
            maxtrscov = sntrscov.max()
            worst_us = sntrscov[sntrscov == maxtrscov].index  # (a)
            utrs = n_trs.loc[worst_us]
            uid_remove = utrs[(utrs == min(utrs))].index[0]   # (b)

            # Update current subset of units and their trial DF.
            uinc.drop(uid_remove, inplace=True)
            tr_covs.loc[iu] = (uid_remove, maxtrscov, len(uinc),
                               calc_heuristic(IncTrsMat.loc[uinc]))

        # Add last unit.
        tr_covs.iloc[-1] = (uinc[0], 0, 0, 0)

        # Plot covered trials against each units removed.
        ax_trc = axs[i_rt, 1]
        sns.tsplot(tr_covs['ntrs_cov'], marker='o', ms=4, color='b',
                   ax=ax_trc)
        title = ('Trial coverage during iterative unit removal'
                 if i_rt == 0 else None)
        xlab, ylab = 'current unit removed', '# trials covered'
        putil.set_labels(ax_trc, xlab, ylab, title, ytitle)
        ax_trc.xaxis.set_ticks(tr_covs.index)
        x_ticklabs = ['none'] + ['{} - {}'.format(ch, ui)
                                 for ch, ui in tr_covs.uid.loc[1:]]
        ax_trc.set_xticklabels(x_ticklabs)
        putil.rot_xtick_labels(ax_trc, 45)
        ax_trc.grid(True)

        # Add # of remaining units to top.
        ax_remu = ax_trc.twiny()
        ax_remu.xaxis.set_ticks(tr_covs.index)
        ax_remu.set_xticklabels(list(range(len(x_ticklabs)))[::-1])
        ax_remu.set_xlabel('# units remaining')
        ax_remu.grid(None)

        # Add heuristic index.
        ax_heur = ax_trc.twinx()
        sns.tsplot(tr_covs['trial x units'], linestyle='--', marker='o',
                   ms=4, color='m',  ax=ax_heur)
        putil.set_labels(ax_heur, ylab='remaining units x covered trials')
        [tl.set_color('m') for tl in ax_heur.get_yticklabels()]
        [tl.set_color('b') for tl in ax_trc.get_yticklabels()]
        ax_heur.grid(None)

        # Decide on which units to exclude.
        min_n_trials = min_n_trs_per_unit * tr_covs['n_rem_u']
        sub_tr_covs = tr_covs[(tr_covs['n_rem_u'] >= min_n_units) &
                              (tr_covs['ntrs_cov'] >= min_n_trials)]

        # If any subset of units passed above criteria.
        rem_uids, exc_uids = pd.Series(), tr_covs.uid[1:]
        n_tr_rem, n_tr_exc = 0, IncTrsMat.shape[1]
        if len(sub_tr_covs.index):
            hmax_idx = sub_tr_covs['trial x units'].argmax()
            rem_uids = tr_covs.uid[(hmax_idx+1):]
            exc_uids = tr_covs.uid[1:hmax_idx+1]
            n_tr_rem = tr_covs.ntrs_cov[hmax_idx]
            n_tr_exc = IncTrsMat.shape[1] - n_tr_rem

            # Add to UnitInfo dataframe
            rt_utids = [(subj, date, elec, ch, ui, task)
                        for ch, ui in rem_uids]
            UInc[rt_utids] = True

        # Highlight selected point in middle plot.
        sel_seg = [('selection', exc_uids.shape[0]-0.4,
                    exc_uids.shape[0]+0.4)]
        putil.plot_periods(sel_seg, ax=ax_trc, alpha=0.3)
        [ax.set_xlim([-0.5, n_units+0.5]) for ax in (ax_trc, ax_remu)]

        # Generate remaining trials dataframe.
        RemTrsMat = IncTrsMat.copy().astype(float)
        for exc_uid in exc_uids:   # Remove all trials from excluded units.
            RemTrsMat.loc[exc_uid] = 0.5
        # Remove uncovered trials in remaining units.
        exc_trs = np.where(~RemTrsMat.loc[list(rem_uids)].all())[0]
        if exc_trs.size:
            RemTrsMat.iloc[:, exc_trs] = 0.5
        # Overwrite by trials excluded during preprocessing.
        RemTrsMat[IncTrsMat == False] = 0.0

        # Plot remaining trials.
        ax = axs[i_rt, 2]
        n_u_rem, n_u_exc = len(rem_uids), len(exc_uids)
        title = ('# units remaining: {}, excluded: {}'.format(n_u_rem,
                                                              n_u_exc) +
                 '\n# trials remaining: {}, excluded: {}'.format(n_tr_rem,
                                                                 n_tr_exc))
        plot_inc_exc_trials(RemTrsMat, ax, title=title, ylab='')

        # Add remaining units and trials to RecInfo.
        rt = (subj, date, task)
        RecInfo.loc[rt, ('units', 'nunits')] = list(rem_uids), len(rem_uids)
        cov_trs = RemTrsMat.loc[list(rem_uids)].all()
        inc_trs = pd.Int64Index(np.where(cov_trs)[0])
        RecInfo.loc[rt, ('trials', 'ntrials')] = inc_trs, sum(cov_trs)

    RecInfo['% remaining units'] = 100 * RecInfo.nunits / RecInfo.nallunits
    RecInfo['% remaining trials'] = 100 * RecInfo.ntrials / RecInfo.nalltrials

    # Save results.
    if fres is not None:
        results = {'RecInfo': RecInfo, 'UInc': UInc}
        util.write_objects(results, fres)

    # Save plot.
    title = 'Trial & unit selection prior decoding'
    putil.save_fig(ffig, fig, title, w_pad=3, h_pad=3)

    return RecInfo, UInc
예제 #8
0
def plot_qm(u,
            bs_stats,
            stab_prd_res,
            prd_inc,
            tr_inc,
            spk_inc,
            add_lbls=False,
            ftempl=None,
            fig=None,
            sps=None):
    """Plot quality metrics related figures."""

    # Init values.
    waveforms = np.array(u.Waveforms)
    wavetime = u.Waveforms.columns * us
    spk_times = np.array(u.SpikeParams['time'], dtype=float)
    base_rate = u.QualityMetrics['baseline']

    # Minimum and maximum gain.
    gmin = u.SessParams['minV']
    gmax = u.SessParams['maxV']

    # %% Init plots.

    # Disable inline plotting to prevent memory leak.
    putil.inline_off()

    # Init figure and gridspec.
    fig = putil.figure(fig)
    if sps is None:
        sps = putil.gridspec(1, 1)[0]
    ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1])

    info_sps, qm_sps = ogsp[0], ogsp[1]

    # Info header.
    info_ax = fig.add_subplot(info_sps)
    putil.hide_axes(info_ax)
    title = putil.get_unit_info_title(u)
    putil.set_labels(ax=info_ax, title=title, ytitle=0.80)

    # Create axes.
    gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4)
    ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)]
    ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)]
    ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)]

    # Trial markers.
    trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop
    tr_markers = pd.DataFrame({'time': trial_starts[9::10]})
    tr_markers['label'] = [
        str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index)
    ]

    # Common variables, limits and labels.
    WF_T_START = test_sorting.WF_T_START
    spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) -
                                      WF_T_START)
    ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts,
                                                  trial_stops)
    ss, sa = 1.0, 0.8  # marker size and alpha on scatter plot

    # Color spikes by their occurance over session time.
    my_cmap = putil.get_cmap('jet')
    spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1))
    if np.any(spk_inc):  # check if there is any spike included
        spk_t_inc = np.array(spk_times[spk_inc])
        tmin, tmax = float(spk_times.min()), float(spk_times.max())
        spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin))
    # Put excluded trials to the front, and randomise order of included trials
    # so later spikes don't systematically cover earlier ones.
    spk_order = np.hstack((np.where(np.invert(spk_inc))[0],
                           np.random.permutation(np.where(spk_inc)[0])))

    # Common labels for plots
    ses_t_lab = 'Recording time (s)'

    # %% Waveform shape analysis.

    # Plot included and excluded waveforms on different axes.
    # Color included by occurance in session time to help detect drifts.
    s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order]
    wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax]
    wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage'
    for st in ('Included', 'Excluded'):
        ax = ax_wf_inc if st == 'Included' else ax_wf_exc
        spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc)
        tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc)

        nspsk, ntrs = sum(spk_idx), sum(tr_idx)
        title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs)

        # Select waveforms and colors.
        rand_spk_idx = spk_idx[spk_order]
        wfs = s_waveforms[rand_spk_idx, :]
        cols = s_spk_cols[rand_spk_idx]

        # Plot waveforms.
        xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None)
        pwaveform.plot_wfs(wfs,
                           spk_t,
                           cols=cols,
                           lw=0.1,
                           alpha=0.05,
                           xlim=wf_t_lim,
                           ylim=glim,
                           title=title,
                           xlab=xlab,
                           ylab=ylab,
                           ax=ax)

    # %% Waveform summary metrics.

    # Init data.
    wf_amp_all = u.SpikeParams['amplitude']
    wf_amp_inc = wf_amp_all[spk_inc]
    wf_dur_all = u.SpikeParams['duration']
    wf_dur_inc = wf_dur_all[spk_inc]

    # Set common limits and labels.
    dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]]  # same across units
    glim = max(wf_amp_all.max(), gmax - gmin)
    amp_lim = [0, glim]

    amp_lab = 'Amplitude'
    dur_lab = 'Duration ($\mu$s)'

    # Waveform amplitude across session time.
    m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std()
    title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp)
    xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_amp_all,
                  spk_inc,
                  c='m',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_amp)

    # Waveform duration across session time.
    mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std()
    title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur)
    xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_dur_all,
                  spk_inc,
                  c='c',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=dur_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_dur)

    # Waveform duration against amplitude.
    title = 'WF duration - amplitude'
    xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(wf_dur_all[spk_order],
                  wf_amp_all[spk_order],
                  c=spk_cols[spk_order],
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=dur_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_amp_dur)

    # %% Firing rate.

    tmean = np.array(bs_stats['tmean'])
    rmean = util.remove_dim_from_series(bs_stats['rate'])
    prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop']

    # Color segments depending on whether they are included / excluded.
    def plot_periods(v, color, ax):
        # Plot line segments.
        for i in range(len(prd_inc[:-1])):
            col = color if prd_inc[i] and prd_inc[i + 1] else 'grey'
            x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])]
            ax.plot(x, y, color=col)
        # Plot line points.
        for i in range(len(prd_inc)):
            col = color if prd_inc[i] else 'grey'
            x, y = [tmean[i], v[i]]
            ax.plot(x,
                    y,
                    color=col,
                    marker='o',
                    markersize=3,
                    markeredgecolor=col)

    # Firing rate over session time.
    title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate))
    xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None)
    ylim = [0, 1.25 * np.max(rmean)]
    plot_periods(rmean, 'b', ax_rate)
    pplot.lines([], [],
                c='b',
                xlim=ses_t_lim,
                ylim=ylim,
                title=title,
                xlab=xlab,
                ylab=ylab,
                ax=ax_rate)

    # Trial markers.
    putil.plot_events(tr_markers,
                      lw=0.5,
                      ls='--',
                      alpha=0.35,
                      y_lbl=0.92,
                      ax=ax_rate)

    # Excluded periods.
    excl_prds = []
    tstart, tstop = ses_t_lim
    if tstart != prd_tstart:
        excl_prds.append(('beg', tstart, prd_tstart))
    if tstop != prd_tstop:
        excl_prds.append(('end', prd_tstop, tstop))
    putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate)

    # %% Post-formatting.

    # Maximize number of ticks on recording time axes to prevent covering.
    for ax in (ax_wf_amp, ax_wf_dur, ax_rate):
        putil.set_max_n_ticks(ax, 6, 'x')

    # %% Save figure.
    if ftempl is not None:
        fname = ftempl.format(u.name_to_fname())
        putil.save_fig(fname, fig, title, rect_height=0.92)
        putil.inline_on()

    return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
예제 #9
0
def rate(rate_list,
         names=None,
         prds=None,
         evts=None,
         cols=None,
         baseline=None,
         pval=0.05,
         test='mann_whitney_u',
         test_kws=None,
         xlim=None,
         ylim=None,
         title=None,
         xlab=None,
         ylab=putil.FR_lbl,
         add_lgn=True,
         lgn_lbl='trs',
         ffig=None,
         ax=None):
    """Plot firing rate."""

    # Init.
    ax = putil.axes(ax)
    if test_kws is None:
        test_kws = dict()

    # Plot periods and baseline first.
    putil.plot_periods(prds, ax=ax)
    if baseline is not None:
        putil.add_baseline(baseline, ax=ax)
    putil.set_limits(ax, xlim)

    if not len(rate_list):
        return ax

    if cols is None:
        cols = putil.get_colors(as_cycle=False)
    if names is None:
        names = len(rate_list) * ['']

    # Iterate through list of rate arrays
    xmin, xmax, ymax = None, None, None
    for i, rts in enumerate(rate_list):

        # Init.
        name = names[i]
        col = cols[i]

        # Skip empty array (no trials).
        if not rts.shape[0]:
            continue

        # Set line label. Convert to Numpy array to format floats nicely.
        lbl = str(np.array(name)) if util.is_iterable(name) else str(name)

        if lgn_lbl is not None:
            lbl += ' ({} {})'.format(rts.shape[0], lgn_lbl)

        # Plot mean +- SEM of rate vectors.
        tvec, meanr, semr = rts.columns, rts.mean(), rts.sem()
        ax.plot(tvec, meanr, label=lbl, color=col)
        ax.fill_between(tvec,
                        meanr - semr,
                        meanr + semr,
                        alpha=0.2,
                        facecolor=col,
                        edgecolor=col)

        # Update limits.
        tmin, tmax, rmax = tvec.min(), tvec.max(), (meanr + semr).max()
        xmin = np.min([xmin, tmin]) if xmin is not None else tmin
        xmax = np.max([xmax, tmax]) if xmax is not None else tmax
        ymax = np.max([ymax, rmax]) if ymax is not None else rmax

    # Set ticks, labels and axis limits.
    if xlim is None:
        if xmin == xmax:  # avoid setting identical limits
            xmax = None
        xlim = (xmin, xmax)
    if ylim is None:
        ymax = 1.02 * ymax if (ymax is not None) and (ymax > 0) else None
        ylim = (0, ymax)
    if xlab is not None:
        xlab = putil.t_lbl.format(xlab)
    putil.format_plot(ax, xlim, ylim, xlab, ylab, title)
    t1, t2 = ax.get_xlim()  # in case it was set to None
    tmarks, tlbls = putil.get_tick_marks_and_labels(t1, t2)
    putil.set_xtick_labels(ax, tmarks, tlbls)
    putil.set_max_n_ticks(ax, 7, 'y')

    # Add legend.
    if add_lgn and len(rate_list):
        putil.set_legend(ax,
                         loc=1,
                         borderaxespad=0.0,
                         handletextpad=0.4,
                         handlelength=0.6)

    # Add significance line to top of axes.
    if (pval is not None) and (len(rate_list) == 2):
        rates1, rates2 = rate_list
        sign_prds = stats.sign_periods(rates1, rates2, pval, test, **test_kws)
        putil.plot_signif_prds(sign_prds, color='m', linewidth=4.0, ax=ax)

    # Plot event markers.
    putil.plot_event_markers(evts, ax=ax)

    # Save and return plot.
    putil.save_fig(ffig)
    return ax
예제 #10
0
def plot_combined_rec_mean(recs, stims, res_dir, par_kws,
                           list_n_most_DS, list_min_nunits,
                           n_boot=1e4, ci=95,
                           tasks=None, task_labels=None, add_title=True,
                           fig=None):
    """Test and plot results combined across sessions."""

    # Init.
    # putil.set_style('notebook', 'ticks')
    vkey = 'all'

    # This should be made more explicit!
    prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
            for stim in stims]

    # Load all results to plot.
    dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws)

    # Create figures.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(dict_rt_res),
                                                ncol=len(list_min_nunits),
                                                subw=8, subh=6, fig=fig,
                                                create_axes=True)

    # Query data.
    allScores = {}
    allnunits = {}
    for n_most_DS, rt_res in dict_rt_res.items():
        # Get accuracy scores.
        dScores = {(rec, task): res[vkey]['Scores'].mean()
                   for (rec, task), res in rt_res.items()
                   if (vkey in res) and (res[vkey] is not None)}
        allScores[n_most_DS] = pd.concat(dScores, axis=1).T
        # Get number of units.
        allnunits[n_most_DS] = {(rec, task): res[vkey]['nunits'].iloc[0]
                                for (rec, task), res in rt_res.items()
                                if (vkey in res) and (res[vkey] is not None)}
        # Get # values (for baseline plotting.)
        all_nvals = pd.Series({(rec, task): res[vkey]['nclasses'].iloc[0]
                               for (rec, task), res in rt_res.items()
                               if (vkey in res) and (res[vkey] is not None)})
        un_nvals = all_nvals.unique()
        if len(un_nvals) > 1 and verbose:
            print('Found multiple # of classes to decode: {}'.format(un_nvals))
        nvals = un_nvals[0]

    allnunits = pd.DataFrame(allnunits)

    # Plot mean performance across recordings and
    # test significance by bootstrapping.
    for inmost, n_most_DS in enumerate(list_n_most_DS):
        Scores = allScores[n_most_DS]
        nunits = allnunits[n_most_DS]

        for iminu, min_nunits in enumerate(list_min_nunits):

            ax_scr = axs_scr[inmost, iminu]

            # Select only recordings with minimum number of units.
            sel_rt = nunits.index[nunits >= min_nunits]
            nScores = Scores.loc[sel_rt].copy()

            # Nothing to plot.
            if nScores.empty:
                ax_scr.axis('off')
                continue

            # Prepare data.
            if tasks is None:
                tasks = nScores.index.get_level_values(1).unique()  # in data
            if task_labels is None:
                task_labels = {task: task for task in tasks}
            dScores = {task: pd.DataFrame(nScores.xs(task, level=1).unstack(),
                                          columns=['accuracy'])
                       for task in tasks}
            lScores = pd.concat(dScores, axis=0)
            lScores['time'] = lScores.index.get_level_values(1)
            lScores['task'] = lScores.index.get_level_values(0)
            lScores['rec'] = lScores.index.get_level_values(2)
            lScores.index = np.arange(len(lScores.index))
            lScores.task.replace(task_labels, inplace=True)

            # Add altered task names for legend plotting.
            nrecs = {task_labels[task]: len(nScores.xs(task, level=1))
                     for task in tasks}
            my_format = lambda x: '{} (n={})'.format(x, nrecs[x])
            lScores['task_nrecs'] = lScores['task'].apply(my_format)

            # Plot as time series.
            sns.tsplot(lScores, time='time', value='accuracy', unit='rec',
                       condition='task_nrecs', ci=ci, n_boot=n_boot, ax=ax_scr)

            # Add chance level line.
            chance_lvl = 1.0 / nvals
            putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Add stimulus periods.
            putil.plot_periods(prds, ax=ax_scr)

            # Set axis limits.
            putil.set_limits(ax_scr, tlim)

            # Format plot.
            title = ('{} most DS units'.format(n_most_DS)
                     if n_most_DS != 0 else 'all units')
            title += (', recordings with at least {} units'.format(min_nunits)
                      if (min_nunits > 1 and len(list_min_nunits) > 1) else '')
            ytitle = 1.0
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)
            putil.hide_legend_title(ax_scr)

    # Match axes across decoding plots.
    [putil.sync_axes(axs_scr[inmost, :], sync_y=True)
     for inmost in range(axs_scr.shape[0])]

    # Save plots.
    list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS]
    par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str)
    title = ''
    if add_title:
        title = decutil.fig_title(res_dir, **par_kws)
        title += '\n{}% CE with {} bootstrapped subsamples'.format(ci,
                                                                   int(n_boot))
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str)
    ffig = decutil.fig_fname(res_dir, 'combined_score', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)

    return fig_scr, axs_scr, ffig
예제 #11
0
def plot_scores_across_nunits(recs, stims, res_dir, list_n_most_DS, par_kws):
    """
    Plot prediction score results across different number of units included.
    """

    # Init.
    putil.set_style('notebook', 'ticks')
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load all results to plot.
    dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws)

    # Create figures.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)
    # Do plotting per recording and task.
    for irec, rec in enumerate(recs):
        if verbose:
            print('\n' + rec)
        for itask, task in enumerate(tasks):
            if verbose:
                print('    ' + task)

            ax_scr = axs_scr[irec, itask]

            # Init data.
            dict_lScores = {}
            cols = sns.color_palette('hls', len(dict_rt_res.keys()))
            lncls = []
            for (n_most_DS, rt_res), col in zip(dict_rt_res.items(), cols):

                # Check if results exist for rec-task combination.
                if (((rec, task) not in rt_res.keys()) or
                    (not len(rt_res[(rec, task)].keys()))):
                    continue

                res = rt_res[(rec, task)]
                for v, col in zip(res.keys(), cols):
                    vres = res[v]
                    Scores = vres['Scores']
                    lncls.append(vres['nclasses'])

                    # Unstack dataframe with results.
                    lScores = pd.DataFrame(Scores.unstack(), columns=['score'])
                    lScores['time'] = lScores.index.get_level_values(0)
                    lScores['fold'] = lScores.index.get_level_values(1)
                    lScores.index = np.arange(len(lScores.index))

                    # Get number of units tested.
                    nunits = vres['nunits']
                    uni_nunits = nunits.unique()
                    if len(uni_nunits) > 1 and verbose:
                        print('Different number of units found.')
                    nunits = uni_nunits[0]

                    # Collect results.
                    dict_lScores[(nunits, v)] = lScores

            # Skip rest if no data is available.
            # Check if any result exists for rec-task combination.
            if not len(dict_lScores):
                ax_scr.axis('off')
                continue

            # Concatenate accuracy scores from every recording.
            all_lScores = pd.concat(dict_lScores)
            all_lScores['n_most_DS'] = all_lScores.index.get_level_values(0)
            all_lScores.index = np.arange(len(all_lScores.index))

            # Plot decoding results.
            nnunits = len(all_lScores['n_most_DS'].unique())
            title = '{} {}, {} sets of units'.format(' '.join(rec), task,
                                                     nnunits)
            ytitle = 1.0
            prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                    for stim in stims]

            # Plot time series.
            palette = sns.color_palette('muted')
            sns.tsplot(all_lScores, time='time', value='score', unit='fold',
                       condition='n_most_DS', color=palette, ax=ax_scr)

            # Add chance level line.
            # This currently plots a chance level line for every nvals,
            # combined across stimulus period!
            uni_ncls = np.unique(np.array(lncls).flatten())
            if len(uni_ncls) > 1 and verbose:
                print('Different number of classes found.')
            for nvals in uni_ncls:
                chance_lvl = 1.0 / nvals
                putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Add stimulus periods.
            if prds is not None:
                putil.plot_periods(prds, ax=ax_scr)

            # Set axis limits.
            putil.set_limits(ax_scr, tlim, ylim_scr)

            # Format plot.
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)

    # Match axes across decoding plots.
    # [putil.sync_axes(axs_scr[:, itask], sync_y=True)
    #  for itask in range(axs_scr.shape[1])]

    # Save plots.
    list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS]
    par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str)
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str)
    ffig = decutil.fig_fname(res_dir, 'score_nunits', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
예제 #12
0
def plot_score_multi_rec(recs, stims, res_dir, par_kws):
    """Plot prediction scores for multiple recordings."""

    # Init.
    putil.set_style('notebook', 'ticks')
    n_most_DS = par_kws['n_most_DS']
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load results.
    rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS]

    # Create figure.
    ret = putil.get_gs_subplots(nrow=1, ncol=len(tasks),
                                subw=8, subh=6, create_axes=True)
    fig_scr, _, axs_scr = ret

    print('\nPlotting multi-recording results...')
    for itask, task in enumerate(tasks):
        if verbose:
            print('    ' + task)
        ax_scr = axs_scr[0, itask]

        dict_lScores = {}
        for irec, rec in enumerate(recs):

            # Check if results exist for rec-task combination.
            if (((rec, task) not in rt_res.keys()) or
               (not len(rt_res[(rec, task)].keys()))):
                continue

            # Init data.
            res = rt_res[(rec, task)]
            cols = sns.color_palette('hls', len(res.keys()))
            lncls = []
            for v, col in zip(res.keys(), cols):
                vres = res[v]
                if vres is None:
                    continue

                Scores = vres['Scores']
                lncls.append(vres['nclasses'])

                # Unstack dataframe with results.
                lScores = pd.DataFrame(Scores.unstack(), columns=['score'])
                lScores['time'] = lScores.index.get_level_values(0)
                lScores['fold'] = lScores.index.get_level_values(1)
                lScores.index = np.arange(len(lScores.index))

                dict_lScores[(rec, v)] = lScores

        if not len(dict_lScores):
            ax_scr.axis('off')
            continue

        # Concatenate accuracy scores from every recording.
        all_lScores = pd.concat(dict_lScores)
        all_lScores['rec'] = all_lScores.index.get_level_values(0)
        all_lScores['rec'] = all_lScores['rec'].str.join(' ')  # format label
        all_lScores.index = np.arange(len(all_lScores.index))

        # Plot decoding results.
        nrec = len(all_lScores['rec'].unique())
        title = '{}, {} recordings'.format(task, nrec)
        ytitle = 1.0
        prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                for stim in stims]

        # Plot time series.
        palette = sns.color_palette('muted')
        sns.tsplot(all_lScores, time='time', value='score', unit='fold',
                   condition='rec', color=palette, ax=ax_scr)

        # Add chance level line.
        # This currently plots a chance level line for every nvals,
        # combined across stimulus period!
        uni_ncls = np.unique(np.array(lncls).flatten())
        if len(uni_ncls) > 1 and verbose:
            print('Different number of classes found.')
        for nvals in uni_ncls:
            chance_lvl = 1.0 / nvals
            putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

        # Add stimulus periods.
        if prds is not None:
            putil.plot_periods(prds, ax=ax_scr)

        # Set axis limits.
        putil.set_limits(ax_scr, tlim, ylim_scr)

        # Format plot.
        putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle)

    # Save figure.
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3
    ffig = decutil.fig_fname(res_dir, 'all_scores', fformat, **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
예제 #13
0
def plot_scores_weights(recs, stims, res_dir, par_kws):
    """
    Plot prediction scores and model weights for given recording and analysis.
    """

    # Init.
    putil.set_style('notebook', 'ticks')
    n_most_DS = par_kws['n_most_DS']
    tasks = par_kws['tasks']

    # Remove Passive if plotting Saccade or Correct.
    if par_kws['feat'] in ['saccade', 'correct']:
        tasks = tasks[~tasks.str.contains('Pas')]

    # Load results.
    rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS]

    # Create figures.
    # For prediction scores.
    fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)

    # For unit weights (coefficients).
    fig_wgt, _, axs_wgt = putil.get_gs_subplots(nrow=len(recs),
                                                ncol=len(tasks),
                                                subw=8, subh=6,
                                                create_axes=True)

    for irec, rec in enumerate(recs):
        if verbose:
            print('\n' + rec)
        for itask, task in enumerate(tasks):
            if verbose:
                print('    ' + task)

            # Init figures.
            ax_scr = axs_scr[irec, itask]
            ax_wgt = axs_wgt[irec, itask]

            # Check if any result exists for rec-task combination.
            if (((rec, task) not in rt_res.keys()) or
               (not len(rt_res[(rec, task)].keys()))):
                ax_scr.axis('off')
                ax_wgt.axis('off')
                continue

            # Init data.
            res = rt_res[(rec, task)]
            vals = [v for v in res.keys() if not util.is_null(res[v])]
            cols = sns.color_palette('hls', len(vals))
            lnunits, lntrs, lncls,  = [], [], []
            for v, col in zip(vals, cols):
                # Basic results.
                vres = res[v]
                Scores = vres['Scores']
                Coefs = vres['Coefs']
                Perm = vres['Perm']
                Psdo = vres['Psdo']
                # Decoding params.
                lnunits.append(vres['nunits'])
                lntrs.append(vres['ntrials'])
                lncls.append(vres['nclasses'])
                # Plot decoding accuracy.
                plot_scores(ax_scr, Scores, Perm, Psdo, col=col)

            # Add labels.
            uni_lnunits = np.unique(np.array(lnunits).flatten())
            if len(uni_lnunits) > 1 and verbose:
                print('Different number of units found.')
            nunits = uni_lnunits[0]
            title = '{} {}, {} units'.format(' '.join(rec), task, nunits)
            putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle=1.04)

            # Add chance level line.
            uni_ncls = np.unique(np.array(lncls).flatten())
            if len(uni_ncls) > 1 and verbose:
                print('Different number of classes found.')
            for nvals in uni_ncls:
                chance_lvl = 1.0 / nvals
                putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl)

            # Plot stimulus periods.
            prds = [[stim] + list(constants.fixed_tr_prds.loc[stim])
                    for stim in stims]
            putil.plot_periods(prds, ax=ax_scr)

            # Plot unit weights over time.
            plot_weights(ax_wgt, Coefs, prds, tlim, tlab, title=title)

    # Match axes across decoding plots.
    # [putil.sync_axes(axs_scr[:, itask], sync_y=True)
    #  for itask in range(axs_scr.shape[1])]

    # Save plots.
    title = decutil.fig_title(res_dir, **par_kws)
    fs_title = 'large'
    w_pad, h_pad = 3, 3

    # Performance.
    ffig = decutil.fig_fname(res_dir, 'score', 'pdf', **par_kws)
    putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)

    # Weights.
    ffig = decutil.fig_fname(res_dir, 'weight', 'pdf', **par_kws)
    putil.save_fig(ffig, fig_wgt, title, fs_title, w_pad=w_pad, h_pad=h_pad)