def raster_rate(spk_list, rate_list, names=None, prds=None, evts=None, cols=None, baseline=None, title=None, rs_ylab=True, rate_kws=None, fig=None, ffig=None, sps=None): """Plot raster and rate plots.""" if rate_kws is None: rate_kws = dict() # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 2, 1, height_ratios=[.66, 1], hspace=.15) n_sets = max(len(spk_list), 1) # let's add an empty axes if no data gsp_raster = putil.embed_gsp(gsp[0], n_sets, 1, hspace=.15) gsp_rate = putil.embed_gsp(gsp[1], 1, 1) # Init colors. if cols is None: col_cyc = putil.get_colors(mpl_colors=True) cols = [next(col_cyc) for i in range(n_sets)] # Raster plots. raster_axs = [fig.add_subplot(gsp_raster[i, 0]) for i in range(n_sets)] for i, (spk_trs, ax) in enumerate(zip(spk_list, raster_axs)): ylab = names[i] if (rs_ylab and names is not None) else None raster(spk_trs, prds=prds, c=cols[i], ylab=ylab, ax=ax) putil.hide_axes(ax) if len(raster_axs): putil.set_labels(raster_axs[0], title=title) # add title to top raster # Rate plot. rate_ax = fig.add_subplot(gsp_rate[0, 0]) rate(rate_list, names, prds, evts, cols, baseline, **rate_kws, ax=rate_ax) # Synchronize raster's x axis limits to rate plot's limits. xlim = rate_ax.get_xlim() [ax.set_xlim(xlim) for ax in raster_axs] # Save and return plot. putil.save_fig(ffig, fig) return fig, raster_axs, rate_ax
def plot_DR_3x3(u, fig=None, sps=None): """Plot 3x3 direction response plot, with polar plot in center.""" if not u.to_plot(): return # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 3, 3) # inner gsp with subplots # Polar plot. putil.set_style('notebook', 'white') ax_polar = fig.add_subplot(gsp[4], polar=True) for stim in constants.stim_dur.index: # for each stimulus stim_resp = u.get_stim_resp_vals(stim, 'Dir') resp_stats = util.calc_stim_resp_stats(stim_resp) dirs, resp = np.array(resp_stats.index) * deg, resp_stats['mean'] c = putil.stim_colors[stim] baseline = u.get_baseline() ptuning.plot_DR(dirs, resp, color=c, baseline=baseline, ax=ax_polar) putil.hide_ticks(ax_polar, 'y') # Raster-rate plots. putil.set_style('notebook', 'ticks') rr_pos = [5, 2, 1, 0, 3, 6, 7, 8] # Position of each direction. rr_dir_plot_pos = pd.Series(constants.all_dirs, index=rr_pos) rate_axs = [] for isp, d in rr_dir_plot_pos.iteritems(): # Prepare plot formatting. first_dir = (isp == 0) # Plot direction response across trial periods. res = plot_SR(u, 'Dir', [d], fig=fig, sps=gsp[isp], no_labels=True) draster_axs, drate_axs, _ = res # Remove axis ticks. for i, ax in enumerate(drate_axs): first_prd = (i == 0) show_x_tick_lbls = first_dir show_y_tick_lbls = first_dir & first_prd putil.hide_tick_labels(ax, show_x_tick_lbls, show_y_tick_lbls) # Add task name as title (to top center axes). if isp == 1: ttl = u.get_task() + (' [excluded]' if u.is_excluded() else '') putil.set_labels(draster_axs[0], title=ttl, ytitle=1.10, title_kws={'loc': 'right'}) rate_axs.extend(drate_axs) rate_axs.extend(drate_axs) # Match scale of y axes. putil.sync_axes(rate_axs, sync_y=True) [putil.adjust_decorators(ax) for ax in rate_axs] return ax_polar, rate_axs
def plot_DSI(u, nrate=None, fig=None, sps=None, prd_pars=None, no_labels=False): """Plot direction selectivity indices.""" # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 2, 1, hspace=0.6) ax_mDSI, ax_wDSI = [fig.add_subplot(igsp) for igsp in gsp] # Calculate DSI using maximum - opposite and weighted measure. prd_pars = u.get_analysis_prds() maxDSI, wghtDSI = [ direction.calc_DSI(u, fDSI, prd_pars, nrate) for fDSI in [direction.max_DS, direction.weighted_DS] ] # Get stimulus periods. stim_prds = u.get_stim_prds() # Plot DSIs. DSI_list = [pd.DataFrame(row).T for name, row in maxDSI.iterrows()] prate.rate(DSI_list, maxDSI.index, prds=stim_prds, pval=None, ylab='mDSI', title='max - opposite DSI', add_lgn=True, lgn_lbl=None, ax=ax_mDSI) DSI_list = [pd.DataFrame(row).T for name, row in wghtDSI.iterrows()] prate.rate(DSI_list, wghtDSI.index, prds=stim_prds, pval=None, ylab='wDSI', title='weighted DSI', add_lgn=True, lgn_lbl=None, ax=ax_wDSI) return ax_mDSI, ax_wDSI
def plot_selectivity(u, fig=None, sps=None): """Plot selectivity summary plot.""" if not u.to_plot(): return # Check which stimulus parameters were variable. Don't plot selectivity for # variables that were kept constant. stims = ('S1', 'S2') trs = u.inc_trials() # Location. s1locs, s2locs = [u.TrData[(stim, 'Loc')][trs].unique() for stim in stims] plot_lr = False if (len(s1locs) == 1) and (len(s2locs) == 1) else True # Direction. s1dirs, s2dirs = [u.TrData[(stim, 'Dir')][trs].unique() for stim in stims] plot_dr = False if (len(s1dirs) == 1) and (len(s2dirs) == 1) else True # Init subplots. nsps = 1 + plot_lr + plot_dr # all trials + location sel + direction sel sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, nsps, 1, hspace=0.3) # Plot task-relatedness. plot_all_trials(u, gsp[0], fig) igsp = 1 # Plot location-specific activity. if plot_lr: _, lr_rate_axs, _ = plot_LR(u, gsp[1], fig) igsp = igsp + 1 else: lr_rate_axs = [] # Plot direction-specific activity. if plot_dr: _, dr_rate_axs, _ = plot_DR(u, gsp[igsp], fig) else: dr_rate_axs = [] return lr_rate_axs, dr_rate_axs
def plot_qm(u, bs_stats, stab_prd_res, prd_inc, tr_inc, spk_inc, add_lbls=False, ftempl=None, fig=None, sps=None): """Plot quality metrics related figures.""" # Init values. waveforms = np.array(u.Waveforms) wavetime = u.Waveforms.columns * us spk_times = np.array(u.SpikeParams['time'], dtype=float) base_rate = u.QualityMetrics['baseline'] # Minimum and maximum gain. gmin = u.SessParams['minV'] gmax = u.SessParams['maxV'] # %% Init plots. # Disable inline plotting to prevent memory leak. putil.inline_off() # Init figure and gridspec. fig = putil.figure(fig) if sps is None: sps = putil.gridspec(1, 1)[0] ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1]) info_sps, qm_sps = ogsp[0], ogsp[1] # Info header. info_ax = fig.add_subplot(info_sps) putil.hide_axes(info_ax) title = putil.get_unit_info_title(u) putil.set_labels(ax=info_ax, title=title, ytitle=0.80) # Create axes. gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4) ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)] ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)] ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)] # Trial markers. trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop tr_markers = pd.DataFrame({'time': trial_starts[9::10]}) tr_markers['label'] = [ str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index) ] # Common variables, limits and labels. WF_T_START = test_sorting.WF_T_START spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) - WF_T_START) ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts, trial_stops) ss, sa = 1.0, 0.8 # marker size and alpha on scatter plot # Color spikes by their occurance over session time. my_cmap = putil.get_cmap('jet') spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1)) if np.any(spk_inc): # check if there is any spike included spk_t_inc = np.array(spk_times[spk_inc]) tmin, tmax = float(spk_times.min()), float(spk_times.max()) spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin)) # Put excluded trials to the front, and randomise order of included trials # so later spikes don't systematically cover earlier ones. spk_order = np.hstack((np.where(np.invert(spk_inc))[0], np.random.permutation(np.where(spk_inc)[0]))) # Common labels for plots ses_t_lab = 'Recording time (s)' # %% Waveform shape analysis. # Plot included and excluded waveforms on different axes. # Color included by occurance in session time to help detect drifts. s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order] wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax] wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage' for st in ('Included', 'Excluded'): ax = ax_wf_inc if st == 'Included' else ax_wf_exc spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc) tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc) nspsk, ntrs = sum(spk_idx), sum(tr_idx) title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs) # Select waveforms and colors. rand_spk_idx = spk_idx[spk_order] wfs = s_waveforms[rand_spk_idx, :] cols = s_spk_cols[rand_spk_idx] # Plot waveforms. xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None) pwaveform.plot_wfs(wfs, spk_t, cols=cols, lw=0.1, alpha=0.05, xlim=wf_t_lim, ylim=glim, title=title, xlab=xlab, ylab=ylab, ax=ax) # %% Waveform summary metrics. # Init data. wf_amp_all = u.SpikeParams['amplitude'] wf_amp_inc = wf_amp_all[spk_inc] wf_dur_all = u.SpikeParams['duration'] wf_dur_inc = wf_dur_all[spk_inc] # Set common limits and labels. dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]] # same across units glim = max(wf_amp_all.max(), gmax - gmin) amp_lim = [0, glim] amp_lab = 'Amplitude' dur_lab = 'Duration ($\mu$s)' # Waveform amplitude across session time. m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std() title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp) xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_amp_all, spk_inc, c='m', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_amp) # Waveform duration across session time. mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std() title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur) xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_dur_all, spk_inc, c='c', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=dur_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_dur) # Waveform duration against amplitude. title = 'WF duration - amplitude' xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(wf_dur_all[spk_order], wf_amp_all[spk_order], c=spk_cols[spk_order], s=ss, xlab=xlab, ylab=ylab, xlim=dur_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_amp_dur) # %% Firing rate. tmean = np.array(bs_stats['tmean']) rmean = util.remove_dim_from_series(bs_stats['rate']) prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop'] # Color segments depending on whether they are included / excluded. def plot_periods(v, color, ax): # Plot line segments. for i in range(len(prd_inc[:-1])): col = color if prd_inc[i] and prd_inc[i + 1] else 'grey' x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])] ax.plot(x, y, color=col) # Plot line points. for i in range(len(prd_inc)): col = color if prd_inc[i] else 'grey' x, y = [tmean[i], v[i]] ax.plot(x, y, color=col, marker='o', markersize=3, markeredgecolor=col) # Firing rate over session time. title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate)) xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None) ylim = [0, 1.25 * np.max(rmean)] plot_periods(rmean, 'b', ax_rate) pplot.lines([], [], c='b', xlim=ses_t_lim, ylim=ylim, title=title, xlab=xlab, ylab=ylab, ax=ax_rate) # Trial markers. putil.plot_events(tr_markers, lw=0.5, ls='--', alpha=0.35, y_lbl=0.92, ax=ax_rate) # Excluded periods. excl_prds = [] tstart, tstop = ses_t_lim if tstart != prd_tstart: excl_prds.append(('beg', tstart, prd_tstart)) if tstop != prd_tstop: excl_prds.append(('end', prd_tstop, tstop)) putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate) # %% Post-formatting. # Maximize number of ticks on recording time axes to prevent covering. for ax in (ax_wf_amp, ax_wf_dur, ax_rate): putil.set_max_n_ticks(ax, 6, 'x') # %% Save figure. if ftempl is not None: fname = ftempl.format(u.name_to_fname()) putil.save_fig(fname, fig, title, rect_height=0.92) putil.inline_on() return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
def plot_DR_tuning(DS, title=None, labels=True, baseline=None, DR_legend=True, tuning_legend=True, ffig=None, fig=None, sps=None): """Plot direction selectivity on polar plot and tuning curve.""" # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 1, 2, wspace=0.3) ax_DR = fig.add_subplot(gsp[0], polar=True) ax_tuning = fig.add_subplot(gsp[1]) DR_patches = [] tuning_patches = [] stims = DS.DSI.index for stim in stims: # Extract params and results. a, b, x0, sigma, FWHM, R2, RMSE = DS.TP.loc[stim] DSI = DS.DSI.wDS[stim] PD, cPD = DS.PD.loc[(stim, 'weighted'), ['PD', 'cPD']] dirs = np.array(DS.DR[stim].index) * deg SR, SRsem = [DS.DR[stim][col] for col in ('mean', 'sem')] # Center direction and response stats. dirsc = direction.center_to_dir(dirs, PD) # Generate data points for plotting fitted tuning curve. x, y = tuning.gen_fit_curve(tuning.gaus, -180, 180, a=a, b=b, x0=x0, sigma=sigma) # Plot DR polar plot. color = putil.stim_colors.loc[stim] ttl = 'Direction response' if labels else None plot_DR(dirs, SR, DSI, PD, baseline, title=ttl, color=color, ax=ax_DR) # Collect parameters of DR plot (stimulus - response). s_pd = str(float(round(PD, 1))) s_pd_c = str(int(cPD)) if not np.isnan(float(cPD)) else 'nan' lgd_lbl = '{}: {:.3f}'.format(stim, DSI) lgd_lbl += ' {:>5}$^\circ$ --> {:>3}$^\circ$ '.format(s_pd, s_pd_c) DR_patches.append(putil.get_artist(lgd_lbl, color)) # Calculate and plot direction tuning curve. xticks = np.arange(-180, 180 + 1, 45) plot_tuning(x, y, dirsc, SR, SRsem, color, baseline, xticks, ax=ax_tuning) # Collect parameters tuning curve fit. s_a, s_b, s_x0, s_sigma, s_FWHM = [ str(float(round(p, 1))) for p in (a, b, x0, sigma, FWHM) ] s_R2 = format(R2, '.2f') lgd_lbl = '{}:{}{:>6}{}{:>6}'.format(stim, 5 * ' ', s_a, 5 * ' ', s_b) lgd_lbl += '{}{:>6}{}{:>6}'.format(5 * ' ', s_x0, 8 * ' ', s_sigma) lgd_lbl += '{}{:>6}{}{:>6}'.format(8 * ' ', s_FWHM, 8 * ' ', s_R2) tuning_patches.append(putil.get_artist(lgd_lbl, color)) # Format tuning curve plot (only after all stimuli has been plotted!). xlab = 'Difference from preferred direction (deg)' if labels else None ylab = putil.FR_lbl if labels else None xlim = [-180 - 5, 180 + 5] # degrees ylim = [0, None] ttl = 'Tuning curve' if labels else None putil.format_plot(ax_tuning, xlim, ylim, xlab, ylab, ttl) putil.add_zero_line('y', ax=ax_tuning) # add zero reference line # Set super title. if title is not None: fig.suptitle(title, y=1.1, fontsize='x-large') # Set legends. ylegend = -0.38 if labels else -0.15 fr_on = False if labels else True lgd_kws = dict([('fancybox', True), ('shadow', False), ('frameon', fr_on), ('framealpha', 1.0), ('loc', 'lower center'), ('bbox_to_anchor', [0., ylegend, 1., .0]), ('prop', { 'family': 'monospace' })]) DR_lgn_ttl = 'DSI'.rjust(20) + 'PD'.rjust(14) + 'PD8'.rjust(14) tuning_lgd_ttl = ('a (sp/s)'.rjust(35) + 'b (sp/s)'.rjust(15) + 'x0 (deg)'.rjust(13) + 'sigma (deg)'.rjust(15) + 'FWHM (deg)'.rjust(15) + 'R-squared'.rjust(15)) lgn_params = [(DR_legend, DR_lgn_ttl, DR_patches, ax_DR), (tuning_legend, tuning_lgd_ttl, tuning_patches, ax_tuning)] for (plot_legend, lgd_ttl, patches, ax) in lgn_params: if not plot_legend: continue if not labels: # customize for summary plot lgd_ttl = None lgd = putil.set_legend(ax, handles=patches, title=lgd_ttl, **lgd_kws) lgd.get_title().set_ha('left') if lgd_kws['frameon']: lgd.get_frame().set_linewidth(.5) # Save figure. if ffig is not None: putil.save_fig(ffig, fig, rect_height=0.55) return ax_DR, ax_tuning
def plot_SR(u, param=None, vals=None, from_trs=None, prd_pars=None, nrate=None, colors=None, add_roc=False, add_prd_name=False, fig=None, sps=None, title=None, **kwargs): """Plot stimulus response (raster, rate and ROC) for mutliple stimuli.""" has_nan_val = (util.is_iterable(vals) and len([v for v in vals if np.isnan(float(v))])) if not u.to_plot() or has_nan_val: return [], [], [] # Set up stimulus parameters. if prd_pars is None: prd_pars = u.get_analysis_prds() # Init subplots. sps, fig = putil.sps_fig(sps, fig) # Create a gridspec for each period. wratio = [ float(min(dur, lmax)) for dur, lmax in zip(prd_pars.dur, prd_pars.max_len) ] wspace = 0.0 # 0.05 to separate periods gsp = putil.embed_gsp(sps, 1, len(prd_pars.index), width_ratios=wratio, wspace=wspace) axes_raster, axes_rate, axes_roc = [], [], [] for i, prd in enumerate(prd_pars.index): ppars = prd_pars.loc[prd] ref = ppars.ref # Prepare trial set. if param is None: trs = u.ser_inc_trials() elif param in u.TrData.columns: trs = u.trials_by_param(param, vals) else: trs = u.trials_by_param((ppars.stim, param), vals) if from_trs is not None: trs = util.filter_lists(trs, from_trs) plot_roc = add_roc and (len(trs) == 2) # Init subplots. if plot_roc: rr_sps, roc_sps = putil.embed_gsp(gsp[i], 2, 1, hspace=0.2, height_ratios=[1, .4]) else: rr_sps = putil.embed_gsp(gsp[i], 1, 1, hspace=0.3)[0] # Init params. prds = [constants.ev_stims.loc[ref]] evnts = None if ('cue' in ppars) and (ppars.cue is not None): evnts = [{'time': ppars.cue}] evnts[0]['color'] = (ppars.cue_color if 'cue_color' in ppars.index else putil.cue_colors['all']) if colors is None: colcyc = putil.get_colors() colors = [next(colcyc) for itrs in range(len(trs))] # Plot response on raster and rate plots. _, raster_axs, rate_ax = prate.plot_rr(u, prd, ref, prds, evnts, nrate, trs, ppars.max_len, cols=colors, fig=fig, sps=rr_sps, **kwargs) # Add period name to rate plot. if add_prd_name: rate_ax.text(0.02, 0.95, prd, fontsize=10, color='k', va='top', ha='left', transform=rate_ax.transAxes) # Plot ROC curve. if plot_roc: roc_ax = fig.add_subplot(roc_sps) # Init rates. plot_params = prate.prep_rr_plot_params(u, prd, ref, nrate, trs) _, _, (rates1, rates2), _, _ = plot_params # Calculate ROC results. aroc = roccore.run_ROC_over_time(rates1, rates2, n_perm=0) # Set up plot params and plot results. tvec, auc = aroc.index, aroc.auc xlim = rate_ax.get_xlim() pauc.plot_auc_over_time(auc, tvec, prds, evnts, xlim=xlim, ax=roc_ax) # Some exception handling: set x axis range to predefined values if # it is not set (probably due to no trial data plotted). if prd in constants.fixed_tr_prds.index: t1, t2 = constants.fixed_tr_prds.loc[prd] axs = raster_axs + [rate_ax] + ([roc_ax] if plot_roc else []) for ax in axs: if ax.get_xlim() == (-0.001, 0.001): ax.set_xlim([t1, t2]) # Collect axes. axes_raster.extend(raster_axs) axes_rate.append(rate_ax) if plot_roc: axes_roc.append(roc_ax) # Format rate and roc plots to make them match across trial periods. # Remove y-axis label, spine and ticks from second and later periods. for ax in axes_rate[1:] + axes_roc[1:]: ax.set_ylabel('') putil.set_spines(ax, bottom=True, left=False) putil.hide_ticks(ax, show_x_ticks=True, show_y_ticks=False) # Remove (hide) legend from all but last rate plot. [putil.hide_legend(ax) for ax in axes_rate[:-1]] # Get middle x point in relative coordinates of first rate axes. xranges = np.array([ax.get_xlim() for ax in axes_rate]) xlens = xranges[:, 1] - xranges[:, 0] xmid = xlens.sum() / xlens[0] / 2 # Add title. if title is not None: axes_raster[0].set_title(title, x=xmid, y=1.0) # Set common x label. if ('no_labels' not in kwargs) or (not kwargs['no_labels']): [ax.set_xlabel('') for ax in axes_rate + axes_roc] xlab = putil.t_lbl.format(prd_pars.stim[0] + ' onset') ax = axes_roc[0] if len(axes_roc) else axes_rate[0] ax.set_xlabel(xlab, x=xmid) # Reformat ticks on x axis. First period is reference. for axes in [axes_rate, axes_roc]: for ax, lbl_shift in zip(axes, prd_pars.lbl_shift): x1, x2 = ax.get_xlim() shift = float(lbl_shift) tmakrs, tlbls = putil.get_tick_marks_and_labels( x1 + shift, x2 + shift) putil.set_xtick_labels(ax, tmakrs - shift, tlbls) # Remove x tick labels from rate axes if roc's are present. if len(axes_roc): [ax.tick_params(labelbottom='off') for ax in axes_rate] # Match scale of rate and roc plots' y axes. for axes in [axes_rate, axes_roc]: putil.sync_axes(axes, sync_y=True) [putil.adjust_decorators(ax) for ax in axes] return axes_raster, axes_rate, axes_roc
def plot_SR_matrix(u, param, vals=None, sps=None, fig=None): """ Plot stimulus response in matrix layout by target and delay length for combined task. This function is currently out of date, needs updating! """ # Init params. dsplit_prd_pars = constants.tr_half_prds.copy() dcomb_prd_pars = constants.tr_third_prds.copy() targets = u.TrData['ToReport'].unique() dlens = util.remove_dim_from_array(constants.del_lens) # Init axes. nrow = len(targets) + (len(targets) > 1) ncol = len(dlens) + (len(dlens) > 1) # Width ratio, depending in delay lengths. wratio = None if len(dlens) > 1: len_wo_delay = float(dsplit_prd_pars.max_len.sum()) - dlens[-1] wratio = [len_wo_delay + dlens[0]] # combined (unsplit) column wratio.extend([len_wo_delay + dlen for dlen in dlens]) gsp = putil.embed_gsp(sps, nrow, ncol, width_ratios=wratio, wspace=0.2, hspace=0.4) raster_axs, rate_axs, roc_axs = [], [], [] # Set up trial splits to plot by. mtargets = ['all'] + list(targets if len(targets) > 1 else []) mdlens = ['all'] + list(dlens if len(dlens) > 1 else []) # No split (all included trials). trs_splits = pd.Series([u.inc_trials()], index=[('all', 'all')]) # Split by report only. if len(targets) > 1: by_report = u.trials_by_params(['ToReport']) by_report.index = [(r, 'all') for r in by_report.index] trs_splits = trs_splits.append(by_report) # Split by delay length only. if len(dlens) > 1: by_dlen = u.trials_by_params(['DelayLen']) by_dlen.index = [('all', dl) for dl in by_dlen.index] trs_splits = trs_splits.append(by_dlen) # Split by both report and delay length. if len(targets) > 1 and len(dlens) > 1: target_dlen_trs = u.trials_by_params(['ToReport', 'DelayLen']) trs_splits = trs_splits.append(target_dlen_trs) # Plot each split on matrix layout. for i, target in enumerate(mtargets): for j, dlen in enumerate(mdlens): # Init plotting params. dlen_str = str(int(dlen)) + ' ms' if util.is_number(dlen) else dlen title = 'report: {} | delay: {}'.format(target, dlen_str) from_trs = trs_splits[(target, dlen)] # Periods to plot. if dlen == 'all': prd_pars = dcomb_prd_pars.copy() prd_pars = u.get_analysis_prds(prd_pars, from_trs) else: prd_pars = dsplit_prd_pars.copy() prd_pars = u.get_analysis_prds(prd_pars, from_trs) S2_shift = constants.stim_dur['S1'] + dlen * ms prd_pars.lbl_shift['S2 half'] = S2_shift if len(dlens) > 1: tdiff = (dlen - dlens[-1]) * ms S1_max_len = prd_pars.max_len['S1 half'] + tdiff prd_pars.max_len['S1 half'] = S1_max_len if 'cue' in prd_pars.columns: prd_pars['cue_color'] = putil.cue_colors[target] # Plot response. res = plot_SR(u, param, vals, from_trs, prd_pars, title=title, sps=gsp[i, j], fig=fig) axes_raster, axes_rate, axes_roc = res # Remove superfluous labels. if i < len(mtargets) - 1: [ax.set_xlabel('') for ax in axes_rate] if j > 0: [ax.set_ylabel('') for ax in axes_rate] # Collect axes. raster_axs.extend(axes_raster) rate_axs.extend(axes_rate) roc_axs.extend(axes_roc) return raster_axs, rate_axs, roc_axs