Exemple #1
0
def raster_rate(spk_list,
                rate_list,
                names=None,
                prds=None,
                evts=None,
                cols=None,
                baseline=None,
                title=None,
                rs_ylab=True,
                rate_kws=None,
                fig=None,
                ffig=None,
                sps=None):
    """Plot raster and rate plots."""

    if rate_kws is None:
        rate_kws = dict()

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 2, 1, height_ratios=[.66, 1], hspace=.15)
    n_sets = max(len(spk_list), 1)  # let's add an empty axes if no data
    gsp_raster = putil.embed_gsp(gsp[0], n_sets, 1, hspace=.15)
    gsp_rate = putil.embed_gsp(gsp[1], 1, 1)

    # Init colors.
    if cols is None:
        col_cyc = putil.get_colors(mpl_colors=True)
        cols = [next(col_cyc) for i in range(n_sets)]

    # Raster plots.
    raster_axs = [fig.add_subplot(gsp_raster[i, 0]) for i in range(n_sets)]
    for i, (spk_trs, ax) in enumerate(zip(spk_list, raster_axs)):
        ylab = names[i] if (rs_ylab and names is not None) else None
        raster(spk_trs, prds=prds, c=cols[i], ylab=ylab, ax=ax)
        putil.hide_axes(ax)
    if len(raster_axs):
        putil.set_labels(raster_axs[0], title=title)  # add title to top raster

    # Rate plot.
    rate_ax = fig.add_subplot(gsp_rate[0, 0])
    rate(rate_list, names, prds, evts, cols, baseline, **rate_kws, ax=rate_ax)

    # Synchronize raster's x axis limits to rate plot's limits.
    xlim = rate_ax.get_xlim()
    [ax.set_xlim(xlim) for ax in raster_axs]

    # Save and return plot.
    putil.save_fig(ffig, fig)
    return fig, raster_axs, rate_ax
Exemple #2
0
def plot_DR_3x3(u, fig=None, sps=None):
    """Plot 3x3 direction response plot, with polar plot in center."""

    if not u.to_plot():
        return

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 3, 3)  # inner gsp with subplots

    # Polar plot.
    putil.set_style('notebook', 'white')
    ax_polar = fig.add_subplot(gsp[4], polar=True)
    for stim in constants.stim_dur.index:  # for each stimulus
        stim_resp = u.get_stim_resp_vals(stim, 'Dir')
        resp_stats = util.calc_stim_resp_stats(stim_resp)
        dirs, resp = np.array(resp_stats.index) * deg, resp_stats['mean']
        c = putil.stim_colors[stim]
        baseline = u.get_baseline()
        ptuning.plot_DR(dirs, resp, color=c, baseline=baseline, ax=ax_polar)
    putil.hide_ticks(ax_polar, 'y')

    # Raster-rate plots.
    putil.set_style('notebook', 'ticks')
    rr_pos = [5, 2, 1, 0, 3, 6, 7, 8]  # Position of each direction.
    rr_dir_plot_pos = pd.Series(constants.all_dirs, index=rr_pos)

    rate_axs = []
    for isp, d in rr_dir_plot_pos.iteritems():

        # Prepare plot formatting.
        first_dir = (isp == 0)

        # Plot direction response across trial periods.
        res = plot_SR(u, 'Dir', [d], fig=fig, sps=gsp[isp], no_labels=True)
        draster_axs, drate_axs, _ = res

        # Remove axis ticks.
        for i, ax in enumerate(drate_axs):
            first_prd = (i == 0)
            show_x_tick_lbls = first_dir
            show_y_tick_lbls = first_dir & first_prd
            putil.hide_tick_labels(ax, show_x_tick_lbls, show_y_tick_lbls)

        # Add task name as title (to top center axes).
        if isp == 1:
            ttl = u.get_task() + (' [excluded]' if u.is_excluded() else '')
            putil.set_labels(draster_axs[0],
                             title=ttl,
                             ytitle=1.10,
                             title_kws={'loc': 'right'})

        rate_axs.extend(drate_axs)
        rate_axs.extend(drate_axs)

    # Match scale of y axes.
    putil.sync_axes(rate_axs, sync_y=True)
    [putil.adjust_decorators(ax) for ax in rate_axs]

    return ax_polar, rate_axs
