def plot_mean_rates(mRates, aa_res_dir, tasks=None, task_lbls=None, baseline=None, xlim=None, ylim=None, ci=68, ffig=None, figsize=(6, 4)): """Plot mean rates across tasks.""" # Init. if tasks is None: tasks = mRates.keys() # Plot mean activity. lRates = [] for task in tasks: lrates = pd.DataFrame(mRates[task].unstack(), columns=['rate']) lrates['task'] = task lRates.append(lrates) lRates = pd.concat(lRates) lRates['time'] = lRates.index.get_level_values(0) lRates['unit'] = lRates.index.get_level_values(1) if task_lbls is not None: lRates.task.replace(task_lbls, inplace=True) # Plot as time series. #putil.set_style('notebook', 'white') fig = putil.figure(figsize=figsize) ax = putil.axes() sns.tsplot(lRates, time='time', value='rate', unit='unit', condition='task', ci=ci, ax=ax) # Add periods and baseline. putil.plot_periods(ax=ax) if baseline is not None: putil.add_baseline(baseline, ax=ax) # Format plot. sns.despine(ax=ax) putil.set_labels(ax, xlab='time since S1 onset', ylab='rate (sp/s)') putil.set_limits(ax, xlim, ylim) putil.hide_legend_title(ax) # Save plot. putil.save_fig(ffig, fig) return fig, ax
def plot_RF_results(RF_res, stims, fdir, sup_title): """Plot receptive field results.""" # Plot distribution of coverage values. putil.set_style('poster', 'white') fig = putil.figure() ax = putil.axes() sns.distplot(RF_res.S1_cover, bins=np.arange(0, 1.01, 0.1), kde=False, rug=True, ax=ax) putil.set_limits(ax, [0, 1]) fst = util.format_to_fname(sup_title) ffig = fdir + fst + '_S1_coverage.png' putil.save_fig(ffig, fig, sup_title) # Plot RF coverage and rate during S1 on regression plot for each # recording and task. tasks = RF_res.index.get_level_values(-1).unique() for vname, ylim in [('mean_rate', [0, None]), ('max_rate', [0, None]), ('mDSI', [0, 1])]: fig, gs, axes = putil.get_gs_subplots(nrow=len(stims), ncol=len(tasks), subw=4, subh=4, ax_kws_list=None, create_axes=True) colors = sns.color_palette('muted', len(tasks)) for istim, stim in enumerate(stims): for itask, task in enumerate(tasks): # Plot regression plot. ax = axes[istim, itask] scov, sval = [stim + '_' + name for name in ('cover', vname)] df = RF_res.xs(task, level=-1) sns.regplot(scov, sval, df, color=colors[itask], ax=ax) # Add unit labels. uids = df.index.droplevel(0) putil.add_unit_labels(ax, uids, df[scov], df[sval]) # Add stats. r, p = sp.stats.pearsonr(df[sval], df[scov]) pstr = util.format_pvalue(p) txt = 'r = {:.2f}, {}'.format(r, pstr) ax.text(0.02, 0.98, txt, va='top', ha='left', transform=ax.transAxes) # Set labels. title = '{} {}'.format(task, stim) xlab, ylab = [sn.replace('_', ' ') for sn in (scov, sval)] putil.set_labels(ax, xlab, ylab, title) # Set limits. xlim = [0, 1] putil.set_limits(ax, xlim, ylim) # Save plot. fst = util.format_to_fname(sup_title) fname = '{}_cover_{}.png'.format(fst, vname) ffig = util.join([fdir, vname, fname]) putil.save_fig(ffig, fig, sup_title)
def plot_auc_heatmap(aroc_mat, cmap='viridis', events=None, xlbl_freq=500, ylbl_freq=10, xlab='time', ylab='unit index', title='AROC over time', ffig=None, fig=None): """Plot ROC AUC of list of units on heatmap.""" fig = putil.figure(fig) # Plot heatmap. yticklabels = np.arange(len(aroc_mat.index)) + 1 ax = pplot.heatmap(aroc_mat, vmin=0, vmax=1, cmap=cmap, xlab=xlab, ylab=ylab, title=title, yticklabels=yticklabels) # Format labels. xlbls = pd.Series(aroc_mat.columns.map(str)) xlbls[aroc_mat.columns % xlbl_freq != 0] = '' putil.set_xtick_labels(ax, lbls=xlbls) putil.rot_xtick_labels(ax, rot=0, ha='center') putil.sparsify_tick_labels(fig, ax, 'y', istart=ylbl_freq - 1, freq=ylbl_freq, reverse=True) putil.hide_tick_marks(ax) putil.hide_spines(ax) # Plot events. if events is not None: putil.plot_events(events, add_names=False, color='black', alpha=0.3, ls='-', lw=1, ax=ax) # Save plot. putil.save_fig(ffig, dpi=300)
def plot_ROC_mean(d_faroc, t1=None, t2=None, ylim=None, colors=None, ylab='AROC', ffig=None): """Plot mean ROC curves over given period.""" # Import results. d_aroc = {} for name, faroc in d_faroc.items(): aroc = util.read_objects(faroc, 'aroc') d_aroc[name] = aroc.unstack().T # Format results. laroc = pd.DataFrame(pd.concat(d_aroc), columns=['aroc']) laroc['task'] = laroc.index.get_level_values(0) laroc['time'] = laroc.index.get_level_values(1) laroc['unit'] = laroc.index.get_level_values(2) laroc.index = np.arange(len(laroc.index)) # Init figure. fig = putil.figure(figsize=(6, 6)) ax = sns.tsplot(laroc, time='time', value='aroc', unit='unit', condition='task', color=colors) # Highlight stimulus periods. putil.plot_periods(ax=ax) # Plot mean results. [ax.lines[i].set_linewidth(3) for i in range(len(ax.lines))] # Add chance level line. putil.add_chance_level(ax=ax, alpha=0.8, color='k') ax.lines[-1].set_linewidth(1.5) # Format plot. xlab = 'Time since S1 onset (ms)' putil.set_labels(ax, xlab, ylab) putil.set_limits(ax, [t1, t2], ylim) putil.set_spines(ax, bottom=True, left=True, top=False, right=False) putil.set_legend(ax, loc=0) # Save plot. putil.save_fig(ffig, fig, ytitle=1.05, w_pad=15)
def plot_slope_diffs_btw_groups(fit_res, prd, res_dir, groups=None, figsize=None): """Plot differences in slopes between group pairs.""" # Test pair-wise difference from each other. if groups is None: groups = fit_res['group'].unique() empty_df = pd.DataFrame(np.nan, index=groups, columns=groups) pw_diff_v = empty_df.copy() pw_diff_p = empty_df.copy() pw_diff_a = empty_df.copy() for grp1, grp2 in combinations(groups, 2): # Get slopes in each task. slp_t1 = fit_res.slope[fit_res.group == grp1] slp_t2 = fit_res.slope[fit_res.group == grp2] # Mean difference value. diff = slp_t1.mean() - slp_t2.mean() pw_diff_v.loc[grp1, grp2] = diff # Do test for statistical difference. stat, pval = stats.mann_whithney_u_test(slp_t1, slp_t2) pw_diff_p.loc[grp1, grp2] = pval # Annotation DF plot. a1, a2 = [ '{:.2f}{}'.format(v, util.star_pvalue(pval)) for v in (diff, -diff) ] pw_diff_a.loc[grp1, grp2] = a1 # Plot and save figure of mean pair-wise difference. fig = putil.figure(figsize=figsize) sns.heatmap(pw_diff_v, annot=pw_diff_a, fmt='', linewidths=0.5, cbar=False) title = 'Mean difference in {} slopes (sp/s / s)'.format(prd) putil.set_labels(title=title) ffig = res_dir + '{}_anticipatory_slope_pairwise_diff.png'.format(prd) putil.save_fig(ffig, fig) return pw_diff_v
def plot_qm(u, bs_stats, stab_prd_res, prd_inc, tr_inc, spk_inc, add_lbls=False, ftempl=None, fig=None, sps=None): """Plot quality metrics related figures.""" # Init values. waveforms = np.array(u.Waveforms) wavetime = u.Waveforms.columns * us spk_times = np.array(u.SpikeParams['time'], dtype=float) base_rate = u.QualityMetrics['baseline'] # Minimum and maximum gain. gmin = u.SessParams['minV'] gmax = u.SessParams['maxV'] # %% Init plots. # Disable inline plotting to prevent memory leak. putil.inline_off() # Init figure and gridspec. fig = putil.figure(fig) if sps is None: sps = putil.gridspec(1, 1)[0] ogsp = putil.embed_gsp(sps, 2, 1, height_ratios=[0.02, 1]) info_sps, qm_sps = ogsp[0], ogsp[1] # Info header. info_ax = fig.add_subplot(info_sps) putil.hide_axes(info_ax) title = putil.get_unit_info_title(u) putil.set_labels(ax=info_ax, title=title, ytitle=0.80) # Create axes. gsp = putil.embed_gsp(qm_sps, 3, 2, wspace=0.3, hspace=0.4) ax_wf_inc, ax_wf_exc = [fig.add_subplot(gsp[0, i]) for i in (0, 1)] ax_wf_amp, ax_wf_dur = [fig.add_subplot(gsp[1, i]) for i in (0, 1)] ax_amp_dur, ax_rate = [fig.add_subplot(gsp[2, i]) for i in (0, 1)] # Trial markers. trial_starts, trial_stops = u.TrData.TrialStart, u.TrData.TrialStop tr_markers = pd.DataFrame({'time': trial_starts[9::10]}) tr_markers['label'] = [ str(itr + 1) if i % 2 else '' for i, itr in enumerate(tr_markers.index) ] # Common variables, limits and labels. WF_T_START = test_sorting.WF_T_START spk_t = u.SessParams.sampl_prd * (np.arange(waveforms.shape[1]) - WF_T_START) ses_t_lim = test_sorting.get_start_stop_times(spk_times, trial_starts, trial_stops) ss, sa = 1.0, 0.8 # marker size and alpha on scatter plot # Color spikes by their occurance over session time. my_cmap = putil.get_cmap('jet') spk_cols = np.tile(np.array([.25, .25, .25, .25]), (len(spk_times), 1)) if np.any(spk_inc): # check if there is any spike included spk_t_inc = np.array(spk_times[spk_inc]) tmin, tmax = float(spk_times.min()), float(spk_times.max()) spk_cols[spk_inc, :] = my_cmap((spk_t_inc - tmin) / (tmax - tmin)) # Put excluded trials to the front, and randomise order of included trials # so later spikes don't systematically cover earlier ones. spk_order = np.hstack((np.where(np.invert(spk_inc))[0], np.random.permutation(np.where(spk_inc)[0]))) # Common labels for plots ses_t_lab = 'Recording time (s)' # %% Waveform shape analysis. # Plot included and excluded waveforms on different axes. # Color included by occurance in session time to help detect drifts. s_waveforms, s_spk_cols = waveforms[spk_order, :], spk_cols[spk_order] wf_t_lim, glim = [min(spk_t), max(spk_t)], [gmin, gmax] wf_t_lab, volt_lab = 'WF time ($\mu$s)', 'Voltage' for st in ('Included', 'Excluded'): ax = ax_wf_inc if st == 'Included' else ax_wf_exc spk_idx = spk_inc if st == 'Included' else np.invert(spk_inc) tr_idx = tr_inc if st == 'Included' else np.invert(tr_inc) nspsk, ntrs = sum(spk_idx), sum(tr_idx) title = '{} WFs, {} spikes, {} trials'.format(st, nspsk, ntrs) # Select waveforms and colors. rand_spk_idx = spk_idx[spk_order] wfs = s_waveforms[rand_spk_idx, :] cols = s_spk_cols[rand_spk_idx] # Plot waveforms. xlab, ylab = (wf_t_lab, volt_lab) if add_lbls else (None, None) pwaveform.plot_wfs(wfs, spk_t, cols=cols, lw=0.1, alpha=0.05, xlim=wf_t_lim, ylim=glim, title=title, xlab=xlab, ylab=ylab, ax=ax) # %% Waveform summary metrics. # Init data. wf_amp_all = u.SpikeParams['amplitude'] wf_amp_inc = wf_amp_all[spk_inc] wf_dur_all = u.SpikeParams['duration'] wf_dur_inc = wf_dur_all[spk_inc] # Set common limits and labels. dur_lim = [0, wavetime[-2] - wavetime[WF_T_START]] # same across units glim = max(wf_amp_all.max(), gmax - gmin) amp_lim = [0, glim] amp_lab = 'Amplitude' dur_lab = 'Duration ($\mu$s)' # Waveform amplitude across session time. m_amp, sd_amp = wf_amp_inc.mean(), wf_amp_inc.std() title = 'WF amplitude: {:.1f} $\pm$ {:.1f}'.format(m_amp, sd_amp) xlab, ylab = (ses_t_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_amp_all, spk_inc, c='m', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_amp) # Waveform duration across session time. mdur, sdur = wf_dur_inc.mean(), wf_dur_inc.std() title = 'WF duration: {:.1f} $\pm$ {:.1f} $\mu$s'.format(mdur, sdur) xlab, ylab = (ses_t_lab, dur_lab) if add_lbls else (None, None) pplot.scatter(spk_times, wf_dur_all, spk_inc, c='c', bc='grey', s=ss, xlab=xlab, ylab=ylab, xlim=ses_t_lim, ylim=dur_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_wf_dur) # Waveform duration against amplitude. title = 'WF duration - amplitude' xlab, ylab = (dur_lab, amp_lab) if add_lbls else (None, None) pplot.scatter(wf_dur_all[spk_order], wf_amp_all[spk_order], c=spk_cols[spk_order], s=ss, xlab=xlab, ylab=ylab, xlim=dur_lim, ylim=amp_lim, edgecolors='', alpha=sa, id_line=False, title=title, ax=ax_amp_dur) # %% Firing rate. tmean = np.array(bs_stats['tmean']) rmean = util.remove_dim_from_series(bs_stats['rate']) prd_tstart, prd_tstop = stab_prd_res['tstart'], stab_prd_res['tstop'] # Color segments depending on whether they are included / excluded. def plot_periods(v, color, ax): # Plot line segments. for i in range(len(prd_inc[:-1])): col = color if prd_inc[i] and prd_inc[i + 1] else 'grey' x, y = [(tmean[i], tmean[i + 1]), (v[i], v[i + 1])] ax.plot(x, y, color=col) # Plot line points. for i in range(len(prd_inc)): col = color if prd_inc[i] else 'grey' x, y = [tmean[i], v[i]] ax.plot(x, y, color=col, marker='o', markersize=3, markeredgecolor=col) # Firing rate over session time. title = 'Baseline rate: {:.1f} spike/s'.format(float(base_rate)) xlab, ylab = (ses_t_lab, putil.FR_lbl) if add_lbls else (None, None) ylim = [0, 1.25 * np.max(rmean)] plot_periods(rmean, 'b', ax_rate) pplot.lines([], [], c='b', xlim=ses_t_lim, ylim=ylim, title=title, xlab=xlab, ylab=ylab, ax=ax_rate) # Trial markers. putil.plot_events(tr_markers, lw=0.5, ls='--', alpha=0.35, y_lbl=0.92, ax=ax_rate) # Excluded periods. excl_prds = [] tstart, tstop = ses_t_lim if tstart != prd_tstart: excl_prds.append(('beg', tstart, prd_tstart)) if tstop != prd_tstop: excl_prds.append(('end', prd_tstop, tstop)) putil.plot_periods(excl_prds, ymax=0.92, ax=ax_rate) # %% Post-formatting. # Maximize number of ticks on recording time axes to prevent covering. for ax in (ax_wf_amp, ax_wf_dur, ax_rate): putil.set_max_n_ticks(ax, 6, 'x') # %% Save figure. if ftempl is not None: fname = ftempl.format(u.name_to_fname()) putil.save_fig(fname, fig, title, rect_height=0.92) putil.inline_on() return [ax_wf_inc, ax_wf_exc], ax_wf_amp, ax_wf_dur, ax_amp_dur, ax_rate
def plot_group_violin(res, x, y, groups=None, npval=None, pth=0.01, color='grey', ylim=None, ylab=None, ffig=None): """Plot group-wise results on violin plots.""" if groups is None: groups = res['group'].unique() # Test difference from zero in each groups. ttest_res = { group: sp.stats.ttest_1samp(gres[y], 0) for group, gres in res.groupby(x) } ttest_res = pd.DataFrame.from_dict(ttest_res, 'index') # Binarize significance test. res['is_sign'] = res[npval] < pth if npval is not None else True res['direction'] = np.sign(res[y]) # Set up figure and plot data. fig = putil.figure() ax = putil.axes() putil.add_baseline(ax=ax) sns.violinplot(x=x, y=y, data=res, inner=None, order=groups, ax=ax) sns.swarmplot(x=x, y=y, hue='is_sign', data=res, color=color, order=groups, hue_order=[True, False], ax=ax) putil.set_labels(ax, xlab='', ylab=ylab) putil.set_limits(ax, ylim=ylim) putil.hide_legend(ax) # Add annotations. ymin, ymax = ax.get_ylim() ylvl = ymax for i, group in enumerate(groups): gres = res.loc[res.group == group] # Mean. mean_str = 'Mean:\n' if i == 0 else '\n' mean_str += '{:.2f}'.format(gres[y].mean()) # Non-zero test of distribution. str_pval = util.format_pvalue(ttest_res.loc[group, 'pvalue']) mean_str += '\n({})'.format(str_pval) # Stats on difference from baseline. nnonsign, ntot = (~gres.is_sign).sum(), len(gres) npos, nneg = [ sum(gres.is_sign & (gres.direction == d)) for d in (1, -1) ] sign_to_report = [('+', npos), ('=', nnonsign), ('-', nneg)] nsign_str = '' for symb, n in sign_to_report: prc = str(int(round(100 * n / ntot))) nsign_str += '{} {:>3} / {} ({:>2}%)\n'.format( symb, int(n), ntot, prc) lbl = '{}\n\n{}'.format(mean_str, nsign_str) ax.text(i, ylvl, lbl, fontsize='smaller', va='bottom', ha='center') # Save plot. putil.save_fig(ffig, fig) return fig, ax
def cat_mean(df, x, y, add_stats=True, fstats=None, bar_ylvl=None, ci=68, add_mean=True, mean_ylvl=None, ylbl=None, fig=None, ax=None, ffig=None): """Plot mean of two categorical dataset.""" # Init. if fig is None and ax is None: fig = putil.figure(figsize=(3, 3)) if ax is None: ax = putil.axes() if fstats is None: fstats = stats.mann_whithney_u_test # Plot means as bars. sns.barplot(x=x, y=y, data=df, ci=ci, ax=ax, palette=palette, errwidth=errwidth, **kwargs) # Get plotted vectors. ngrps = [t.get_text() for t in ax.get_xticklabels()] v1, v2 = [df.loc[df[x] == ngrp, y] for ngrp in ngrps] # Add significance bar. if add_stats: _, pval = fstats(v1, v2) pval_str = util.format_pvalue(pval) if bar_ylvl is None: bar_ylvl = 1.1 * max(v1.mean() + stats.sem(v1), v2.mean() + stats.sem(v2)) lines([0.1, 0.9], [bar_ylvl, bar_ylvl], color='grey', ax=ax) ax.text(0.5, 1.01 * bar_ylvl, pval_str, fontsize='medium', fontstyle='italic', va='bottom', ha='center') # Add mean values. for vec, xpos in [(v1, 0.2), (v2, 1.2)]: mstr = '{:.2f}'.format(vec.mean()) ypos = 1.005 * vec.mean() ax.text(xpos, ypos, mstr, fontstyle='italic', fontsize='smaller', va='bottom', ha='center') # Format plot. sns.despine() putil.hide_legend_title(ax) putil.set_labels(ax, '', ylbl) putil.sparsify_tick_labels(fig, ax, 'y', freq=2) # Save plot. putil.save_fig(ffig, fig) return ax
def plot_up_down_raster(Spikes, task, rec, itrs): """Plot spike raster for up-down dynamics analysis.""" # Set params for plotting. uids, trs = Spikes.index, Spikes.columns plot_trs = trs[itrs] ntrs = len(plot_trs) nunits = len(uids) tr_gap = nunits / 2 # Init figure. putil.set_style('notebook', 'ticks') fig = putil.figure(figsize=(10, ntrs)) ax = fig.add_subplot(111) # Per trial, per unit. for itr, tr in enumerate(plot_trs): for iu, uid in enumerate(uids): # Init y level and spike times. i = (tr_gap + nunits) * itr + iu spk_tr = Spikes.loc[uid, tr] # Plot (spike time, y-level) pairs. x = np.array(spk_tr.rescale('ms')) y = (i+1) * np.ones_like(x) patches = [Rectangle((xi-wsp/2, yi-hsp/2), wsp, hsp) for xi, yi in zip(x, y)] collection = PatchCollection(patches, facecolor=c, edgecolor=c) ax.add_collection(collection) # Add stimulus lines. for stim in constants.stim_dur.index: t_start, t_stop = constants.fixed_tr_prds.loc[stim] events = pd.DataFrame([(t_start, 't_start'), (t_stop, 't_stop')], index=['start', 'stop'], columns=['time', 'label']) putil.plot_events(events, add_names=False, color='grey', alpha=0.5, ls='-', lw=0.5, ax=ax) # Add inter-trial shading. for itr in range(ntrs+1): ymin = itr * (tr_gap + nunits) - tr_gap + 0.5 ax.axhspan(ymin, ymin+tr_gap, alpha=.05, color='grey') # Set tick labels. pos = np.arange(ntrs) * (tr_gap + nunits) + nunits/2 lbls = plot_trs + 1 putil.set_ytick_labels(ax, pos, lbls) # putil.sparsify_tick_labels(ax, 'y', freq=2, istart=1) putil.hide_tick_marks(ax, show_x_tick_mrks=True) # Format plot. xlim = constants.fixed_tr_prds.loc['whole trial'] ylim = [-tr_gap/2, ntrs * (nunits+tr_gap)-tr_gap/2] xlab = 'Time since S1 onset (ms)' ylab = 'Trial number' title = '{} {}'.format(rec, task) putil.format_plot(ax, xlim, ylim, xlab, ylab, title) putil.set_spines(ax, True, False, False, False) # Save figure. fname = 'UpDown_dynamics_{}_{}.pdf'.format(rec, task) ffig = util.join(['results', 'UpDown', fname]) putil.save_fig(ffig, fig, dpi=600)