def plot_weights(ax, Coefs, prds=None, xlim=None, xlab=tlab, ylab='unit coefficient', title='', ytitle=1.04): """Plot decoding weights.""" # Unstack dataframe with results. lCoefs = pd.DataFrame(Coefs.unstack().unstack(), columns=['coef']) lCoefs['time'] = lCoefs.index.get_level_values(0) lCoefs['value'] = lCoefs.index.get_level_values(1) lCoefs['uid'] = lCoefs.index.get_level_values(2) lCoefs.index = np.arange(len(lCoefs.index)) # Plot time series. sns.tsplot(lCoefs, time='time', value='coef', unit='value', condition='uid', ax=ax) # Add chance level line and stimulus periods. putil.add_chance_level(ax=ax, ylevel=0) putil.plot_periods(prds, ax=ax) # Set axis limits. xlim = xlim if xlim is not None else tlim putil.set_limits(ax, xlim) # Format plot. putil.set_labels(ax, xlab, ylab, title, ytitle) putil.hide_legend(ax)
def plot_mean_rates(mRates, aa_res_dir, tasks=None, task_lbls=None, baseline=None, xlim=None, ylim=None, ci=68, ffig=None, figsize=(6, 4)): """Plot mean rates across tasks.""" # Init. if tasks is None: tasks = mRates.keys() # Plot mean activity. lRates = [] for task in tasks: lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate']) lrates['task'] = task lRates.append(lrates) lRates = pd.concat(lRates) lRates['time'] = lRates.index.get_level_values(0) lRates['unit'] = lRates.index.get_level_values(1) if task_lbls is not None: lRates.task.replace(task_lbls, inplace=True) # Plot as time series. #putil.set_style('notebook', 'white') fig = putil.figure(figsize=figsize) ax = putil.axes() sns.tsplot(lRates, time='time', value='rate', unit='unit', condition='task', ci=ci, ax=ax) # Add periods and baseline. putil.plot_periods(ax=ax) if baseline is not None: putil.add_baseline(baseline, ax=ax) # Format plot. sns.despine(ax=ax) putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)') putil.set_limits(ax, xlim, ylim) putil.hide_legend_title(ax) # Save plot. putil.save_fig(ffig, fig) return fig, ax
def raster(spk_trains, t_unit=ms, prds=None, c='b', xlim=None, title=None, xlab=None, ylab=None, ffig=None, ax=None): """Plot rasterplot.""" # Init. ax = putil.axes(ax) putil.plot_periods(prds, ax=ax) putil.set_limits(ax, xlim) # There's nothing to plot. if not len(spk_trains): return ax # Plot raster. for i, spk_tr in enumerate(spk_trains): x = np.array(spk_tr.rescale(t_unit)) y = (i + 1) * np.ones_like(x) # Spike markers are plotted in absolute size (figure coordinates). # ax.scatter(x, y, c=c, s=1.8, edgecolor=c, marker='|') # Spike markers are plotted in relative size (axis coordinates) patches = [ Rectangle((xi - wsp / 2, yi - hsp / 2), wsp, hsp) for xi, yi in zip(x, y) ] collection = PatchCollection(patches, facecolor=c, edgecolor=c) ax.add_collection(collection) # Format plot. ylim = [0.5, len(spk_trains) + 0.5] if len(spk_trains) else [0, 1] if xlab is not None: xlab = putil.t_lbl.format(xlab) putil.format_plot(ax, xlim, ylim, xlab, ylab, title) putil.hide_axes(ax, show_x=True) putil.hide_spines(ax) # Order trials from top to bottom, only after setting axis limits. ax.invert_yaxis() # Save and return plot. putil.save_fig(ffig) return ax
def plot_ROC_mean(d_faroc, t1=None, t2=None, ylim=None, colors=None, ylab='AROC', ffig=None): """Plot mean ROC curves over given period.""" # Import results. d_aroc = {} for name, faroc in d_faroc.items(): aroc = util.read_objects(faroc, 'aroc') d_aroc[name] = aroc.unstack().T # Format results. laroc = pd.DataFrame(pd.concat(d_aroc), columns=['aroc']) laroc['task'] = laroc.index.get_level_values(0) laroc['time'] = laroc.index.get_level_values(1) laroc['unit'] = laroc.index.get_level_values(2) laroc.index = np.arange(len(laroc.index)) # Init figure. fig = putil.figure(figsize=(6, 6)) ax = sns.tsplot(laroc, time='time', value='aroc', unit='unit', condition='task', color=colors) # Highlight stimulus periods. putil.plot_periods(ax=ax) # Plot mean results. [ax.lines[i].set_linewidth(3) for i in range(len(ax.lines))] # Add chance level line. putil.add_chance_level(ax=ax, alpha=0.8, color='k') ax.lines[-1].set_linewidth(1.5) # Format plot. xlab = 'Time since S1 onset (ms)' putil.set_labels(ax, xlab, ylab) putil.set_limits(ax, [t1, t2], ylim) putil.set_spines(ax, bottom=True, left=True, top=False, right=False) putil.set_legend(ax, loc=0) # Save plot. putil.save_fig(ffig, fig, ytitle=1.05, w_pad=15)
def plot_scores(ax, Scores, Perm=None, Psdo=None, nvals=None, prds=None, col='b', perm_col='grey', psdo_col='g', xlim=None, ylim=ylim_scr, xlab=tlab, ylab=ylab_scr, title='', ytitle=1.04): """Plot decoding accuracy results.""" lgn_patches = [] # Plot permuted results (if exist). if not util.is_null(Perm) and not Perm.isnull().all().all(): x, pval = Perm.columns, Perm.loc['pval'] ymean, ystd = Perm.loc['mean'], Perm.loc['std'] plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=6, color=perm_col, ax=ax) lgn_patches.append(putil.get_artist('permuted', perm_col)) # Plot population shuffled results (if exist). if not util.is_null(Psdo) and not Psdo.isnull().all().all(): x, pval = Psdo.columns, Psdo.loc['pval'] ymean, ystd = Psdo.loc['mean'], Psdo.loc['std'] plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=3, color=psdo_col, ax=ax) lgn_patches.append(putil.get_artist('pseudo-population', psdo_col)) # Plot scores. plot_score_set(Scores, ax, color=col) lgn_patches.append(putil.get_artist('synchronous', col)) # Add legend. lgn_patches = lgn_patches[::-1] putil.set_legend(ax, handles=lgn_patches) # Add chance level line. # This currently plots all nvals combined across stimulus period! if nvals is not None: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax) # Set axis limits. xlim = xlim if xlim is not None else tlim putil.set_limits(ax, xlim, ylim) # Format plot. putil.set_labels(ax, xlab, ylab, title, ytitle)
def plot_auc_over_time(auc, tvec, prds=None, evts=None, xlim=None, ylim=None, xlab='time', ylab='AUC', title=None, ax=None): """Plot AROC values over time.""" # Init params. ax = putil.axes(ax) if xlim is None: xlim = [min(tvec), max(tvec)] # Plot periods first. putil.plot_periods(prds, ax=ax) # Plot AUC over time. pplot.lines(tvec, auc, ylim, xlim, xlab, ylab, title, color='green', ax=ax) # Add chance level line. putil.add_chance_level(ax=ax) # # Set minimum y axis scale. # ymin, ymax = ax.get_ylim() # ymin, ymax = min(ymin, 0.3), max(ymax, 0.7) # ax.set_ylim([ymin, ymax]) # Set y tick labels. if ylim is not None and ylim[0] == 0 and ylim[1] == 1: tck_marks = np.linspace(0, 1, 5) tck_lbls = np.array(tck_marks, dtype=str) tck_lbls[1::2] = '' putil.set_ytick_labels(ax, tck_marks, tck_lbls) putil.set_max_n_ticks(ax, 5, 'y') # Plot event markers. putil.plot_event_markers(evts, ax=ax) return ax
def select_units_trials(UA, utids=None, fres=None, ffig=None, min_n_units=5, min_n_trs_per_unit=5): """ Select optimal set of units and trials for population decoding. min_n_units: minimum number of units to keep (0: off) min_n_trs_per_unit: minimum number of trials per unit to keep (0: off) """ print('Selecting optimal set of units and trials for decoding...') # Init. if utids is None: utids = UA.utids(as_series=True) u_rt_grpby = utids.groupby(level=['subj', 'date', 'task']) # Unit info frame. UInc = pd.Series(False, index=utids.index) # Included trials by unit. IncTrs = pd.Series([(UA.get_unit(utid[:-1], utid[-1]).inc_trials()) for utid in utids], index=utids.index) # Result DF. rec_task = pd.MultiIndex.from_tuples([rt for rt, _ in u_rt_grpby], names=['subj', 'date', 'task']) cols = ['elec', 'units', 'nunits', 'nallunits', '% remaining units', 'trials', 'ntrials', 'nalltrials', '% remaining trials'] RecInfo = pd.DataFrame(index=rec_task, columns=cols) rt_utids = [utids.xs((s, d, t), level=('subj', 'date', 'task')) for s, d, t in rec_task] RecInfo.nallunits = [len(utids) for utids in rt_utids] rt_ulist = [UA.get_unit(utids[0][:-1], utids[0][-1]) for utids in rt_utids] RecInfo.nalltrials = [int(u.QualityMetrics['NTrialsTotal']) for u in rt_ulist] # Function to plot matrix (DF) included/excluded trials. def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None, xlab='Trial #', ylab=None,): # Plot on heatmap. sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax) # Set tick labels. putil.hide_tick_marks(ax) tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25)) ax.xaxis.set_ticks(tr_ticks) ax.set_xticklabels(tr_ticks) putil.rot_xtick_labels(ax, 0) putil.rot_ytick_labels(ax, 0, va='center') putil.set_labels(ax, xlab, ylab, title, ytitle) # Init plotting. ytitle = 1.40 putil.set_style('notebook', 'whitegrid') fig, gsp, axs = putil.get_gs_subplots(nrow=len(rec_task), ncol=3, subw=6, subh=4, create_axes=True) for i_rt, ((subj, date, task), rt_utids) in enumerate(u_rt_grpby): print('{} / {}: {} - {} {}'.format(i_rt+1, len(u_rt_grpby), subj, date, task)) # Init electrode. elecs = rt_utids.index.get_level_values('elec').unique() if len(elecs) != 1: warnings.warn('More than one electrode?') elec = elecs[0] RecInfo.loc[(subj, date, task), 'elec'] = elec # Create matrix of included trials of recording & task of units. ch_idxs = rt_utids.index.droplevel(-1).droplevel(2).droplevel(1).droplevel(0) n_alltrs = RecInfo.nalltrials[(subj, date, task)] IncTrsMat = pd.DataFrame(np.zeros((len(ch_idxs), n_alltrs), dtype=int), index=ch_idxs, columns=np.arange(n_alltrs)+1) for ch_idx, utid in zip(ch_idxs, rt_utids): IncTrsMat.loc[ch_idx].iloc[IncTrs[utid]] = 1 # Plot included/excluded trials after preprocessing. ax = axs[i_rt, 0] ylab = '{} {} {}'.format(subj, date, task) title = ('Included (green) and excluded (red) trials' if i_rt == 0 else None) plot_inc_exc_trials(IncTrsMat, ax, title, ytitle, ylab=ylab) # Calculate and plot overlap of trials across units. # How many trials will remain if we iteratively excluding units # with the least overlap with the rest of the units? def n_cov_trs(df): # return number of trials covered in df return sum(df.all()) def calc_heuristic(df): return df.shape[0] * n_cov_trs(df) n_trs = IncTrsMat.sum(1) n_units = IncTrsMat.shape[0] # Init results DF. columns = ('uid', 'ntrs_cov', 'n_rem_u', 'trial x units') tr_covs = pd.DataFrame(columns=columns, index=range(n_units+1)) tr_covs.loc[0] = ('none', n_cov_trs(IncTrsMat), n_units, calc_heuristic(IncTrsMat)) # Subset of included units (to be updated in each iteration). uinc = IncTrsMat.index.to_series() for iu in range(1, len(uinc)): # Number of covered trials after removing each unit. sntrscov = pd.Series([n_cov_trs(IncTrsMat.loc[uinc.drop(uid)]) for uid in uinc], index=uinc.index) ######################################### # Select and remove unit that # # (a) yields maximum trial coverage, # # (b) has minimum number of trials # ######################################### maxtrscov = sntrscov.max() worst_us = sntrscov[sntrscov == maxtrscov].index # (a) utrs = n_trs.loc[worst_us] uid_remove = utrs[(utrs == min(utrs))].index[0] # (b) # Update current subset of units and their trial DF. uinc.drop(uid_remove, inplace=True) tr_covs.loc[iu] = (uid_remove, maxtrscov, len(uinc), calc_heuristic(IncTrsMat.loc[uinc])) # Add last unit. tr_covs.iloc[-1] = (uinc[0], 0, 0, 0) # Plot covered trials against each units removed. ax_trc = axs[i_rt, 1] sns.tsplot(tr_covs['ntrs_cov'], marker='o', ms=4, color='b', ax=ax_trc) title = ('Trial coverage during iterative unit removal' if i_rt == 0 else None) xlab, ylab = 'current unit removed', '# trials covered' putil.set_labels(ax_trc, xlab, ylab, title, ytitle) ax_trc.xaxis.set_ticks(tr_covs.index) x_ticklabs = ['none'] + ['{} - {}'.format(ch, ui) for ch, ui in tr_covs.uid.loc[1:]] ax_trc.set_xticklabels(x_ticklabs) putil.rot_xtick_labels(ax_trc, 45) ax_trc.grid(True) # Add # of remaining units to top. ax_remu = ax_trc.twiny() ax_remu.xaxis.set_ticks(tr_covs.index) ax_remu.set_xticklabels(list(range(len(x_ticklabs)))[::-1]) ax_remu.set_xlabel('# units remaining') ax_remu.grid(None) # Add heuristic index. ax_heur = ax_trc.twinx() sns.tsplot(tr_covs['trial x units'], linestyle='--', marker='o', ms=4, color='m', ax=ax_heur) putil.set_labels(ax_heur, ylab='remaining units x covered trials') [tl.set_color('m') for tl in ax_heur.get_yticklabels()] [tl.set_color('b') for tl in ax_trc.get_yticklabels()] ax_heur.grid(None) # Decide on which units to exclude. min_n_trials = min_n_trs_per_unit * tr_covs['n_rem_u'] sub_tr_covs = tr_covs[(tr_covs['n_rem_u'] >= min_n_units) & (tr_covs['ntrs_cov'] >= min_n_trials)] # If any subset of units passed above criteria. rem_uids, exc_uids = pd.Series(), tr_covs.uid[1:] n_tr_rem, n_tr_exc = 0, IncTrsMat.shape[1] if len(sub_tr_covs.index): hmax_idx = sub_tr_covs['trial x units'].argmax() rem_uids = tr_covs.uid[(hmax_idx+1):] exc_uids = tr_covs.uid[1:hmax_idx+1] n_tr_rem = tr_covs.ntrs_cov[hmax_idx] n_tr_exc = IncTrsMat.shape[1] - n_tr_rem # Add to UnitInfo dataframe rt_utids = [(subj, date, elec, ch, ui, task) for ch, ui in rem_uids] UInc[rt_utids] = True # Highlight selected point in middle plot. sel_seg = [('selection', exc_uids.shape[0]-0.4, exc_uids.shape[0]+0.4)] putil.plot_periods(sel_seg, ax=ax_trc, alpha=0.3) [ax.set_xlim([-0.5, n_units+0.5]) for ax in (ax_trc, ax_remu)] # Generate remaining trials dataframe. RemTrsMat = IncTrsMat.copy().astype(float) for exc_uid in exc_uids: # Remove all trials from excluded units. RemTrsMat.loc[exc_uid] = 0.5 # Remove uncovered trials in remaining units. exc_trs = np.where(~RemTrsMat.loc[list(rem_uids)].all())[0] if exc_trs.size: RemTrsMat.iloc[:, exc_trs] = 0.5 # Overwrite by trials excluded during preprocessing. RemTrsMat[IncTrsMat == False] = 0.0 # Plot remaining trials. ax = axs[i_rt, 2] n_u_rem, n_u_exc = len(rem_uids), len(exc_uids) title = ('# units remaining: {}, excluded: {}'.format(n_u_rem, n_u_exc) + '\n# trials remaining: {}, excluded: {}'.format(n_tr_rem, n_tr_exc)) plot_inc_exc_trials(RemTrsMat, ax, title=title, ylab='') # Add remaining units and trials to RecInfo. rt = (subj, date, task) RecInfo.loc[rt, ('units', 'nunits')] = list(rem_uids), len(rem_uids) cov_trs = RemTrsMat.loc[list(rem_uids)].all() inc_trs = pd.Int64Index(np.where(cov_trs)[0]) RecInfo.loc[rt, ('trials', 'ntrials')] = inc_trs, sum(cov_trs) RecInfo['% remaining units'] = 100 * RecInfo.nunits / RecInfo.nallunits RecInfo['% remaining trials'] = 100 * RecInfo.ntrials / RecInfo.nalltrials # Save results. if fres is not None: results = {'RecInfo': RecInfo, 'UInc': UInc} util.write_objects(results, fres) # Save plot. title = 'Trial & unit selection prior decoding' putil.save_fig(ffig, fig, title, w_pad=3, h_pad=3) return RecInfo, UInc
def plot_qm(u, bs_stats, stab_prd_res, prd_inc, tr_inc, spk_inc, add_lbls=False, ftempl=None, fig=None, sps=None): """Plot quality metrics related figures.""" # Init values. waveforms = np.array(u.Waveforms) wavetime = u.Waveforms.columns * us spk_times = np.array(u.SpikeParams['time'], dtype=float) base_rate = u.QualityMetrics['baseline'] # Minimum and maximum gain. gmin = u.SessParams['minV'] gmax = u.SessParams['maxV'] # %% Init plots. # Disable inline plotting to prevent memory leak. putil.inline_off() # Init figure and gridspec. fig = putil.figure(fig) if sps is None: sps = putil.gridspec(1, 1)[0] ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1]) info_sps, qm_sps = ogsp[0], ogsp[1] # Info header. info_ax = fig.add_subplot(info_sps) putil.hide_axes(info_ax) title = putil.get_unit_info_title(u) putil.set_labels(ax=info_ax, title=title, ytitle=0.80) # Create axes. gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4) ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)] ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)] ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)] # Trial markers. trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop tr_markers = pd.DataFrame({'time': trial_starts[9::10]}) tr_markers['label'] = [ str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index) ] # Common variables, limits and labels. WF_T_START = test_sorting.WF_T_START spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) - WF_T_START) ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts, trial_stops) ss, sa = 1.0, 0.8 # marker size and alpha on scatter plot # Color spikes by their occurance over session time. my_cmap = putil.get_cmap('jet') spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1)) if np.any(spk_inc): # check if there is any spike included spk_t_inc = np.array(spk_times[spk_inc]) tmin, tmax = float(spk_times.min()), float(spk_times.max()) spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin)) # Put excluded trials to the front, and randomise order of included trials # so later spikes don't systematically cover earlier ones. spk_order = np.hstack((np.where(np.invert(spk_inc))[0], np.random.permutation(np.where(spk_inc)[0]))) # Common labels for plots ses_t_lab = 'Recording time (s)' # %% Waveform shape analysis. # Plot included and excluded waveforms on different axes. # Color included by occurance in session time to help detect drifts. s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order] wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax] wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage' for st in ('Included', 'Excluded'): ax = ax_wf_inc if st == 'Included' else ax_wf_exc spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc) tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc) nspsk, ntrs = sum(spk_idx), sum(tr_idx) title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs) # Select waveforms and colors. rand_spk_idx = spk_idx[spk_order] wfs = s_waveforms[rand_spk_idx, :] cols = s_spk_cols[rand_spk_idx] # Plot waveforms. xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None) pwaveform.plot_wfs(wfs, spk_t, cols=cols, lw=0.1, alpha=0.05, xlim=wf_t_lim, ylim=glim, title=title, xlab=xlab, ylab=ylab, ax=ax) # %% Waveform summary metrics. # Init data. wf_amp_all = u.SpikeParams['amplitude'] wf_amp_inc = wf_amp_all[spk_inc] wf_dur_all = u.SpikeParams['duration'] wf_dur_inc = wf_dur_all[spk_inc] # Set common limits and labels. dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]] # same across units glim = max(wf_amp_all.max(), gmax - gmin) amp_lim = [0, glim] amp_lab = 'Amplitude' dur_lab = 'Duration ($\mu$s)' # Waveform amplitude across session time. m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std() title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp) xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_amp_all, spk_inc, c='m', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_amp) # Waveform duration across session time. mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std() title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur) xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_dur_all, spk_inc, c='c', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=dur_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_dur) # Waveform duration against amplitude. title = 'WF duration - amplitude' xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(wf_dur_all[spk_order], wf_amp_all[spk_order], c=spk_cols[spk_order], s=ss, xlab=xlab, ylab=ylab, xlim=dur_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_amp_dur) # %% Firing rate. tmean = np.array(bs_stats['tmean']) rmean = util.remove_dim_from_series(bs_stats['rate']) prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop'] # Color segments depending on whether they are included / excluded. def plot_periods(v, color, ax): # Plot line segments. for i in range(len(prd_inc[:-1])): col = color if prd_inc[i] and prd_inc[i + 1] else 'grey' x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])] ax.plot(x, y, color=col) # Plot line points. for i in range(len(prd_inc)): col = color if prd_inc[i] else 'grey' x, y = [tmean[i], v[i]] ax.plot(x, y, color=col, marker='o', markersize=3, markeredgecolor=col) # Firing rate over session time. title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate)) xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None) ylim = [0, 1.25 * np.max(rmean)] plot_periods(rmean, 'b', ax_rate) pplot.lines([], [], c='b', xlim=ses_t_lim, ylim=ylim, title=title, xlab=xlab, ylab=ylab, ax=ax_rate) # Trial markers. putil.plot_events(tr_markers, lw=0.5, ls='--', alpha=0.35, y_lbl=0.92, ax=ax_rate) # Excluded periods. excl_prds = [] tstart, tstop = ses_t_lim if tstart != prd_tstart: excl_prds.append(('beg', tstart, prd_tstart)) if tstop != prd_tstop: excl_prds.append(('end', prd_tstop, tstop)) putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate) # %% Post-formatting. # Maximize number of ticks on recording time axes to prevent covering. for ax in (ax_wf_amp, ax_wf_dur, ax_rate): putil.set_max_n_ticks(ax, 6, 'x') # %% Save figure. if ftempl is not None: fname = ftempl.format(u.name_to_fname()) putil.save_fig(fname, fig, title, rect_height=0.92) putil.inline_on() return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
def rate(rate_list, names=None, prds=None, evts=None, cols=None, baseline=None, pval=0.05, test='mann_whitney_u', test_kws=None, xlim=None, ylim=None, title=None, xlab=None, ylab=putil.FR_lbl, add_lgn=True, lgn_lbl='trs', ffig=None, ax=None): """Plot firing rate.""" # Init. ax = putil.axes(ax) if test_kws is None: test_kws = dict() # Plot periods and baseline first. putil.plot_periods(prds, ax=ax) if baseline is not None: putil.add_baseline(baseline, ax=ax) putil.set_limits(ax, xlim) if not len(rate_list): return ax if cols is None: cols = putil.get_colors(as_cycle=False) if names is None: names = len(rate_list) * [''] # Iterate through list of rate arrays xmin, xmax, ymax = None, None, None for i, rts in enumerate(rate_list): # Init. name = names[i] col = cols[i] # Skip empty array (no trials). if not rts.shape[0]: continue # Set line label. Convert to Numpy array to format floats nicely. lbl = str(np.array(name)) if util.is_iterable(name) else str(name) if lgn_lbl is not None: lbl += ' ({} {})'.format(rts.shape[0], lgn_lbl) # Plot mean +- SEM of rate vectors. tvec, meanr, semr = rts.columns, rts.mean(), rts.sem() ax.plot(tvec, meanr, label=lbl, color=col) ax.fill_between(tvec, meanr - semr, meanr + semr, alpha=0.2, facecolor=col, edgecolor=col) # Update limits. tmin, tmax, rmax = tvec.min(), tvec.max(), (meanr + semr).max() xmin = np.min([xmin, tmin]) if xmin is not None else tmin xmax = np.max([xmax, tmax]) if xmax is not None else tmax ymax = np.max([ymax, rmax]) if ymax is not None else rmax # Set ticks, labels and axis limits. if xlim is None: if xmin == xmax: # avoid setting identical limits xmax = None xlim = (xmin, xmax) if ylim is None: ymax = 1.02 * ymax if (ymax is not None) and (ymax > 0) else None ylim = (0, ymax) if xlab is not None: xlab = putil.t_lbl.format(xlab) putil.format_plot(ax, xlim, ylim, xlab, ylab, title) t1, t2 = ax.get_xlim() # in case it was set to None tmarks, tlbls = putil.get_tick_marks_and_labels(t1, t2) putil.set_xtick_labels(ax, tmarks, tlbls) putil.set_max_n_ticks(ax, 7, 'y') # Add legend. if add_lgn and len(rate_list): putil.set_legend(ax, loc=1, borderaxespad=0.0, handletextpad=0.4, handlelength=0.6) # Add significance line to top of axes. if (pval is not None) and (len(rate_list) == 2): rates1, rates2 = rate_list sign_prds = stats.sign_periods(rates1, rates2, pval, test, **test_kws) putil.plot_signif_prds(sign_prds, color='m', linewidth=4.0, ax=ax) # Plot event markers. putil.plot_event_markers(evts, ax=ax) # Save and return plot. putil.save_fig(ffig) return ax
def plot_combined_rec_mean(recs, stims, res_dir, par_kws, list_n_most_DS, list_min_nunits, n_boot=1e4, ci=95, tasks=None, task_labels=None, add_title=True, fig=None): """Test and plot results combined across sessions.""" # Init. # putil.set_style('notebook', 'ticks') vkey = 'all' # This should be made more explicit! prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Load all results to plot. dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws) # Create figures. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(dict_rt_res), ncol=len(list_min_nunits), subw=8, subh=6, fig=fig, create_axes=True) # Query data. allScores = {} allnunits = {} for n_most_DS, rt_res in dict_rt_res.items(): # Get accuracy scores. dScores = {(rec, task): res[vkey]['Scores'].mean() for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)} allScores[n_most_DS] = pd.concat(dScores, axis=1).T # Get number of units. allnunits[n_most_DS] = {(rec, task): res[vkey]['nunits'].iloc[0] for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)} # Get # values (for baseline plotting.) all_nvals = pd.Series({(rec, task): res[vkey]['nclasses'].iloc[0] for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)}) un_nvals = all_nvals.unique() if len(un_nvals) > 1 and verbose: print('Found multiple # of classes to decode: {}'.format(un_nvals)) nvals = un_nvals[0] allnunits = pd.DataFrame(allnunits) # Plot mean performance across recordings and # test significance by bootstrapping. for inmost, n_most_DS in enumerate(list_n_most_DS): Scores = allScores[n_most_DS] nunits = allnunits[n_most_DS] for iminu, min_nunits in enumerate(list_min_nunits): ax_scr = axs_scr[inmost, iminu] # Select only recordings with minimum number of units. sel_rt = nunits.index[nunits >= min_nunits] nScores = Scores.loc[sel_rt].copy() # Nothing to plot. if nScores.empty: ax_scr.axis('off') continue # Prepare data. if tasks is None: tasks = nScores.index.get_level_values(1).unique() # in data if task_labels is None: task_labels = {task: task for task in tasks} dScores = {task: pd.DataFrame(nScores.xs(task, level=1).unstack(), columns=['accuracy']) for task in tasks} lScores = pd.concat(dScores, axis=0) lScores['time'] = lScores.index.get_level_values(1) lScores['task'] = lScores.index.get_level_values(0) lScores['rec'] = lScores.index.get_level_values(2) lScores.index = np.arange(len(lScores.index)) lScores.task.replace(task_labels, inplace=True) # Add altered task names for legend plotting. nrecs = {task_labels[task]: len(nScores.xs(task, level=1)) for task in tasks} my_format = lambda x: '{} (n={})'.format(x, nrecs[x]) lScores['task_nrecs'] = lScores['task'].apply(my_format) # Plot as time series. sns.tsplot(lScores, time='time', value='accuracy', unit='rec', condition='task_nrecs', ci=ci, n_boot=n_boot, ax=ax_scr) # Add chance level line. chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim) # Format plot. title = ('{} most DS units'.format(n_most_DS) if n_most_DS != 0 else 'all units') title += (', recordings with at least {} units'.format(min_nunits) if (min_nunits > 1 and len(list_min_nunits) > 1) else '') ytitle = 1.0 putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) putil.hide_legend_title(ax_scr) # Match axes across decoding plots. [putil.sync_axes(axs_scr[inmost, :], sync_y=True) for inmost in range(axs_scr.shape[0])] # Save plots. list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS] par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str) title = '' if add_title: title = decutil.fig_title(res_dir, **par_kws) title += '\n{}% CE with {} bootstrapped subsamples'.format(ci, int(n_boot)) fs_title = 'large' w_pad, h_pad = 3, 3 par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str) ffig = decutil.fig_fname(res_dir, 'combined_score', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad) return fig_scr, axs_scr, ffig
def plot_scores_across_nunits(recs, stims, res_dir, list_n_most_DS, par_kws): """ Plot prediction score results across different number of units included. """ # Init. putil.set_style('notebook', 'ticks') tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load all results to plot. dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws) # Create figures. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) # Do plotting per recording and task. for irec, rec in enumerate(recs): if verbose: print('\n' + rec) for itask, task in enumerate(tasks): if verbose: print(' ' + task) ax_scr = axs_scr[irec, itask] # Init data. dict_lScores = {} cols = sns.color_palette('hls', len(dict_rt_res.keys())) lncls = [] for (n_most_DS, rt_res), col in zip(dict_rt_res.items(), cols): # Check if results exist for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): continue res = rt_res[(rec, task)] for v, col in zip(res.keys(), cols): vres = res[v] Scores = vres['Scores'] lncls.append(vres['nclasses']) # Unstack dataframe with results. lScores = pd.DataFrame(Scores.unstack(), columns=['score']) lScores['time'] = lScores.index.get_level_values(0) lScores['fold'] = lScores.index.get_level_values(1) lScores.index = np.arange(len(lScores.index)) # Get number of units tested. nunits = vres['nunits'] uni_nunits = nunits.unique() if len(uni_nunits) > 1 and verbose: print('Different number of units found.') nunits = uni_nunits[0] # Collect results. dict_lScores[(nunits, v)] = lScores # Skip rest if no data is available. # Check if any result exists for rec-task combination. if not len(dict_lScores): ax_scr.axis('off') continue # Concatenate accuracy scores from every recording. all_lScores = pd.concat(dict_lScores) all_lScores['n_most_DS'] = all_lScores.index.get_level_values(0) all_lScores.index = np.arange(len(all_lScores.index)) # Plot decoding results. nnunits = len(all_lScores['n_most_DS'].unique()) title = '{} {}, {} sets of units'.format(' '.join(rec), task, nnunits) ytitle = 1.0 prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Plot time series. palette = sns.color_palette('muted') sns.tsplot(all_lScores, time='time', value='score', unit='fold', condition='n_most_DS', color=palette, ax=ax_scr) # Add chance level line. # This currently plots a chance level line for every nvals, # combined across stimulus period! uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim, ylim_scr) # Format plot. putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) # Match axes across decoding plots. # [putil.sync_axes(axs_scr[:, itask], sync_y=True) # for itask in range(axs_scr.shape[1])] # Save plots. list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS] par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str) title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str) ffig = decutil.fig_fname(res_dir, 'score_nunits', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
def plot_score_multi_rec(recs, stims, res_dir, par_kws): """Plot prediction scores for multiple recordings.""" # Init. putil.set_style('notebook', 'ticks') n_most_DS = par_kws['n_most_DS'] tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load results. rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS] # Create figure. ret = putil.get_gs_subplots(nrow=1, ncol=len(tasks), subw=8, subh=6, create_axes=True) fig_scr, _, axs_scr = ret print('\nPlotting multi-recording results...') for itask, task in enumerate(tasks): if verbose: print(' ' + task) ax_scr = axs_scr[0, itask] dict_lScores = {} for irec, rec in enumerate(recs): # Check if results exist for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): continue # Init data. res = rt_res[(rec, task)] cols = sns.color_palette('hls', len(res.keys())) lncls = [] for v, col in zip(res.keys(), cols): vres = res[v] if vres is None: continue Scores = vres['Scores'] lncls.append(vres['nclasses']) # Unstack dataframe with results. lScores = pd.DataFrame(Scores.unstack(), columns=['score']) lScores['time'] = lScores.index.get_level_values(0) lScores['fold'] = lScores.index.get_level_values(1) lScores.index = np.arange(len(lScores.index)) dict_lScores[(rec, v)] = lScores if not len(dict_lScores): ax_scr.axis('off') continue # Concatenate accuracy scores from every recording. all_lScores = pd.concat(dict_lScores) all_lScores['rec'] = all_lScores.index.get_level_values(0) all_lScores['rec'] = all_lScores['rec'].str.join(' ') # format label all_lScores.index = np.arange(len(all_lScores.index)) # Plot decoding results. nrec = len(all_lScores['rec'].unique()) title = '{}, {} recordings'.format(task, nrec) ytitle = 1.0 prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Plot time series. palette = sns.color_palette('muted') sns.tsplot(all_lScores, time='time', value='score', unit='fold', condition='rec', color=palette, ax=ax_scr) # Add chance level line. # This currently plots a chance level line for every nvals, # combined across stimulus period! uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim, ylim_scr) # Format plot. putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) # Save figure. title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 ffig = decutil.fig_fname(res_dir, 'all_scores', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
def plot_scores_weights(recs, stims, res_dir, par_kws): """ Plot prediction scores and model weights for given recording and analysis. """ # Init. putil.set_style('notebook', 'ticks') n_most_DS = par_kws['n_most_DS'] tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load results. rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS] # Create figures. # For prediction scores. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) # For unit weights (coefficients). fig_wgt, _, axs_wgt = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) for irec, rec in enumerate(recs): if verbose: print('\n' + rec) for itask, task in enumerate(tasks): if verbose: print(' ' + task) # Init figures. ax_scr = axs_scr[irec, itask] ax_wgt = axs_wgt[irec, itask] # Check if any result exists for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): ax_scr.axis('off') ax_wgt.axis('off') continue # Init data. res = rt_res[(rec, task)] vals = [v for v in res.keys() if not util.is_null(res[v])] cols = sns.color_palette('hls', len(vals)) lnunits, lntrs, lncls, = [], [], [] for v, col in zip(vals, cols): # Basic results. vres = res[v] Scores = vres['Scores'] Coefs = vres['Coefs'] Perm = vres['Perm'] Psdo = vres['Psdo'] # Decoding params. lnunits.append(vres['nunits']) lntrs.append(vres['ntrials']) lncls.append(vres['nclasses']) # Plot decoding accuracy. plot_scores(ax_scr, Scores, Perm, Psdo, col=col) # Add labels. uni_lnunits = np.unique(np.array(lnunits).flatten()) if len(uni_lnunits) > 1 and verbose: print('Different number of units found.') nunits = uni_lnunits[0] title = '{} {}, {} units'.format(' '.join(rec), task, nunits) putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle=1.04) # Add chance level line. uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Plot stimulus periods. prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] putil.plot_periods(prds, ax=ax_scr) # Plot unit weights over time. plot_weights(ax_wgt, Coefs, prds, tlim, tlab, title=title) # Match axes across decoding plots. # [putil.sync_axes(axs_scr[:, itask], sync_y=True) # for itask in range(axs_scr.shape[1])] # Save plots. title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 # Performance. ffig = decutil.fig_fname(res_dir, 'score', 'pdf', **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad) # Weights. ffig = decutil.fig_fname(res_dir, 'weight', 'pdf', **par_kws) putil.save_fig(ffig, fig_wgt, title, fs_title, w_pad=w_pad, h_pad=h_pad)