Exemple #3
0
def plot_DSI(u,
             nrate=None,
             fig=None,
             sps=None,
             prd_pars=None,
             no_labels=False):
    """Plot direction selectivity indices."""

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 2, 1, hspace=0.6)
    ax_mDSI, ax_wDSI = [fig.add_subplot(igsp) for igsp in gsp]

    # Calculate DSI using maximum - opposite and weighted measure.
    prd_pars = u.get_analysis_prds()
    maxDSI, wghtDSI = [
        direction.calc_DSI(u, fDSI, prd_pars, nrate)
        for fDSI in [direction.max_DS, direction.weighted_DS]
    ]

    # Get stimulus periods.
    stim_prds = u.get_stim_prds()

    # Plot DSIs.
    DSI_list = [pd.DataFrame(row).T for name, row in maxDSI.iterrows()]
    prate.rate(DSI_list,
               maxDSI.index,
               prds=stim_prds,
               pval=None,
               ylab='mDSI',
               title='max - opposite DSI',
               add_lgn=True,
               lgn_lbl=None,
               ax=ax_mDSI)

    DSI_list = [pd.DataFrame(row).T for name, row in wghtDSI.iterrows()]
    prate.rate(DSI_list,
               wghtDSI.index,
               prds=stim_prds,
               pval=None,
               ylab='wDSI',
               title='weighted DSI',
               add_lgn=True,
               lgn_lbl=None,
               ax=ax_wDSI)

    return ax_mDSI, ax_wDSI
Exemple #4
0
def plot_selectivity(u, fig=None, sps=None):
    """Plot selectivity summary plot."""

    if not u.to_plot():
        return

    # Check which stimulus parameters were variable. Don't plot selectivity for
    # variables that were kept constant.
    stims = ('S1', 'S2')
    trs = u.inc_trials()
    # Location.
    s1locs, s2locs = [u.TrData[(stim, 'Loc')][trs].unique() for stim in stims]
    plot_lr = False if (len(s1locs) == 1) and (len(s2locs) == 1) else True
    # Direction.
    s1dirs, s2dirs = [u.TrData[(stim, 'Dir')][trs].unique() for stim in stims]
    plot_dr = False if (len(s1dirs) == 1) and (len(s2dirs) == 1) else True

    # Init subplots.
    nsps = 1 + plot_lr + plot_dr  # all trials + location sel + direction sel
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, nsps, 1, hspace=0.3)

    # Plot task-relatedness.
    plot_all_trials(u, gsp[0], fig)
    igsp = 1

    # Plot location-specific activity.
    if plot_lr:
        _, lr_rate_axs, _ = plot_LR(u, gsp[1], fig)
        igsp = igsp + 1
    else:
        lr_rate_axs = []

    # Plot direction-specific activity.
    if plot_dr:
        _, dr_rate_axs, _ = plot_DR(u, gsp[igsp], fig)
    else:
        dr_rate_axs = []

    return lr_rate_axs, dr_rate_axs
