def plot_weights(ax, Coefs, prds=None, xlim=None, xlab=tlab, ylab='unit coefficient', title='', ytitle=1.04): """Plot decoding weights.""" # Unstack dataframe with results. lCoefs = pd.DataFrame(Coefs.unstack().unstack(), columns=['coef']) lCoefs['time'] = lCoefs.index.get_level_values(0) lCoefs['value'] = lCoefs.index.get_level_values(1) lCoefs['uid'] = lCoefs.index.get_level_values(2) lCoefs.index = np.arange(len(lCoefs.index)) # Plot time series. sns.tsplot(lCoefs, time='time', value='coef', unit='value', condition='uid', ax=ax) # Add chance level line and stimulus periods. putil.add_chance_level(ax=ax, ylevel=0) putil.plot_periods(prds, ax=ax) # Set axis limits. xlim = xlim if xlim is not None else tlim putil.set_limits(ax, xlim) # Format plot. putil.set_labels(ax, xlab, ylab, title, ytitle) putil.hide_legend(ax)
def plot_DR_3x3(u, fig=None, sps=None): """Plot 3x3 direction response plot, with polar plot in center.""" if not u.to_plot(): return # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 3, 3) # inner gsp with subplots # Polar plot. putil.set_style('notebook', 'white') ax_polar = fig.add_subplot(gsp[4], polar=True) for stim in constants.stim_dur.index: # for each stimulus stim_resp = u.get_stim_resp_vals(stim, 'Dir') resp_stats = util.calc_stim_resp_stats(stim_resp) dirs, resp = np.array(resp_stats.index) * deg, resp_stats['mean'] c = putil.stim_colors[stim] baseline = u.get_baseline() ptuning.plot_DR(dirs, resp, color=c, baseline=baseline, ax=ax_polar) putil.hide_ticks(ax_polar, 'y') # Raster-rate plots. putil.set_style('notebook', 'ticks') rr_pos = [5, 2, 1, 0, 3, 6, 7, 8] # Position of each direction. rr_dir_plot_pos = pd.Series(constants.all_dirs, index=rr_pos) rate_axs = [] for isp, d in rr_dir_plot_pos.iteritems(): # Prepare plot formatting. first_dir = (isp == 0) # Plot direction response across trial periods. res = plot_SR(u, 'Dir', [d], fig=fig, sps=gsp[isp], no_labels=True) draster_axs, drate_axs, _ = res # Remove axis ticks. for i, ax in enumerate(drate_axs): first_prd = (i == 0) show_x_tick_lbls = first_dir show_y_tick_lbls = first_dir & first_prd putil.hide_tick_labels(ax, show_x_tick_lbls, show_y_tick_lbls) # Add task name as title (to top center axes). if isp == 1: ttl = u.get_task() + (' [excluded]' if u.is_excluded() else '') putil.set_labels(draster_axs[0], title=ttl, ytitle=1.10, title_kws={'loc': 'right'}) rate_axs.extend(drate_axs) rate_axs.extend(drate_axs) # Match scale of y axes. putil.sync_axes(rate_axs, sync_y=True) [putil.adjust_decorators(ax) for ax in rate_axs] return ax_polar, rate_axs
def plot_mean_rates(mRates, aa_res_dir, tasks=None, task_lbls=None, baseline=None, xlim=None, ylim=None, ci=68, ffig=None, figsize=(6, 4)): """Plot mean rates across tasks.""" # Init. if tasks is None: tasks = mRates.keys() # Plot mean activity. lRates = [] for task in tasks: lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate']) lrates['task'] = task lRates.append(lrates) lRates = pd.concat(lRates) lRates['time'] = lRates.index.get_level_values(0) lRates['unit'] = lRates.index.get_level_values(1) if task_lbls is not None: lRates.task.replace(task_lbls, inplace=True) # Plot as time series. #putil.set_style('notebook', 'white') fig = putil.figure(figsize=figsize) ax = putil.axes() sns.tsplot(lRates, time='time', value='rate', unit='unit', condition='task', ci=ci, ax=ax) # Add periods and baseline. putil.plot_periods(ax=ax) if baseline is not None: putil.add_baseline(baseline, ax=ax) # Format plot. sns.despine(ax=ax) putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)') putil.set_limits(ax, xlim, ylim) putil.hide_legend_title(ax) # Save plot. putil.save_fig(ffig, fig) return fig, ax
def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None, xlab='Trial #', ylab=None,): # Plot on heatmap. sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax) # Set tick labels. putil.hide_tick_marks(ax) tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25)) ax.xaxis.set_ticks(tr_ticks) ax.set_xticklabels(tr_ticks) putil.rot_xtick_labels(ax, 0) putil.rot_ytick_labels(ax, 0, va='center') putil.set_labels(ax, xlab, ylab, title, ytitle)
def plot_RF_results(RF_res, stims, fdir, sup_title): """Plot receptive field results.""" # Plot distribution of coverage values. putil.set_style('poster', 'white') fig = putil.figure() ax = putil.axes() sns.distplot(RF_res.S1_cover, bins=np.arange(0, 1.01, 0.1), kde=False, rug=True, ax=ax) putil.set_limits(ax, [0, 1]) fst = util.format_to_fname(sup_title) ffig = fdir + fst + '_S1_coverage.png' putil.save_fig(ffig, fig, sup_title) # Plot RF coverage and rate during S1 on regression plot for each # recording and task. tasks = RF_res.index.get_level_values(-1).unique() for vname, ylim in [('mean_rate', [0, None]), ('max_rate', [0, None]), ('mDSI', [0, 1])]: fig, gs, axes = putil.get_gs_subplots(nrow=len(stims), ncol=len(tasks), subw=4, subh=4, ax_kws_list=None, create_axes=True) colors = sns.color_palette('muted', len(tasks)) for istim, stim in enumerate(stims): for itask, task in enumerate(tasks): # Plot regression plot. ax = axes[istim, itask] scov, sval = [stim + '_' + name for name in ('cover', vname)] df = RF_res.xs(task, level=-1) sns.regplot(scov, sval, df, color=colors[itask], ax=ax) # Add unit labels. uids = df.index.droplevel(0) putil.add_unit_labels(ax, uids, df[scov], df[sval]) # Add stats. r, p = sp.stats.pearsonr(df[sval], df[scov]) pstr = util.format_pvalue(p) txt = 'r = {:.2f}, {}'.format(r, pstr) ax.text(0.02, 0.98, txt, va='top', ha='left', transform=ax.transAxes) # Set labels. title = '{} {}'.format(task, stim) xlab, ylab = [sn.replace('_', ' ') for sn in (scov, sval)] putil.set_labels(ax, xlab, ylab, title) # Set limits. xlim = [0, 1] putil.set_limits(ax, xlim, ylim) # Save plot. fst = util.format_to_fname(sup_title) fname = '{}_cover_{}.png'.format(fst, vname) ffig = util.join([fdir, vname, fname]) putil.save_fig(ffig, fig, sup_title)
def raster_rate(spk_list, rate_list, names=None, prds=None, evts=None, cols=None, baseline=None, title=None, rs_ylab=True, rate_kws=None, fig=None, ffig=None, sps=None): """Plot raster and rate plots.""" if rate_kws is None: rate_kws = dict() # Init subplots. sps, fig = putil.sps_fig(sps, fig) gsp = putil.embed_gsp(sps, 2, 1, height_ratios=[.66, 1], hspace=.15) n_sets = max(len(spk_list), 1) # let's add an empty axes if no data gsp_raster = putil.embed_gsp(gsp[0], n_sets, 1, hspace=.15) gsp_rate = putil.embed_gsp(gsp[1], 1, 1) # Init colors. if cols is None: col_cyc = putil.get_colors(mpl_colors=True) cols = [next(col_cyc) for i in range(n_sets)] # Raster plots. raster_axs = [fig.add_subplot(gsp_raster[i, 0]) for i in range(n_sets)] for i, (spk_trs, ax) in enumerate(zip(spk_list, raster_axs)): ylab = names[i] if (rs_ylab and names is not None) else None raster(spk_trs, prds=prds, c=cols[i], ylab=ylab, ax=ax) putil.hide_axes(ax) if len(raster_axs): putil.set_labels(raster_axs[0], title=title) # add title to top raster # Rate plot. rate_ax = fig.add_subplot(gsp_rate[0, 0]) rate(rate_list, names, prds, evts, cols, baseline, **rate_kws, ax=rate_ax) # Synchronize raster's x axis limits to rate plot's limits. xlim = rate_ax.get_xlim() [ax.set_xlim(xlim) for ax in raster_axs] # Save and return plot. putil.save_fig(ffig, fig) return fig, raster_axs, rate_ax
def plot_ROC_mean(d_faroc, t1=None, t2=None, ylim=None, colors=None, ylab='AROC', ffig=None): """Plot mean ROC curves over given period.""" # Import results. d_aroc = {} for name, faroc in d_faroc.items(): aroc = util.read_objects(faroc, 'aroc') d_aroc[name] = aroc.unstack().T # Format results. laroc = pd.DataFrame(pd.concat(d_aroc), columns=['aroc']) laroc['task'] = laroc.index.get_level_values(0) laroc['time'] = laroc.index.get_level_values(1) laroc['unit'] = laroc.index.get_level_values(2) laroc.index = np.arange(len(laroc.index)) # Init figure. fig = putil.figure(figsize=(6, 6)) ax = sns.tsplot(laroc, time='time', value='aroc', unit='unit', condition='task', color=colors) # Highlight stimulus periods. putil.plot_periods(ax=ax) # Plot mean results. [ax.lines[i].set_linewidth(3) for i in range(len(ax.lines))] # Add chance level line. putil.add_chance_level(ax=ax, alpha=0.8, color='k') ax.lines[-1].set_linewidth(1.5) # Format plot. xlab = 'Time since S1 onset (ms)' putil.set_labels(ax, xlab, ylab) putil.set_limits(ax, [t1, t2], ylim) putil.set_spines(ax, bottom=True, left=True, top=False, right=False) putil.set_legend(ax, loc=0) # Save plot. putil.save_fig(ffig, fig, ytitle=1.05, w_pad=15)
def plot_scores(ax, Scores, Perm=None, Psdo=None, nvals=None, prds=None, col='b', perm_col='grey', psdo_col='g', xlim=None, ylim=ylim_scr, xlab=tlab, ylab=ylab_scr, title='', ytitle=1.04): """Plot decoding accuracy results.""" lgn_patches = [] # Plot permuted results (if exist). if not util.is_null(Perm) and not Perm.isnull().all().all(): x, pval = Perm.columns, Perm.loc['pval'] ymean, ystd = Perm.loc['mean'], Perm.loc['std'] plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=6, color=perm_col, ax=ax) lgn_patches.append(putil.get_artist('permuted', perm_col)) # Plot population shuffled results (if exist). if not util.is_null(Psdo) and not Psdo.isnull().all().all(): x, pval = Psdo.columns, Psdo.loc['pval'] ymean, ystd = Psdo.loc['mean'], Psdo.loc['std'] plot_mean_std_sdiff(x, ymean, ystd, pval, pth=0.01, lw=3, color=psdo_col, ax=ax) lgn_patches.append(putil.get_artist('pseudo-population', psdo_col)) # Plot scores. plot_score_set(Scores, ax, color=col) lgn_patches.append(putil.get_artist('synchronous', col)) # Add legend. lgn_patches = lgn_patches[::-1] putil.set_legend(ax, handles=lgn_patches) # Add chance level line. # This currently plots all nvals combined across stimulus period! if nvals is not None: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax) # Set axis limits. xlim = xlim if xlim is not None else tlim putil.set_limits(ax, xlim, ylim) # Format plot. putil.set_labels(ax, xlab, ylab, title, ytitle)
def plot_slope_diffs_btw_groups(fit_res, prd, res_dir, groups=None, figsize=None): """Plot differences in slopes between group pairs.""" # Test pair-wise difference from each other. if groups is None: groups = fit_res['group'].unique() empty_df = pd.DataFrame(np.nan, index=groups, columns=groups) pw_diff_v = empty_df.copy() pw_diff_p = empty_df.copy() pw_diff_a = empty_df.copy() for grp1, grp2 in combinations(groups, 2): # Get slopes in each task. slp_t1 = fit_res.slope[fit_res.group == grp1] slp_t2 = fit_res.slope[fit_res.group == grp2] # Mean difference value. diff = slp_t1.mean() - slp_t2.mean() pw_diff_v.loc[grp1, grp2] = diff # Do test for statistical difference. stat, pval = stats.mann_whithney_u_test(slp_t1, slp_t2) pw_diff_p.loc[grp1, grp2] = pval # Annotation DF plot. a1, a2 = [ '{:.2f}{}'.format(v, util.star_pvalue(pval)) for v in (diff, -diff) ] pw_diff_a.loc[grp1, grp2] = a1 # Plot and save figure of mean pair-wise difference. fig = putil.figure(figsize=figsize) sns.heatmap(pw_diff_v, annot=pw_diff_a, fmt='', linewidths=0.5, cbar=False) title = 'Mean difference in {} slopes (sp/s / s)'.format(prd) putil.set_labels(title=title) ffig = res_dir + '{}_anticipatory_slope_pairwise_diff.png'.format(prd) putil.save_fig(ffig, fig) return pw_diff_v
def plot_CE_time_distribution(ulists, eff_t_res, eff_pars, aroc_res_dir, bins=None): """Plot distribution of ROC comparison effect timing across groups.""" # Init. putil.set_style('notebook', 'white') if bins is None: bins = np.arange(2000, 2600, 50) fig, _, axs = putil.get_gs_subplots(nrow=1, ncol=len(eff_pars), subw=5, subh=4, create_axes=True, as_array=False) # Plot CE timing distribution for each unit group. for (eff_dir, eff_lbl), ax in zip(eff_pars, axs): etd = eff_t_res.loc[eff_t_res.effect_dir == eff_dir, 'time'] for nlist in ulists: tvals = etd.loc[nlist] lbl = '{} (n={})'.format(nlist, len(tvals)) sns.distplot(tvals, bins, label=lbl, ax=ax) putil.set_labels(ax, 'effect timing (ms since S1 onset)', '', eff_lbl) # Format plots. sns.despine(ax=ax) [ax.legend() for ax in axs] [putil.hide_tick_labels(ax, show_x_tick_lbls=True) for ax in axs] putil.sync_axes(axs, sync_y=True) # Save plot.s ffig = aroc_res_dir + 'CE/CE_timing_distributions.png' putil.save_fig(ffig, fig)
def cat_mean(df, x, y, add_stats=True, fstats=None, bar_ylvl=None, ci=68, add_mean=True, mean_ylvl=None, ylbl=None, fig=None, ax=None, ffig=None): """Plot mean of two categorical dataset.""" # Init. if fig is None and ax is None: fig = putil.figure(figsize=(3, 3)) if ax is None: ax = putil.axes() if fstats is None: fstats = stats.mann_whithney_u_test # Plot means as bars. sns.barplot(x=x, y=y, data=df, ci=ci, ax=ax, palette=palette, errwidth=errwidth, **kwargs) # Get plotted vectors. ngrps = [t.get_text() for t in ax.get_xticklabels()] v1, v2 = [df.loc[df[x] == ngrp, y] for ngrp in ngrps] # Add significance bar. if add_stats: _, pval = fstats(v1, v2) pval_str = util.format_pvalue(pval) if bar_ylvl is None: bar_ylvl = 1.1 * max(v1.mean() + stats.sem(v1), v2.mean() + stats.sem(v2)) lines([0.1, 0.9], [bar_ylvl, bar_ylvl], color='grey', ax=ax) ax.text(0.5, 1.01 * bar_ylvl, pval_str, fontsize='medium', fontstyle='italic', va='bottom', ha='center') # Add mean values. for vec, xpos in [(v1, 0.2), (v2, 1.2)]: mstr = '{:.2f}'.format(vec.mean()) ypos = 1.005 * vec.mean() ax.text(xpos, ypos, mstr, fontstyle='italic', fontsize='smaller', va='bottom', ha='center') # Format plot. sns.despine() putil.hide_legend_title(ax) putil.set_labels(ax, '', ylbl) putil.sparsify_tick_labels(fig, ax, 'y', freq=2) # Save plot. putil.save_fig(ffig, fig) return ax
def plot_trial_type_distribution(UA, RecInfo, utids=None, tr_par=('S1', 'Dir'), save_plot=False, fname=None): """Plot distribution of trial types.""" # Init. par_str = util.format_to_fname(str(tr_par)) if utids is None: utids = UA.utids(as_series=True) recs = util.get_subj_date_pairs(utids) tasks = RecInfo.index.get_level_values('task').unique() tasks = [task for task in UA.tasks() if task in tasks] # reorder tasks # Init plotting. putil.set_style('notebook', 'darkgrid') fig, gsp, axs = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=4, subh=3, create_axes=True) for ir, rec in enumerate(recs): for it, task in enumerate(tasks): ax = axs[ir, it] rt = rec + (task,) if rt not in RecInfo.index: ax.set_axis_off() continue # Get includecd trials and their parameters. inc_trs = RecInfo.loc[rt, 'trials'] utid = utids.xs(rt, level=('subj', 'date', 'task'))[0] TrData = UA.get_unit(utid[:-1], utid[-1]).TrData.loc[inc_trs] # Create DF to plot. anw_df = TrData[[tr_par, 'correct']].copy() anw_df['answer'] = 'error' anw_df.loc[anw_df.correct, 'answer'] = 'correct' all_df = anw_df.copy() all_df.answer = 'all' comb_df = pd.concat([anw_df, all_df]) if not TrData.size: ax.set_axis_off() continue # Plot as countplot. sns.countplot(x=tr_par, hue='answer', data=comb_df, hue_order=['all', 'correct', 'error'], ax=ax) sns.despine(ax=ax) putil.hide_tick_marks(ax) putil.set_max_n_ticks(ax, 6, 'y') ax.legend(loc=[0.95, 0.7]) # Add title. title = '{} {}'.format(rec, task) nce = anw_df.answer.value_counts() nc, ne = [nce[c] if c in nce else 0 for c in ('correct', 'error')] pnc, pne = 100*nc/nce.sum(), 100*ne/nce.sum() title += '\n\n# correct: {} ({:.0f}%)'.format(nc, pnc) title += ' # error: {} ({:.0f}%)'.format(ne, pne) putil.set_labels(ax, title=title, xlab=par_str) # Format legend. if (ir != 0) or (it != 0): ax.legend_.remove() # Save plot. if save_plot: title = 'Trial type distribution' if fname is None: fname = util.join(['results', 'decoding', 'prepare', par_str + '_trial_type_distr.pdf']) putil.save_fig(fname, fig, title, w_pad=3, h_pad=3)
def select_units_trials(UA, utids=None, fres=None, ffig=None, min_n_units=5, min_n_trs_per_unit=5): """ Select optimal set of units and trials for population decoding. min_n_units: minimum number of units to keep (0: off) min_n_trs_per_unit: minimum number of trials per unit to keep (0: off) """ print('Selecting optimal set of units and trials for decoding...') # Init. if utids is None: utids = UA.utids(as_series=True) u_rt_grpby = utids.groupby(level=['subj', 'date', 'task']) # Unit info frame. UInc = pd.Series(False, index=utids.index) # Included trials by unit. IncTrs = pd.Series([(UA.get_unit(utid[:-1], utid[-1]).inc_trials()) for utid in utids], index=utids.index) # Result DF. rec_task = pd.MultiIndex.from_tuples([rt for rt, _ in u_rt_grpby], names=['subj', 'date', 'task']) cols = ['elec', 'units', 'nunits', 'nallunits', '% remaining units', 'trials', 'ntrials', 'nalltrials', '% remaining trials'] RecInfo = pd.DataFrame(index=rec_task, columns=cols) rt_utids = [utids.xs((s, d, t), level=('subj', 'date', 'task')) for s, d, t in rec_task] RecInfo.nallunits = [len(utids) for utids in rt_utids] rt_ulist = [UA.get_unit(utids[0][:-1], utids[0][-1]) for utids in rt_utids] RecInfo.nalltrials = [int(u.QualityMetrics['NTrialsTotal']) for u in rt_ulist] # Function to plot matrix (DF) included/excluded trials. def plot_inc_exc_trials(IncTrsMat, ax, title=None, ytitle=None, xlab='Trial #', ylab=None,): # Plot on heatmap. sns.heatmap(IncTrsMat, cmap='RdYlGn', center=0.5, cbar=False, ax=ax) # Set tick labels. putil.hide_tick_marks(ax) tr_ticks = [1] + list(np.arange(25, IncTrsMat.shape[1]+1, 25)) ax.xaxis.set_ticks(tr_ticks) ax.set_xticklabels(tr_ticks) putil.rot_xtick_labels(ax, 0) putil.rot_ytick_labels(ax, 0, va='center') putil.set_labels(ax, xlab, ylab, title, ytitle) # Init plotting. ytitle = 1.40 putil.set_style('notebook', 'whitegrid') fig, gsp, axs = putil.get_gs_subplots(nrow=len(rec_task), ncol=3, subw=6, subh=4, create_axes=True) for i_rt, ((subj, date, task), rt_utids) in enumerate(u_rt_grpby): print('{} / {}: {} - {} {}'.format(i_rt+1, len(u_rt_grpby), subj, date, task)) # Init electrode. elecs = rt_utids.index.get_level_values('elec').unique() if len(elecs) != 1: warnings.warn('More than one electrode?') elec = elecs[0] RecInfo.loc[(subj, date, task), 'elec'] = elec # Create matrix of included trials of recording & task of units. ch_idxs = rt_utids.index.droplevel(-1).droplevel(2).droplevel(1).droplevel(0) n_alltrs = RecInfo.nalltrials[(subj, date, task)] IncTrsMat = pd.DataFrame(np.zeros((len(ch_idxs), n_alltrs), dtype=int), index=ch_idxs, columns=np.arange(n_alltrs)+1) for ch_idx, utid in zip(ch_idxs, rt_utids): IncTrsMat.loc[ch_idx].iloc[IncTrs[utid]] = 1 # Plot included/excluded trials after preprocessing. ax = axs[i_rt, 0] ylab = '{} {} {}'.format(subj, date, task) title = ('Included (green) and excluded (red) trials' if i_rt == 0 else None) plot_inc_exc_trials(IncTrsMat, ax, title, ytitle, ylab=ylab) # Calculate and plot overlap of trials across units. # How many trials will remain if we iteratively excluding units # with the least overlap with the rest of the units? def n_cov_trs(df): # return number of trials covered in df return sum(df.all()) def calc_heuristic(df): return df.shape[0] * n_cov_trs(df) n_trs = IncTrsMat.sum(1) n_units = IncTrsMat.shape[0] # Init results DF. columns = ('uid', 'ntrs_cov', 'n_rem_u', 'trial x units') tr_covs = pd.DataFrame(columns=columns, index=range(n_units+1)) tr_covs.loc[0] = ('none', n_cov_trs(IncTrsMat), n_units, calc_heuristic(IncTrsMat)) # Subset of included units (to be updated in each iteration). uinc = IncTrsMat.index.to_series() for iu in range(1, len(uinc)): # Number of covered trials after removing each unit. sntrscov = pd.Series([n_cov_trs(IncTrsMat.loc[uinc.drop(uid)]) for uid in uinc], index=uinc.index) ######################################### # Select and remove unit that # # (a) yields maximum trial coverage, # # (b) has minimum number of trials # ######################################### maxtrscov = sntrscov.max() worst_us = sntrscov[sntrscov == maxtrscov].index # (a) utrs = n_trs.loc[worst_us] uid_remove = utrs[(utrs == min(utrs))].index[0] # (b) # Update current subset of units and their trial DF. uinc.drop(uid_remove, inplace=True) tr_covs.loc[iu] = (uid_remove, maxtrscov, len(uinc), calc_heuristic(IncTrsMat.loc[uinc])) # Add last unit. tr_covs.iloc[-1] = (uinc[0], 0, 0, 0) # Plot covered trials against each units removed. ax_trc = axs[i_rt, 1] sns.tsplot(tr_covs['ntrs_cov'], marker='o', ms=4, color='b', ax=ax_trc) title = ('Trial coverage during iterative unit removal' if i_rt == 0 else None) xlab, ylab = 'current unit removed', '# trials covered' putil.set_labels(ax_trc, xlab, ylab, title, ytitle) ax_trc.xaxis.set_ticks(tr_covs.index) x_ticklabs = ['none'] + ['{} - {}'.format(ch, ui) for ch, ui in tr_covs.uid.loc[1:]] ax_trc.set_xticklabels(x_ticklabs) putil.rot_xtick_labels(ax_trc, 45) ax_trc.grid(True) # Add # of remaining units to top. ax_remu = ax_trc.twiny() ax_remu.xaxis.set_ticks(tr_covs.index) ax_remu.set_xticklabels(list(range(len(x_ticklabs)))[::-1]) ax_remu.set_xlabel('# units remaining') ax_remu.grid(None) # Add heuristic index. ax_heur = ax_trc.twinx() sns.tsplot(tr_covs['trial x units'], linestyle='--', marker='o', ms=4, color='m', ax=ax_heur) putil.set_labels(ax_heur, ylab='remaining units x covered trials') [tl.set_color('m') for tl in ax_heur.get_yticklabels()] [tl.set_color('b') for tl in ax_trc.get_yticklabels()] ax_heur.grid(None) # Decide on which units to exclude. min_n_trials = min_n_trs_per_unit * tr_covs['n_rem_u'] sub_tr_covs = tr_covs[(tr_covs['n_rem_u'] >= min_n_units) & (tr_covs['ntrs_cov'] >= min_n_trials)] # If any subset of units passed above criteria. rem_uids, exc_uids = pd.Series(), tr_covs.uid[1:] n_tr_rem, n_tr_exc = 0, IncTrsMat.shape[1] if len(sub_tr_covs.index): hmax_idx = sub_tr_covs['trial x units'].argmax() rem_uids = tr_covs.uid[(hmax_idx+1):] exc_uids = tr_covs.uid[1:hmax_idx+1] n_tr_rem = tr_covs.ntrs_cov[hmax_idx] n_tr_exc = IncTrsMat.shape[1] - n_tr_rem # Add to UnitInfo dataframe rt_utids = [(subj, date, elec, ch, ui, task) for ch, ui in rem_uids] UInc[rt_utids] = True # Highlight selected point in middle plot. sel_seg = [('selection', exc_uids.shape[0]-0.4, exc_uids.shape[0]+0.4)] putil.plot_periods(sel_seg, ax=ax_trc, alpha=0.3) [ax.set_xlim([-0.5, n_units+0.5]) for ax in (ax_trc, ax_remu)] # Generate remaining trials dataframe. RemTrsMat = IncTrsMat.copy().astype(float) for exc_uid in exc_uids: # Remove all trials from excluded units. RemTrsMat.loc[exc_uid] = 0.5 # Remove uncovered trials in remaining units. exc_trs = np.where(~RemTrsMat.loc[list(rem_uids)].all())[0] if exc_trs.size: RemTrsMat.iloc[:, exc_trs] = 0.5 # Overwrite by trials excluded during preprocessing. RemTrsMat[IncTrsMat == False] = 0.0 # Plot remaining trials. ax = axs[i_rt, 2] n_u_rem, n_u_exc = len(rem_uids), len(exc_uids) title = ('# units remaining: {}, excluded: {}'.format(n_u_rem, n_u_exc) + '\n# trials remaining: {}, excluded: {}'.format(n_tr_rem, n_tr_exc)) plot_inc_exc_trials(RemTrsMat, ax, title=title, ylab='') # Add remaining units and trials to RecInfo. rt = (subj, date, task) RecInfo.loc[rt, ('units', 'nunits')] = list(rem_uids), len(rem_uids) cov_trs = RemTrsMat.loc[list(rem_uids)].all() inc_trs = pd.Int64Index(np.where(cov_trs)[0]) RecInfo.loc[rt, ('trials', 'ntrials')] = inc_trs, sum(cov_trs) RecInfo['% remaining units'] = 100 * RecInfo.nunits / RecInfo.nallunits RecInfo['% remaining trials'] = 100 * RecInfo.ntrials / RecInfo.nalltrials # Save results. if fres is not None: results = {'RecInfo': RecInfo, 'UInc': UInc} util.write_objects(results, fres) # Save plot. title = 'Trial & unit selection prior decoding' putil.save_fig(ffig, fig, title, w_pad=3, h_pad=3) return RecInfo, UInc
def plot_qm(u, bs_stats, stab_prd_res, prd_inc, tr_inc, spk_inc, add_lbls=False, ftempl=None, fig=None, sps=None): """Plot quality metrics related figures.""" # Init values. waveforms = np.array(u.Waveforms) wavetime = u.Waveforms.columns * us spk_times = np.array(u.SpikeParams['time'], dtype=float) base_rate = u.QualityMetrics['baseline'] # Minimum and maximum gain. gmin = u.SessParams['minV'] gmax = u.SessParams['maxV'] # %% Init plots. # Disable inline plotting to prevent memory leak. putil.inline_off() # Init figure and gridspec. fig = putil.figure(fig) if sps is None: sps = putil.gridspec(1, 1)[0] ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1]) info_sps, qm_sps = ogsp[0], ogsp[1] # Info header. info_ax = fig.add_subplot(info_sps) putil.hide_axes(info_ax) title = putil.get_unit_info_title(u) putil.set_labels(ax=info_ax, title=title, ytitle=0.80) # Create axes. gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4) ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)] ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)] ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)] # Trial markers. trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop tr_markers = pd.DataFrame({'time': trial_starts[9::10]}) tr_markers['label'] = [ str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index) ] # Common variables, limits and labels. WF_T_START = test_sorting.WF_T_START spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) - WF_T_START) ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts, trial_stops) ss, sa = 1.0, 0.8 # marker size and alpha on scatter plot # Color spikes by their occurance over session time. my_cmap = putil.get_cmap('jet') spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1)) if np.any(spk_inc): # check if there is any spike included spk_t_inc = np.array(spk_times[spk_inc]) tmin, tmax = float(spk_times.min()), float(spk_times.max()) spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin)) # Put excluded trials to the front, and randomise order of included trials # so later spikes don't systematically cover earlier ones. spk_order = np.hstack((np.where(np.invert(spk_inc))[0], np.random.permutation(np.where(spk_inc)[0]))) # Common labels for plots ses_t_lab = 'Recording time (s)' # %% Waveform shape analysis. # Plot included and excluded waveforms on different axes. # Color included by occurance in session time to help detect drifts. s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order] wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax] wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage' for st in ('Included', 'Excluded'): ax = ax_wf_inc if st == 'Included' else ax_wf_exc spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc) tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc) nspsk, ntrs = sum(spk_idx), sum(tr_idx) title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs) # Select waveforms and colors. rand_spk_idx = spk_idx[spk_order] wfs = s_waveforms[rand_spk_idx, :] cols = s_spk_cols[rand_spk_idx] # Plot waveforms. xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None) pwaveform.plot_wfs(wfs, spk_t, cols=cols, lw=0.1, alpha=0.05, xlim=wf_t_lim, ylim=glim, title=title, xlab=xlab, ylab=ylab, ax=ax) # %% Waveform summary metrics. # Init data. wf_amp_all = u.SpikeParams['amplitude'] wf_amp_inc = wf_amp_all[spk_inc] wf_dur_all = u.SpikeParams['duration'] wf_dur_inc = wf_dur_all[spk_inc] # Set common limits and labels. dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]] # same across units glim = max(wf_amp_all.max(), gmax - gmin) amp_lim = [0, glim] amp_lab = 'Amplitude' dur_lab = 'Duration ($\mu$s)' # Waveform amplitude across session time. m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std() title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp) xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_amp_all, spk_inc, c='m', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_amp) # Waveform duration across session time. mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std() title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur) xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_dur_all, spk_inc, c='c', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=dur_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_dur) # Waveform duration against amplitude. title = 'WF duration - amplitude' xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(wf_dur_all[spk_order], wf_amp_all[spk_order], c=spk_cols[spk_order], s=ss, xlab=xlab, ylab=ylab, xlim=dur_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_amp_dur) # %% Firing rate. tmean = np.array(bs_stats['tmean']) rmean = util.remove_dim_from_series(bs_stats['rate']) prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop'] # Color segments depending on whether they are included / excluded. def plot_periods(v, color, ax): # Plot line segments. for i in range(len(prd_inc[:-1])): col = color if prd_inc[i] and prd_inc[i + 1] else 'grey' x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])] ax.plot(x, y, color=col) # Plot line points. for i in range(len(prd_inc)): col = color if prd_inc[i] else 'grey' x, y = [tmean[i], v[i]] ax.plot(x, y, color=col, marker='o', markersize=3, markeredgecolor=col) # Firing rate over session time. title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate)) xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None) ylim = [0, 1.25 * np.max(rmean)] plot_periods(rmean, 'b', ax_rate) pplot.lines([], [], c='b', xlim=ses_t_lim, ylim=ylim, title=title, xlab=xlab, ylab=ylab, ax=ax_rate) # Trial markers. putil.plot_events(tr_markers, lw=0.5, ls='--', alpha=0.35, y_lbl=0.92, ax=ax_rate) # Excluded periods. excl_prds = [] tstart, tstop = ses_t_lim if tstart != prd_tstart: excl_prds.append(('beg', tstart, prd_tstart)) if tstop != prd_tstop: excl_prds.append(('end', prd_tstop, tstop)) putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate) # %% Post-formatting. # Maximize number of ticks on recording time axes to prevent covering. for ax in (ax_wf_amp, ax_wf_dur, ax_rate): putil.set_max_n_ticks(ax, 6, 'x') # %% Save figure. if ftempl is not None: fname = ftempl.format(u.name_to_fname()) putil.save_fig(fname, fig, title, rect_height=0.92) putil.inline_on() return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
def rec_stability_test(UA, fname=None, periods=None): """Check stability of recording session across tasks.""" # Init. if periods is None: periods = ['whole trial', 'fixation'] # Init figure. fig, gsp, axs = putil.get_gs_subplots(nrow=len(periods), ncol=1, subw=10, subh=2.5, create_axes=True, as_array=False) for prd, ax in zip(periods, axs): # Calculate and plot firing rate during given period in each trial # across session for all units. colors = putil.get_colors() task_stats = pd.DataFrame(columns=['t_start', 't_stops', 'label']) for task, color in zip(UA.tasks(), colors): # Get activity of all units in task. tr_rates = [] for u in UA.iter_thru([task]): rates = u.get_prd_rates(prd, tr_time_idx=True) tr_rates.append(util.remove_dim_from_series(rates)) tr_rates = pd.DataFrame(tr_rates) # Not (non-empty and included) unit during task. if not len(tr_rates.index): continue # Plot each rate in task. tr_times = tr_rates.columns pplot.lines(tr_times, tr_rates.T, zorder=1, alpha=0.5, color=color, ax=ax) # Plot mean +- sem rate. tr_time = tr_rates.columns mean_rate, sem_rate = tr_rates.mean(), tr_rates.std() lower, upper = mean_rate-sem_rate, mean_rate+sem_rate lower[lower < 0] = 0 # remove negative values ax.fill_between(tr_time, lower, upper, zorder=2, alpha=.5, facecolor='grey', edgecolor='grey') pplot.lines(tr_time, mean_rate, lw=2, color='k', ax=ax) # Add task stats. task_lbl = '{}, {} units'.format(task, len(tr_rates.index)) # Add grand mean FR. task_lbl += '\nFR: {:.1f} sp/s'.format(tr_rates.mean().mean()) # Calculate linear trend to test gradual drift. slope, _, _, p_value, _ = sp.stats.linregress(tr_times, mean_rate) slope = 3600*slope # convert to change in spike per hour pval = util.format_pvalue(p_value, max_digit=3) task_lbl += '\n$\delta$FR: {:.1f} sp/s/h'.format(slope) task_lbl += '\n{}'.format(pval) task_stats.loc[task] = (tr_times.min(), tr_times.max(), task_lbl) # Set axes limits. tmin, tmax = task_stats.t_start.min(), task_stats.t_stops.max() putil.set_limits(ax, xlim=(tmin, tmax)) # Add task labels after all tasks have been plotted. putil.plot_events(task_stats[['t_start', 'label']], y_lbl=0.75, lbl_ha='left', lbl_rotation=0, ax=ax) # Format plot. xlab = 'Recording time (s)' if prd == periods[-1] else None putil.set_labels(ax, xlab=xlab, ylab=prd) putil.set_spines(ax, left=False) # Save figure. title = 'Recording stability of ' + UA.Name putil.save_fig(fname, fig, title)
def plot_group_violin(res, x, y, groups=None, npval=None, pth=0.01, color='grey', ylim=None, ylab=None, ffig=None): """Plot group-wise results on violin plots.""" if groups is None: groups = res['group'].unique() # Test difference from zero in each groups. ttest_res = { group: sp.stats.ttest_1samp(gres[y], 0) for group, gres in res.groupby(x) } ttest_res = pd.DataFrame.from_dict(ttest_res, 'index') # Binarize significance test. res['is_sign'] = res[npval] < pth if npval is not None else True res['direction'] = np.sign(res[y]) # Set up figure and plot data. fig = putil.figure() ax = putil.axes() putil.add_baseline(ax=ax) sns.violinplot(x=x, y=y, data=res, inner=None, order=groups, ax=ax) sns.swarmplot(x=x, y=y, hue='is_sign', data=res, color=color, order=groups, hue_order=[True, False], ax=ax) putil.set_labels(ax, xlab='', ylab=ylab) putil.set_limits(ax, ylim=ylim) putil.hide_legend(ax) # Add annotations. ymin, ymax = ax.get_ylim() ylvl = ymax for i, group in enumerate(groups): gres = res.loc[res.group == group] # Mean. mean_str = 'Mean:\n' if i == 0 else '\n' mean_str += '{:.2f}'.format(gres[y].mean()) # Non-zero test of distribution. str_pval = util.format_pvalue(ttest_res.loc[group, 'pvalue']) mean_str += '\n({})'.format(str_pval) # Stats on difference from baseline. nnonsign, ntot = (~gres.is_sign).sum(), len(gres) npos, nneg = [ sum(gres.is_sign & (gres.direction == d)) for d in (1, -1) ] sign_to_report = [('+', npos), ('=', nnonsign), ('-', nneg)] nsign_str = '' for symb, n in sign_to_report: prc = str(int(round(100 * n / ntot))) nsign_str += '{} {:>3} / {} ({:>2}%)\n'.format( symb, int(n), ntot, prc) lbl = '{}\n\n{}'.format(mean_str, nsign_str) ax.text(i, ylvl, lbl, fontsize='smaller', va='bottom', ha='center') # Save plot. putil.save_fig(ffig, fig) return fig, ax
def plot_combined_rec_mean(recs, stims, res_dir, par_kws, list_n_most_DS, list_min_nunits, n_boot=1e4, ci=95, tasks=None, task_labels=None, add_title=True, fig=None): """Test and plot results combined across sessions.""" # Init. # putil.set_style('notebook', 'ticks') vkey = 'all' # This should be made more explicit! prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Load all results to plot. dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws) # Create figures. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(dict_rt_res), ncol=len(list_min_nunits), subw=8, subh=6, fig=fig, create_axes=True) # Query data. allScores = {} allnunits = {} for n_most_DS, rt_res in dict_rt_res.items(): # Get accuracy scores. dScores = {(rec, task): res[vkey]['Scores'].mean() for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)} allScores[n_most_DS] = pd.concat(dScores, axis=1).T # Get number of units. allnunits[n_most_DS] = {(rec, task): res[vkey]['nunits'].iloc[0] for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)} # Get # values (for baseline plotting.) all_nvals = pd.Series({(rec, task): res[vkey]['nclasses'].iloc[0] for (rec, task), res in rt_res.items() if (vkey in res) and (res[vkey] is not None)}) un_nvals = all_nvals.unique() if len(un_nvals) > 1 and verbose: print('Found multiple # of classes to decode: {}'.format(un_nvals)) nvals = un_nvals[0] allnunits = pd.DataFrame(allnunits) # Plot mean performance across recordings and # test significance by bootstrapping. for inmost, n_most_DS in enumerate(list_n_most_DS): Scores = allScores[n_most_DS] nunits = allnunits[n_most_DS] for iminu, min_nunits in enumerate(list_min_nunits): ax_scr = axs_scr[inmost, iminu] # Select only recordings with minimum number of units. sel_rt = nunits.index[nunits >= min_nunits] nScores = Scores.loc[sel_rt].copy() # Nothing to plot. if nScores.empty: ax_scr.axis('off') continue # Prepare data. if tasks is None: tasks = nScores.index.get_level_values(1).unique() # in data if task_labels is None: task_labels = {task: task for task in tasks} dScores = {task: pd.DataFrame(nScores.xs(task, level=1).unstack(), columns=['accuracy']) for task in tasks} lScores = pd.concat(dScores, axis=0) lScores['time'] = lScores.index.get_level_values(1) lScores['task'] = lScores.index.get_level_values(0) lScores['rec'] = lScores.index.get_level_values(2) lScores.index = np.arange(len(lScores.index)) lScores.task.replace(task_labels, inplace=True) # Add altered task names for legend plotting. nrecs = {task_labels[task]: len(nScores.xs(task, level=1)) for task in tasks} my_format = lambda x: '{} (n={})'.format(x, nrecs[x]) lScores['task_nrecs'] = lScores['task'].apply(my_format) # Plot as time series. sns.tsplot(lScores, time='time', value='accuracy', unit='rec', condition='task_nrecs', ci=ci, n_boot=n_boot, ax=ax_scr) # Add chance level line. chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim) # Format plot. title = ('{} most DS units'.format(n_most_DS) if n_most_DS != 0 else 'all units') title += (', recordings with at least {} units'.format(min_nunits) if (min_nunits > 1 and len(list_min_nunits) > 1) else '') ytitle = 1.0 putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) putil.hide_legend_title(ax_scr) # Match axes across decoding plots. [putil.sync_axes(axs_scr[inmost, :], sync_y=True) for inmost in range(axs_scr.shape[0])] # Save plots. list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS] par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str) title = '' if add_title: title = decutil.fig_title(res_dir, **par_kws) title += '\n{}% CE with {} bootstrapped subsamples'.format(ci, int(n_boot)) fs_title = 'large' w_pad, h_pad = 3, 3 par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str) ffig = decutil.fig_fname(res_dir, 'combined_score', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad) return fig_scr, axs_scr, ffig
def plot_scores_across_nunits(recs, stims, res_dir, list_n_most_DS, par_kws): """ Plot prediction score results across different number of units included. """ # Init. putil.set_style('notebook', 'ticks') tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load all results to plot. dict_rt_res = decutil.load_res(res_dir, list_n_most_DS, **par_kws) # Create figures. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) # Do plotting per recording and task. for irec, rec in enumerate(recs): if verbose: print('\n' + rec) for itask, task in enumerate(tasks): if verbose: print(' ' + task) ax_scr = axs_scr[irec, itask] # Init data. dict_lScores = {} cols = sns.color_palette('hls', len(dict_rt_res.keys())) lncls = [] for (n_most_DS, rt_res), col in zip(dict_rt_res.items(), cols): # Check if results exist for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): continue res = rt_res[(rec, task)] for v, col in zip(res.keys(), cols): vres = res[v] Scores = vres['Scores'] lncls.append(vres['nclasses']) # Unstack dataframe with results. lScores = pd.DataFrame(Scores.unstack(), columns=['score']) lScores['time'] = lScores.index.get_level_values(0) lScores['fold'] = lScores.index.get_level_values(1) lScores.index = np.arange(len(lScores.index)) # Get number of units tested. nunits = vres['nunits'] uni_nunits = nunits.unique() if len(uni_nunits) > 1 and verbose: print('Different number of units found.') nunits = uni_nunits[0] # Collect results. dict_lScores[(nunits, v)] = lScores # Skip rest if no data is available. # Check if any result exists for rec-task combination. if not len(dict_lScores): ax_scr.axis('off') continue # Concatenate accuracy scores from every recording. all_lScores = pd.concat(dict_lScores) all_lScores['n_most_DS'] = all_lScores.index.get_level_values(0) all_lScores.index = np.arange(len(all_lScores.index)) # Plot decoding results. nnunits = len(all_lScores['n_most_DS'].unique()) title = '{} {}, {} sets of units'.format(' '.join(rec), task, nnunits) ytitle = 1.0 prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Plot time series. palette = sns.color_palette('muted') sns.tsplot(all_lScores, time='time', value='score', unit='fold', condition='n_most_DS', color=palette, ax=ax_scr) # Add chance level line. # This currently plots a chance level line for every nvals, # combined across stimulus period! uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim, ylim_scr) # Format plot. putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) # Match axes across decoding plots. # [putil.sync_axes(axs_scr[:, itask], sync_y=True) # for itask in range(axs_scr.shape[1])] # Save plots. list_n_most_DS_str = [str(i) if i != 0 else 'all' for i in list_n_most_DS] par_kws['n_most_DS'] = ', '.join(list_n_most_DS_str) title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 par_kws['n_most_DS'] = '_'.join(list_n_most_DS_str) ffig = decutil.fig_fname(res_dir, 'score_nunits', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
def plot_score_multi_rec(recs, stims, res_dir, par_kws): """Plot prediction scores for multiple recordings.""" # Init. putil.set_style('notebook', 'ticks') n_most_DS = par_kws['n_most_DS'] tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load results. rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS] # Create figure. ret = putil.get_gs_subplots(nrow=1, ncol=len(tasks), subw=8, subh=6, create_axes=True) fig_scr, _, axs_scr = ret print('\nPlotting multi-recording results...') for itask, task in enumerate(tasks): if verbose: print(' ' + task) ax_scr = axs_scr[0, itask] dict_lScores = {} for irec, rec in enumerate(recs): # Check if results exist for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): continue # Init data. res = rt_res[(rec, task)] cols = sns.color_palette('hls', len(res.keys())) lncls = [] for v, col in zip(res.keys(), cols): vres = res[v] if vres is None: continue Scores = vres['Scores'] lncls.append(vres['nclasses']) # Unstack dataframe with results. lScores = pd.DataFrame(Scores.unstack(), columns=['score']) lScores['time'] = lScores.index.get_level_values(0) lScores['fold'] = lScores.index.get_level_values(1) lScores.index = np.arange(len(lScores.index)) dict_lScores[(rec, v)] = lScores if not len(dict_lScores): ax_scr.axis('off') continue # Concatenate accuracy scores from every recording. all_lScores = pd.concat(dict_lScores) all_lScores['rec'] = all_lScores.index.get_level_values(0) all_lScores['rec'] = all_lScores['rec'].str.join(' ') # format label all_lScores.index = np.arange(len(all_lScores.index)) # Plot decoding results. nrec = len(all_lScores['rec'].unique()) title = '{}, {} recordings'.format(task, nrec) ytitle = 1.0 prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] # Plot time series. palette = sns.color_palette('muted') sns.tsplot(all_lScores, time='time', value='score', unit='fold', condition='rec', color=palette, ax=ax_scr) # Add chance level line. # This currently plots a chance level line for every nvals, # combined across stimulus period! uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Add stimulus periods. if prds is not None: putil.plot_periods(prds, ax=ax_scr) # Set axis limits. putil.set_limits(ax_scr, tlim, ylim_scr) # Format plot. putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle) # Save figure. title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 ffig = decutil.fig_fname(res_dir, 'all_scores', fformat, **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad)
def plot_scores_weights(recs, stims, res_dir, par_kws): """ Plot prediction scores and model weights for given recording and analysis. """ # Init. putil.set_style('notebook', 'ticks') n_most_DS = par_kws['n_most_DS'] tasks = par_kws['tasks'] # Remove Passive if plotting Saccade or Correct. if par_kws['feat'] in ['saccade', 'correct']: tasks = tasks[~tasks.str.contains('Pas')] # Load results. rt_res = decutil.load_res(res_dir, **par_kws)[n_most_DS] # Create figures. # For prediction scores. fig_scr, _, axs_scr = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) # For unit weights (coefficients). fig_wgt, _, axs_wgt = putil.get_gs_subplots(nrow=len(recs), ncol=len(tasks), subw=8, subh=6, create_axes=True) for irec, rec in enumerate(recs): if verbose: print('\n' + rec) for itask, task in enumerate(tasks): if verbose: print(' ' + task) # Init figures. ax_scr = axs_scr[irec, itask] ax_wgt = axs_wgt[irec, itask] # Check if any result exists for rec-task combination. if (((rec, task) not in rt_res.keys()) or (not len(rt_res[(rec, task)].keys()))): ax_scr.axis('off') ax_wgt.axis('off') continue # Init data. res = rt_res[(rec, task)] vals = [v for v in res.keys() if not util.is_null(res[v])] cols = sns.color_palette('hls', len(vals)) lnunits, lntrs, lncls, = [], [], [] for v, col in zip(vals, cols): # Basic results. vres = res[v] Scores = vres['Scores'] Coefs = vres['Coefs'] Perm = vres['Perm'] Psdo = vres['Psdo'] # Decoding params. lnunits.append(vres['nunits']) lntrs.append(vres['ntrials']) lncls.append(vres['nclasses']) # Plot decoding accuracy. plot_scores(ax_scr, Scores, Perm, Psdo, col=col) # Add labels. uni_lnunits = np.unique(np.array(lnunits).flatten()) if len(uni_lnunits) > 1 and verbose: print('Different number of units found.') nunits = uni_lnunits[0] title = '{} {}, {} units'.format(' '.join(rec), task, nunits) putil.set_labels(ax_scr, tlab, ylab_scr, title, ytitle=1.04) # Add chance level line. uni_ncls = np.unique(np.array(lncls).flatten()) if len(uni_ncls) > 1 and verbose: print('Different number of classes found.') for nvals in uni_ncls: chance_lvl = 1.0 / nvals putil.add_chance_level(ax=ax_scr, ylevel=chance_lvl) # Plot stimulus periods. prds = [[stim] + list(constants.fixed_tr_prds.loc[stim]) for stim in stims] putil.plot_periods(prds, ax=ax_scr) # Plot unit weights over time. plot_weights(ax_wgt, Coefs, prds, tlim, tlab, title=title) # Match axes across decoding plots. # [putil.sync_axes(axs_scr[:, itask], sync_y=True) # for itask in range(axs_scr.shape[1])] # Save plots. title = decutil.fig_title(res_dir, **par_kws) fs_title = 'large' w_pad, h_pad = 3, 3 # Performance. ffig = decutil.fig_fname(res_dir, 'score', 'pdf', **par_kws) putil.save_fig(ffig, fig_scr, title, fs_title, w_pad=w_pad, h_pad=h_pad) # Weights. ffig = decutil.fig_fname(res_dir, 'weight', 'pdf', **par_kws) putil.save_fig(ffig, fig_wgt, title, fs_title, w_pad=w_pad, h_pad=h_pad)