def site_cell_summary(id): """ plots a grid of subplots, each one showing the real dprime, histogram of significant bins and fitted exponential decay to the significant bins. Both the dprime and significant bins are the cell grand mean across probes and context pairs :param id: str. site id :return: fig, axes """ site_cells = set([ cell for cell in batch_dprimes['SC']['dprime'].keys() if cell[:7] == id ]) fig, axes = fplt.subplots_sqr(len(site_cells), sharex=True, sharey=True) for ax, cell in zip(axes, site_cells): grand_mean, _ = cDP.flip_dprimes(batch_dprimes['SC']['dprime'][cell], flip='max') line = np.mean(grand_mean, axis=(0, 1)) hist = np.mean(batch_dprimes['SC']['shuffled_significance'][cell], axis=(0, 1)) ax.plot(times[:len(line)], line, color='black') ax.bar(times[:len(hist)], hist, width=bar_width, align='center', color='C0', edgecolor='white') _ = fplt.exp_decay(times[:len(hist)], hist, ax=ax, linestyle='--', color='gray') # ax.set_title(cell, fontsize=10) ax.legend(loc='upper right', fontsize='small', markerscale=3, frameon=False) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) return fig, axes
def fit_example_plot(id, source): """ Plots dprime (top) and significant bins (bottom) with their fitted exponential decays. Both dprime and significant bins are the grand mean across all probes and context pairs for a given cell or site. :param id: str. cell or site id :param source: str. 'SC', 'dPCA', or 'LDA' :return: fig, axes. """ # flips signs of dprimes and montecarlos as neede dprimes, shuffleds = cDP.flip_dprimes( batch_dprimes[source]['dprime'][id], batch_dprimes[source]['shuffled_dprime'][id], flip='max') signif_bars = batch_dprimes[source]['shuffled_significance'][id] mean_dprime = np.mean(dprimes[:, :, :], axis=(0, 1)) mean_signif = np.mean(signif_bars[:, :, :], axis=(0, 1)) t = times[:dprimes.shape[-1]] fig, axes = plt.subplots(2, 1, sharex='all', sharey='all') # plots dprime plus fit axes[0].plot(t, mean_dprime, color='black') axes[0].axhline(0, color='gray', linestyle='--') _ = fplt.exp_decay(t, mean_dprime, ax=axes[0], linestyle='--', color='black') # plots confifence bins plut fit axes[1].bar( t, mean_signif, width=bar_width, align='center', edgecolor='white', ) _ = fplt.exp_decay(times, mean_signif, ax=axes[1], linestyle='--', color='black') axes[0].legend(loc='upper right', fontsize=ax_val_size, markerscale=3, frameon=False) axes[1].legend(loc='upper right', fontsize=ax_val_size, markerscale=3, frameon=False) # formats axis, legend and so on. axes[0].set_ylabel(f'dprime', fontsize=ax_lab_size) axes[0].tick_params(labelsize=ax_val_size) axes[1].set_ylabel(f'mean significance', fontsize=ax_lab_size) axes[1].tick_params(labelsize=ax_val_size) axes[1].set_xlabel('time (ms)', fontsize=ax_lab_size) axes[1].tick_params(labelsize=ax_val_size) for ax in axes: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) return fig, axes
def analysis_steps_plot(id, probe, source): site = id[:7] if source == 'SC' else id # loads the raw data recs = load(site, rasterfs=meta['raster_fs'], recache=False) sig = recs['trip0']['resp'] # calculates response realiability and select only good cells to improve analysis r_vals, goodcells = signal_reliability(sig, r'\ASTIM_*', threshold=meta['reliability']) goodcells = goodcells.tolist() # get the full data raster Context x Probe x Rep x Neuron x Time raster = src.data.rasters.raster_from_sig( sig, probe, channels=goodcells, contexts=meta['transitions'], smooth_window=meta['smoothing_window'], raster_fs=meta['raster_fs'], zscore=meta['zscore'], part='probe') # trialR shape: Trial x Cell x Context x Probe x Time; R shape: Cell x Context x Probe x Time trialR, R, _ = cdPCA.format_raster(raster) trialR, R = trialR.squeeze(axis=3), R.squeeze(axis=2) # squeezes out probe if source == 'dPCA': projection, _ = cdPCA.fit_transform(R, trialR) elif source == 'LDA': projection, _ = cLDA.fit_transform_over_time(trialR) projection = projection.squeeze(axis=1) if meta['zscore'] is False: trialR = trialR * meta['raster_fs'] if source == 'dPCA': projection = projection * meta['raster_fs'] # flips signs of dprimes and montecarlos as needed dprimes, shuffleds = cDP.flip_dprimes( batch_dprimes[source]['dprime'][id], batch_dprimes[source]['shuffled_dprime'][id], flip='max') if source in ['dPCA', 'LDA']: _, simulations = cDP.flip_dprimes( batch_dprimes[source]['dprime'][id], batch_dprimes[source]['simulated_dprime'][id], flip='max') t = times[:trialR.shape[-1]] nrows = 2 if source == 'SC' else 3 fig, axes = plt.subplots(nrows, 6, sharex='all', sharey='row') # PSTH for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): t0_idx = meta['transitions'].index(trans[0]) t1_idx = meta['transitions'].index(trans[1]) if source == 'SC': cell_idx = goodcells.index(id) axes[0, tt].plot(t, trialR[:, cell_idx, t0_idx, :].mean(axis=0), color=trans_color_map[trans[0]], linewidth=3) axes[0, tt].plot(t, trialR[:, cell_idx, t1_idx, :].mean(axis=0), color=trans_color_map[trans[1]], linewidth=3) else: axes[0, tt].plot(t, projection[:, t0_idx, :].mean(axis=0), color=trans_color_map[trans[0]], linewidth=3) axes[0, tt].plot(t, projection[:, t1_idx, :].mean(axis=0), color=trans_color_map[trans[1]], linewidth=3) # Raster, dprime, CI bottom, top = axes[0, 0].get_ylim() half = ((top - bottom) / 2) + bottom for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): prb_idx = all_probes.index(probe) pair_idx = tt if source == 'SC': # raster cell_idx = goodcells.index(id) t0_idx = meta['transitions'].index(trans[0]) t1_idx = meta['transitions'].index(trans[1]) _ = fplt._raster(t, trialR[:, cell_idx, t0_idx, :], y_offset=0, y_range=(bottom, half), ax=axes[0, tt], scatter_kws={ 'color': trans_color_map[trans[0]], 'alpha': 0.4, 's': 10 }) _ = fplt._raster(t, trialR[:, cell_idx, t1_idx, :], y_offset=0, y_range=(half, top), ax=axes[0, tt], scatter_kws={ 'color': trans_color_map[trans[1]], 'alpha': 0.4, 's': 10 }) # plots the real dprime and the shuffled dprime ci axes[1, tt].plot(t, dprimes[prb_idx, pair_idx, :], color='black') _ = fplt._cint(t, shuffleds[:, prb_idx, pair_idx, :], confidence=0.95, ax=axes[1, tt], fillkwargs={ 'color': 'black', 'alpha': 0.5 }) if source in ['dPCA', 'LDA']: # plots the real dprime and simulated dprime ci axes[2, tt].plot(t, dprimes[prb_idx, pair_idx, :], color='black') _ = fplt._cint(t, simulations[:, prb_idx, pair_idx, :], confidence=0.95, ax=axes[2, tt], fillkwargs={ 'color': 'black', 'alpha': 0.5 }) # significance bars ax1_bottom = axes[1, 0].get_ylim()[0] if source == 'dPCA': ax2_bottom = axes[2, 0].get_ylim()[0] for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): prb_idx = all_probes.index(probe) pair_idx = tt # histogram of context discrimination axes[1, tt].bar( t, batch_dprimes[source]['shuffled_significance'][id][prb_idx, pair_idx, :], width=bar_width, align='center', edgecolor='white', bottom=ax1_bottom) if source in ['dPCA', 'LDA']: # histogram of population effects axes[2, tt].bar(t, batch_dprimes[source]['simulated_significance'][id] [prb_idx, pair_idx, :], width=bar_width, align='center', edgecolor='white', bottom=ax2_bottom) # formats legend if tt == 0: axes[0, tt].set_ylabel(f'dPC', fontsize=ax_lab_size) axes[0, tt].tick_params(labelsize=ax_val_size) axes[1, tt].set_ylabel(f'dprime', fontsize=ax_lab_size) axes[1, tt].tick_params(labelsize=ax_val_size) if source in ['dPCA', 'LDA']: axes[2, tt].set_ylabel(f'dprime', fontsize=ax_lab_size) axes[2, tt].tick_params(labelsize=ax_val_size) axes[-1, tt].set_xlabel('time (ms)', fontsize=ax_lab_size) axes[-1, tt].tick_params(labelsize=ax_val_size) axes[0, tt].set_title(f'{trans[0]}_{trans[1]}', fontsize=sub_title_size) for ax in np.ravel(axes): ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) return fig, axes
def category_summary_plot(id, source): """ Plots calculated dprime, confidense interval of shuffled dprime, and histogram of significant bins, for all contexts and probes. Subplots are a grid of al combinations of probe (rows) and context pairs (columns), plus the means of each category, and the grand mean :param id: str. cell or site id :param source: str. 'SC', 'dPCA', or 'LDA' :return: fig, axes. """ # flips signs of dprimes and montecarlos as neede dprimes, shuffleds = cDP.flip_dprimes( batch_dprimes[source]['dprime'][id], batch_dprimes[source]['shuffled_dprime'][id], flip='max') signif_bars = batch_dprimes[source]['shuffled_significance'][id] t = times[:dprimes.shape[-1]] fig, axes = plt.subplots(5, 7, sharex='all', sharey='all') # dprime and confidence interval for each probe-transition combinations for (pp, probe), (tt, trans) in itt.product( enumerate(all_probes), enumerate(itt.combinations(meta['transitions'], 2))): prb_idx = all_probes.index(probe) # plots the real dprime and the shuffled dprime axes[pp, tt].plot(t, dprimes[prb_idx, tt, :], color='black') _ = fplt._cint(t, shuffleds[:, prb_idx, tt, :], confidence=0.95, ax=axes[pp, tt], fillkwargs={ 'color': 'black', 'alpha': 0.5 }) # dprime and ci for the mean across context pairs for pp, probe in enumerate(all_probes): prb_idx = all_probes.index(probe) axes[pp, -1].plot(t, np.mean(dprimes[prb_idx, :, :], axis=0), color='black') axes[pp, -1].axhline(0, color='gray', linestyle='--') # dprime and ci for the mean across probes for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): axes[-1, tt].plot(t, np.mean(dprimes[:, tt, :], axis=0), color='black') axes[-1, tt].axhline(0, color='gray', linestyle='--') # significance bars for each probe-transition combinations bar_bottom = axes[0, 0].get_ylim()[0] for (pp, probe), (tt, trans) in itt.product( enumerate(all_probes), enumerate(itt.combinations(meta['transitions'], 2))): prb_idx = all_probes.index(probe) axes[pp, tt].bar(t, signif_bars[prb_idx, tt, :], width=bar_width, align='center', edgecolor='white', bottom=bar_bottom) # _ = fplt.exp_decay(t, signif_bars[prb_idx, tt, :], ax=axes[2, tt]) # significance bars for the mean across context pairs for pp, probe in enumerate(all_probes): prb_idx = all_probes.index(probe) axes[pp, -1].bar(t, np.mean(signif_bars[prb_idx, :, :], axis=0), width=bar_width, align='center', edgecolor='white', bottom=bar_bottom) _ = fplt.exp_decay(t, np.mean(signif_bars[prb_idx, :, :], axis=0), ax=axes[pp, -1], yoffset=bar_bottom, linestyle='.', color='gray') axes[pp, -1].legend(loc='upper right', fontsize='small', markerscale=3, frameon=False) # significance bars for the mean across probes for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): axes[-1, tt].bar(t, np.mean(signif_bars[:, tt, :], axis=0), width=bar_width, align='center', edgecolor='white', bottom=bar_bottom) _ = fplt.exp_decay(t, np.mean(signif_bars[:, tt, :], axis=0), axes[-1, tt], yoffset=bar_bottom, linestyle='.', color='gray') axes[-1, tt].legend(loc='upper right', fontsize='small', markerscale=3, frameon=False) # cell summary mean: dprime, confidence interval axes[-1, -1].plot(t, np.mean(dprimes[:, :, :], axis=(0, 1)), color='black') axes[-1, -1].axhline(0, color='gray', linestyle='--') axes[-1, -1].bar(t, np.mean(signif_bars[:, :, :], axis=(0, 1)), width=bar_width, align='center', edgecolor='white', bottom=bar_bottom) _ = fplt.exp_decay(t, np.mean(signif_bars[:, :, :], axis=(0, 1)), ax=axes[-1, -1], yoffset=bar_bottom, linestyle='.', color='gray') axes[-1, -1].legend(loc='upper right', fontsize='small', markerscale=3, frameon=False) # formats axis, legend and so on. for pp, probe in enumerate(all_probes): axes[pp, 0].set_ylabel(f'probe {probe}', fontsize=ax_lab_size) axes[pp, 0].tick_params(labelsize=ax_val_size) axes[-1, 0].set_ylabel(f'probe\nmean', fontsize=ax_lab_size) axes[-1, 0].tick_params(labelsize=ax_val_size) for tt, trans in enumerate(itt.combinations(meta['transitions'], 2)): axes[0, tt].set_title(f'{trans[0]}_{trans[1]}', fontsize=sub_title_size) axes[-1, tt].set_xlabel('time (ms)', fontsize=ax_lab_size) axes[-1, tt].tick_params(labelsize=ax_val_size) axes[0, -1].set_title(f'pair\nmean', fontsize=sub_title_size) axes[-1, -1].set_xlabel('time (ms)', fontsize=ax_lab_size) axes[-1, -1].tick_params(labelsize=ax_val_size) for ax in np.ravel(axes): ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) return fig, axes
np.linspace(0, dprime.shape[-1] / meta['raster_fs'], dprime.shape[-1], endpoint=False) * 1000 } # calculats different significaces/corrections # calculate significant time bins, both raw and corrected for multiple comparisons for corr_name, corr in multiple_corrections.items(): print(f' comp_corr: {corr_name}') significance, confidence_interval = _significance(dprime, shuffled_dprime, corr, alpha=alpha) fliped, _ = flip_dprimes(dprime, None, flip='sum') for mean_type in mean_types: print(f' mean_signif: {mean_type}') # masks dprime with different significances, uses different approaches to define significance of the mean. masked, masked_lab_dict = _mask_with_significance( fliped, significance, dim_lab_dict, mean_type=mean_type) # calculate different metrics and organize into a dataframe df = metrics_to_DF(masked, masked_lab_dict, metrics=metrics) df['mult_comp_corr'] = corr_name df['mean_signif_type'] = mean_type df['stim_type'] = meta['stim_type'] df['analysis'] = fname df['siteid'] = site df['region'] = region_map[site]