Exemple #5
0
def plot_qm(u,
            bs_stats,
            stab_prd_res,
            prd_inc,
            tr_inc,
            spk_inc,
            add_lbls=False,
            ftempl=None,
            fig=None,
            sps=None):
    """Plot quality metrics related figures."""

    # Init values.
    waveforms = np.array(u.Waveforms)
    wavetime = u.Waveforms.columns * us
    spk_times = np.array(u.SpikeParams['time'], dtype=float)
    base_rate = u.QualityMetrics['baseline']

    # Minimum and maximum gain.
    gmin = u.SessParams['minV']
    gmax = u.SessParams['maxV']

    # %% Init plots.

    # Disable inline plotting to prevent memory leak.
    putil.inline_off()

    # Init figure and gridspec.
    fig = putil.figure(fig)
    if sps is None:
        sps = putil.gridspec(1, 1)[0]
    ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1])

    info_sps, qm_sps = ogsp[0], ogsp[1]

    # Info header.
    info_ax = fig.add_subplot(info_sps)
    putil.hide_axes(info_ax)
    title = putil.get_unit_info_title(u)
    putil.set_labels(ax=info_ax, title=title, ytitle=0.80)

    # Create axes.
    gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4)
    ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)]
    ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)]
    ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)]

    # Trial markers.
    trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop
    tr_markers = pd.DataFrame({'time': trial_starts[9::10]})
    tr_markers['label'] = [
        str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index)
    ]

    # Common variables, limits and labels.
    WF_T_START = test_sorting.WF_T_START
    spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) -
                                      WF_T_START)
    ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts,
                                                  trial_stops)
    ss, sa = 1.0, 0.8  # marker size and alpha on scatter plot

    # Color spikes by their occurance over session time.
    my_cmap = putil.get_cmap('jet')
    spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1))
    if np.any(spk_inc):  # check if there is any spike included
        spk_t_inc = np.array(spk_times[spk_inc])
        tmin, tmax = float(spk_times.min()), float(spk_times.max())
        spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin))
    # Put excluded trials to the front, and randomise order of included trials
    # so later spikes don't systematically cover earlier ones.
    spk_order = np.hstack((np.where(np.invert(spk_inc))[0],
                           np.random.permutation(np.where(spk_inc)[0])))

    # Common labels for plots
    ses_t_lab = 'Recording time (s)'

    # %% Waveform shape analysis.

    # Plot included and excluded waveforms on different axes.
    # Color included by occurance in session time to help detect drifts.
    s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order]
    wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax]
    wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage'
    for st in ('Included', 'Excluded'):
        ax = ax_wf_inc if st == 'Included' else ax_wf_exc
        spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc)
        tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc)

        nspsk, ntrs = sum(spk_idx), sum(tr_idx)
        title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs)

        # Select waveforms and colors.
        rand_spk_idx = spk_idx[spk_order]
        wfs = s_waveforms[rand_spk_idx, :]
        cols = s_spk_cols[rand_spk_idx]

        # Plot waveforms.
        xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None)
        pwaveform.plot_wfs(wfs,
                           spk_t,
                           cols=cols,
                           lw=0.1,
                           alpha=0.05,
                           xlim=wf_t_lim,
                           ylim=glim,
                           title=title,
                           xlab=xlab,
                           ylab=ylab,
                           ax=ax)

    # %% Waveform summary metrics.

    # Init data.
    wf_amp_all = u.SpikeParams['amplitude']
    wf_amp_inc = wf_amp_all[spk_inc]
    wf_dur_all = u.SpikeParams['duration']
    wf_dur_inc = wf_dur_all[spk_inc]

    # Set common limits and labels.
    dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]]  # same across units
    glim = max(wf_amp_all.max(), gmax - gmin)
    amp_lim = [0, glim]

    amp_lab = 'Amplitude'
    dur_lab = 'Duration ($\mu$s)'

    # Waveform amplitude across session time.
    m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std()
    title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp)
    xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_amp_all,
                  spk_inc,
                  c='m',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_amp)

    # Waveform duration across session time.
    mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std()
    title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur)
    xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None)
    pplot.scatter(spk_times,
                  wf_dur_all,
                  spk_inc,
                  c='c',
                  bc='grey',
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=ses_t_lim,
                  ylim=dur_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_wf_dur)

    # Waveform duration against amplitude.
    title = 'WF duration - amplitude'
    xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None)
    pplot.scatter(wf_dur_all[spk_order],
                  wf_amp_all[spk_order],
                  c=spk_cols[spk_order],
                  s=ss,
                  xlab=xlab,
                  ylab=ylab,
                  xlim=dur_lim,
                  ylim=amp_lim,
                  edgecolors='',
                  alpha=sa,
                  id_line=False,
                  title=title,
                  ax=ax_amp_dur)

    # %% Firing rate.

    tmean = np.array(bs_stats['tmean'])
    rmean = util.remove_dim_from_series(bs_stats['rate'])
    prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop']

    # Color segments depending on whether they are included / excluded.
    def plot_periods(v, color, ax):
        # Plot line segments.
        for i in range(len(prd_inc[:-1])):
            col = color if prd_inc[i] and prd_inc[i + 1] else 'grey'
            x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])]
            ax.plot(x, y, color=col)
        # Plot line points.
        for i in range(len(prd_inc)):
            col = color if prd_inc[i] else 'grey'
            x, y = [tmean[i], v[i]]
            ax.plot(x,
                    y,
                    color=col,
                    marker='o',
                    markersize=3,
                    markeredgecolor=col)

    # Firing rate over session time.
    title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate))
    xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None)
    ylim = [0, 1.25 * np.max(rmean)]
    plot_periods(rmean, 'b', ax_rate)
    pplot.lines([], [],
                c='b',
                xlim=ses_t_lim,
                ylim=ylim,
                title=title,
                xlab=xlab,
                ylab=ylab,
                ax=ax_rate)

    # Trial markers.
    putil.plot_events(tr_markers,
                      lw=0.5,
                      ls='--',
                      alpha=0.35,
                      y_lbl=0.92,
                      ax=ax_rate)

    # Excluded periods.
    excl_prds = []
    tstart, tstop = ses_t_lim
    if tstart != prd_tstart:
        excl_prds.append(('beg', tstart, prd_tstart))
    if tstop != prd_tstop:
        excl_prds.append(('end', prd_tstop, tstop))
    putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate)

    # %% Post-formatting.

    # Maximize number of ticks on recording time axes to prevent covering.
    for ax in (ax_wf_amp, ax_wf_dur, ax_rate):
        putil.set_max_n_ticks(ax, 6, 'x')

    # %% Save figure.
    if ftempl is not None:
        fname = ftempl.format(u.name_to_fname())
        putil.save_fig(fname, fig, title, rect_height=0.92)
        putil.inline_on()

    return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
Exemple #6
0
def plot_DR_tuning(DS,
                   title=None,
                   labels=True,
                   baseline=None,
                   DR_legend=True,
                   tuning_legend=True,
                   ffig=None,
                   fig=None,
                   sps=None):
    """Plot direction selectivity on polar plot and tuning curve."""

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)
    gsp = putil.embed_gsp(sps, 1, 2, wspace=0.3)
    ax_DR = fig.add_subplot(gsp[0], polar=True)
    ax_tuning = fig.add_subplot(gsp[1])
    DR_patches = []
    tuning_patches = []

    stims = DS.DSI.index
    for stim in stims:

        # Extract params and results.
        a, b, x0, sigma, FWHM, R2, RMSE = DS.TP.loc[stim]
        DSI = DS.DSI.wDS[stim]
        PD, cPD = DS.PD.loc[(stim, 'weighted'), ['PD', 'cPD']]
        dirs = np.array(DS.DR[stim].index) * deg
        SR, SRsem = [DS.DR[stim][col] for col in ('mean', 'sem')]

        # Center direction  and response stats.
        dirsc = direction.center_to_dir(dirs, PD)

        # Generate data points for plotting fitted tuning curve.
        x, y = tuning.gen_fit_curve(tuning.gaus,
                                    -180,
                                    180,
                                    a=a,
                                    b=b,
                                    x0=x0,
                                    sigma=sigma)

        # Plot DR polar plot.
        color = putil.stim_colors.loc[stim]
        ttl = 'Direction response' if labels else None
        plot_DR(dirs, SR, DSI, PD, baseline, title=ttl, color=color, ax=ax_DR)

        # Collect parameters of DR plot (stimulus - response).
        s_pd = str(float(round(PD, 1)))
        s_pd_c = str(int(cPD)) if not np.isnan(float(cPD)) else 'nan'
        lgd_lbl = '{}:   {:.3f}'.format(stim, DSI)
        lgd_lbl += '     {:>5}$^\circ$ --> {:>3}$^\circ$ '.format(s_pd, s_pd_c)
        DR_patches.append(putil.get_artist(lgd_lbl, color))

        # Calculate and plot direction tuning curve.
        xticks = np.arange(-180, 180 + 1, 45)
        plot_tuning(x,
                    y,
                    dirsc,
                    SR,
                    SRsem,
                    color,
                    baseline,
                    xticks,
                    ax=ax_tuning)

        # Collect parameters tuning curve fit.
        s_a, s_b, s_x0, s_sigma, s_FWHM = [
            str(float(round(p, 1))) for p in (a, b, x0, sigma, FWHM)
        ]
        s_R2 = format(R2, '.2f')
        lgd_lbl = '{}:{}{:>6}{}{:>6}'.format(stim, 5 * ' ', s_a, 5 * ' ', s_b)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(5 * ' ', s_x0, 8 * ' ', s_sigma)
        lgd_lbl += '{}{:>6}{}{:>6}'.format(8 * ' ', s_FWHM, 8 * ' ', s_R2)
        tuning_patches.append(putil.get_artist(lgd_lbl, color))

    # Format tuning curve plot (only after all stimuli has been plotted!).
    xlab = 'Difference from preferred direction (deg)' if labels else None
    ylab = putil.FR_lbl if labels else None
    xlim = [-180 - 5, 180 + 5]  # degrees
    ylim = [0, None]
    ttl = 'Tuning curve' if labels else None
    putil.format_plot(ax_tuning, xlim, ylim, xlab, ylab, ttl)
    putil.add_zero_line('y', ax=ax_tuning)  # add zero reference line

    # Set super title.
    if title is not None:
        fig.suptitle(title, y=1.1, fontsize='x-large')

    # Set legends.
    ylegend = -0.38 if labels else -0.15
    fr_on = False if labels else True
    lgd_kws = dict([('fancybox', True), ('shadow', False), ('frameon', fr_on),
                    ('framealpha', 1.0), ('loc', 'lower center'),
                    ('bbox_to_anchor', [0., ylegend, 1., .0]),
                    ('prop', {
                        'family': 'monospace'
                    })])
    DR_lgn_ttl = 'DSI'.rjust(20) + 'PD'.rjust(14) + 'PD8'.rjust(14)
    tuning_lgd_ttl = ('a (sp/s)'.rjust(35) + 'b (sp/s)'.rjust(15) +
                      'x0 (deg)'.rjust(13) + 'sigma (deg)'.rjust(15) +
                      'FWHM (deg)'.rjust(15) + 'R-squared'.rjust(15))

    lgn_params = [(DR_legend, DR_lgn_ttl, DR_patches, ax_DR),
                  (tuning_legend, tuning_lgd_ttl, tuning_patches, ax_tuning)]

    for (plot_legend, lgd_ttl, patches, ax) in lgn_params:
        if not plot_legend:
            continue
        if not labels:  # customize for summary plot
            lgd_ttl = None
        lgd = putil.set_legend(ax, handles=patches, title=lgd_ttl, **lgd_kws)
        lgd.get_title().set_ha('left')
        if lgd_kws['frameon']:
            lgd.get_frame().set_linewidth(.5)

    # Save figure.
    if ffig is not None:
        putil.save_fig(ffig, fig, rect_height=0.55)

    return ax_DR, ax_tuning
Exemple #7
0
def plot_SR(u,
            param=None,
            vals=None,
            from_trs=None,
            prd_pars=None,
            nrate=None,
            colors=None,
            add_roc=False,
            add_prd_name=False,
            fig=None,
            sps=None,
            title=None,
            **kwargs):
    """Plot stimulus response (raster, rate and ROC) for mutliple stimuli."""

    has_nan_val = (util.is_iterable(vals)
                   and len([v for v in vals if np.isnan(float(v))]))
    if not u.to_plot() or has_nan_val:
        return [], [], []

    # Set up stimulus parameters.
    if prd_pars is None:
        prd_pars = u.get_analysis_prds()

    # Init subplots.
    sps, fig = putil.sps_fig(sps, fig)

    # Create a gridspec for each period.
    wratio = [
        float(min(dur, lmax))
        for dur, lmax in zip(prd_pars.dur, prd_pars.max_len)
    ]
    wspace = 0.0  # 0.05 to separate periods
    gsp = putil.embed_gsp(sps,
                          1,
                          len(prd_pars.index),
                          width_ratios=wratio,
                          wspace=wspace)

    axes_raster, axes_rate, axes_roc = [], [], []

    for i, prd in enumerate(prd_pars.index):

        ppars = prd_pars.loc[prd]
        ref = ppars.ref

        # Prepare trial set.
        if param is None:
            trs = u.ser_inc_trials()
        elif param in u.TrData.columns:
            trs = u.trials_by_param(param, vals)
        else:
            trs = u.trials_by_param((ppars.stim, param), vals)

        if from_trs is not None:
            trs = util.filter_lists(trs, from_trs)

        plot_roc = add_roc and (len(trs) == 2)

        # Init subplots.
        if plot_roc:
            rr_sps, roc_sps = putil.embed_gsp(gsp[i],
                                              2,
                                              1,
                                              hspace=0.2,
                                              height_ratios=[1, .4])
        else:
            rr_sps = putil.embed_gsp(gsp[i], 1, 1, hspace=0.3)[0]

        # Init params.
        prds = [constants.ev_stims.loc[ref]]
        evnts = None
        if ('cue' in ppars) and (ppars.cue is not None):
            evnts = [{'time': ppars.cue}]
            evnts[0]['color'] = (ppars.cue_color if 'cue_color' in ppars.index
                                 else putil.cue_colors['all'])

        if colors is None:
            colcyc = putil.get_colors()
            colors = [next(colcyc) for itrs in range(len(trs))]

        # Plot response on raster and rate plots.
        _, raster_axs, rate_ax = prate.plot_rr(u,
                                               prd,
                                               ref,
                                               prds,
                                               evnts,
                                               nrate,
                                               trs,
                                               ppars.max_len,
                                               cols=colors,
                                               fig=fig,
                                               sps=rr_sps,
                                               **kwargs)

        # Add period name to rate plot.
        if add_prd_name:
            rate_ax.text(0.02,
                         0.95,
                         prd,
                         fontsize=10,
                         color='k',
                         va='top',
                         ha='left',
                         transform=rate_ax.transAxes)

        # Plot ROC curve.
        if plot_roc:
            roc_ax = fig.add_subplot(roc_sps)
            # Init rates.
            plot_params = prate.prep_rr_plot_params(u, prd, ref, nrate, trs)
            _, _, (rates1, rates2), _, _ = plot_params
            # Calculate ROC results.
            aroc = roccore.run_ROC_over_time(rates1, rates2, n_perm=0)
            # Set up plot params and plot results.
            tvec, auc = aroc.index, aroc.auc
            xlim = rate_ax.get_xlim()
            pauc.plot_auc_over_time(auc,
                                    tvec,
                                    prds,
                                    evnts,
                                    xlim=xlim,
                                    ax=roc_ax)

        # Some exception handling: set x axis range to predefined values if
        # it is not set (probably due to no trial data plotted).
        if prd in constants.fixed_tr_prds.index:
            t1, t2 = constants.fixed_tr_prds.loc[prd]
            axs = raster_axs + [rate_ax] + ([roc_ax] if plot_roc else [])
            for ax in axs:
                if ax.get_xlim() == (-0.001, 0.001):
                    ax.set_xlim([t1, t2])

        # Collect axes.
        axes_raster.extend(raster_axs)
        axes_rate.append(rate_ax)
        if plot_roc:
            axes_roc.append(roc_ax)

    # Format rate and roc plots to make them match across trial periods.

    # Remove y-axis label, spine and ticks from second and later periods.
    for ax in axes_rate[1:] + axes_roc[1:]:
        ax.set_ylabel('')
        putil.set_spines(ax, bottom=True, left=False)
        putil.hide_ticks(ax, show_x_ticks=True, show_y_ticks=False)

    # Remove (hide) legend from all but last rate plot.
    [putil.hide_legend(ax) for ax in axes_rate[:-1]]

    # Get middle x point in relative coordinates of first rate axes.
    xranges = np.array([ax.get_xlim() for ax in axes_rate])
    xlens = xranges[:, 1] - xranges[:, 0]
    xmid = xlens.sum() / xlens[0] / 2

    # Add title.
    if title is not None:
        axes_raster[0].set_title(title, x=xmid, y=1.0)

    # Set common x label.
    if ('no_labels' not in kwargs) or (not kwargs['no_labels']):
        [ax.set_xlabel('') for ax in axes_rate + axes_roc]
        xlab = putil.t_lbl.format(prd_pars.stim[0] + ' onset')
        ax = axes_roc[0] if len(axes_roc) else axes_rate[0]
        ax.set_xlabel(xlab, x=xmid)

    # Reformat ticks on x axis. First period is reference.
    for axes in [axes_rate, axes_roc]:
        for ax, lbl_shift in zip(axes, prd_pars.lbl_shift):
            x1, x2 = ax.get_xlim()
            shift = float(lbl_shift)
            tmakrs, tlbls = putil.get_tick_marks_and_labels(
                x1 + shift, x2 + shift)
            putil.set_xtick_labels(ax, tmakrs - shift, tlbls)
    # Remove x tick labels from rate axes if roc's are present.
    if len(axes_roc):
        [ax.tick_params(labelbottom='off') for ax in axes_rate]

    # Match scale of rate and roc plots' y axes.
    for axes in [axes_rate, axes_roc]:
        putil.sync_axes(axes, sync_y=True)
        [putil.adjust_decorators(ax) for ax in axes]

    return axes_raster, axes_rate, axes_roc
Exemple #8
0
def plot_SR_matrix(u, param, vals=None, sps=None, fig=None):
    """
    Plot stimulus response in matrix layout by target and delay length for
    combined task.

    This function is currently out of date, needs updating!
    """

    # Init params.
    dsplit_prd_pars = constants.tr_half_prds.copy()
    dcomb_prd_pars = constants.tr_third_prds.copy()
    targets = u.TrData['ToReport'].unique()
    dlens = util.remove_dim_from_array(constants.del_lens)

    # Init axes.
    nrow = len(targets) + (len(targets) > 1)
    ncol = len(dlens) + (len(dlens) > 1)
    # Width ratio, depending in delay lengths.
    wratio = None
    if len(dlens) > 1:
        len_wo_delay = float(dsplit_prd_pars.max_len.sum()) - dlens[-1]
        wratio = [len_wo_delay + dlens[0]]  # combined (unsplit) column
        wratio.extend([len_wo_delay + dlen for dlen in dlens])
    gsp = putil.embed_gsp(sps,
                          nrow,
                          ncol,
                          width_ratios=wratio,
                          wspace=0.2,
                          hspace=0.4)
    raster_axs, rate_axs, roc_axs = [], [], []

    # Set up trial splits to plot by.
    mtargets = ['all'] + list(targets if len(targets) > 1 else [])
    mdlens = ['all'] + list(dlens if len(dlens) > 1 else [])
    # No split (all included trials).
    trs_splits = pd.Series([u.inc_trials()], index=[('all', 'all')])
    # Split by report only.
    if len(targets) > 1:
        by_report = u.trials_by_params(['ToReport'])
        by_report.index = [(r, 'all') for r in by_report.index]
        trs_splits = trs_splits.append(by_report)
    # Split by delay length only.
    if len(dlens) > 1:
        by_dlen = u.trials_by_params(['DelayLen'])
        by_dlen.index = [('all', dl) for dl in by_dlen.index]
        trs_splits = trs_splits.append(by_dlen)
    # Split by both report and delay length.
    if len(targets) > 1 and len(dlens) > 1:
        target_dlen_trs = u.trials_by_params(['ToReport', 'DelayLen'])
        trs_splits = trs_splits.append(target_dlen_trs)

    # Plot each split on matrix layout.
    for i, target in enumerate(mtargets):
        for j, dlen in enumerate(mdlens):

            # Init plotting params.
            dlen_str = str(int(dlen)) + ' ms' if util.is_number(dlen) else dlen
            title = 'report: {}    |    delay: {}'.format(target, dlen_str)
            from_trs = trs_splits[(target, dlen)]

            # Periods to plot.
            if dlen == 'all':
                prd_pars = dcomb_prd_pars.copy()
                prd_pars = u.get_analysis_prds(prd_pars, from_trs)
            else:
                prd_pars = dsplit_prd_pars.copy()
                prd_pars = u.get_analysis_prds(prd_pars, from_trs)
                S2_shift = constants.stim_dur['S1'] + dlen * ms
                prd_pars.lbl_shift['S2 half'] = S2_shift
                if len(dlens) > 1:
                    tdiff = (dlen - dlens[-1]) * ms
                    S1_max_len = prd_pars.max_len['S1 half'] + tdiff
                    prd_pars.max_len['S1 half'] = S1_max_len
            if 'cue' in prd_pars.columns:
                prd_pars['cue_color'] = putil.cue_colors[target]

            # Plot response.
            res = plot_SR(u,
                          param,
                          vals,
                          from_trs,
                          prd_pars,
                          title=title,
                          sps=gsp[i, j],
                          fig=fig)
            axes_raster, axes_rate, axes_roc = res

            # Remove superfluous labels.
            if i < len(mtargets) - 1:
                [ax.set_xlabel('') for ax in axes_rate]
            if j > 0:
                [ax.set_ylabel('') for ax in axes_rate]

            # Collect axes.
            raster_axs.extend(axes_raster)
            rate_axs.extend(axes_rate)
            roc_axs.extend(axes_roc)

    return raster_axs, rate_axs, roc_axs