def thalamic_mapping_panel(dest_dir_appdx='../', which_paradigms=['DAC1', 'DAC2'], anatomy_dir=None, ts_plots_dir=None): """ This function creates a panel similar in style to the one produced in cortical_mapping_panel(). The panel has the LFP respones plotted for the thalamic channels, firingrate plots, shows the injection slice, and the timestamp of first response. Again, which_paradigms controls the paradigms being plotted, by default DAC1, DAC2. Both `anatomy_dir` and `ts_plots_dir` need to be passed for the panel to include these. Note that both need to be passed, even if you're note interested in one of them in order for them to be drawn. The `anatomy_dir` should have files of the form `mouse.png`, eg. `mGE82.png`, ts_plots_dir is simply the output dir of the previous function, first_response(). Plots are saved as usual at P["outputPath]/dest_dir_appdx. Note that the function expects that cortical mapping to be done. The same file as the one used for cortical mapping is referenced: P["outputPath"]/../chnls_map.csv. Subsequent scripts expect `VPM` as the label of this thalamic mapping. """ def make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad, stim_t): thal_lfp = lfp.iloc[-10:, :-50] * 1_000_000 ctx_lfp = lfp.iloc[:4, :-50] * 1_000_000 frates = frates.loc[range(22, 32), :] x_time = lfp.columns[:-50].values.astype(float) # init figure fig = plt.figure(figsize=(14, 10.2)) gs = fig.add_gridspec( 7, 3, width_ratios=[.5, .35, .15], height_ratios=[.25, .025, .15, .15, .15, .15, .15], hspace=0, wspace=.12, right=.975, top=.95, left=.02, bottom=.05) lfp_ax_ctx = fig.add_subplot(gs[0, 0]) frate_ax = fig.add_subplot(gs[0, 1]) assig_ax = fig.add_subplot(gs[0, 2]) lfp_axl = [fig.add_subplot(gs[i, 0]) for i in range(2, 7)] lfp_axr = [fig.add_subplot(gs[i, 1:]) for i in range(2, 7)] all_axs = lfp_axl + lfp_axr + [lfp_ax_ctx, frate_ax, assig_ax] lfp_axs = lfp_axl + lfp_axr + [ lfp_ax_ctx ] # 0,1,2,3,4 left plots, 5,6,7,8,9 right plots, 10 ctx plot # clean axes [ ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) for ax in all_axs ] # iterate bottom plots block and top left (all lfp's), setup axis for which_ax, ax in enumerate(lfp_axs): # general [sp.set_visible(False) for sp in ax.spines.values()] ax.patch.set_facecolor('grey') ax.patch.set_alpha(.16) ax.hlines(0, -5, 20, color='grey') # x axis xleft, xright = -0.05, 0.2 ax.xaxis.grid(True, which='major') ax.set_xlim((xleft, xright)) xticks = (xleft, 0, 0.05, 0.1, 0.15, xright) ax.set_xticks(xticks) # bottom plots if which_ax in (4, 9): ax.tick_params(labelbottom=True) ax.set_xticklabels(xticks, size=8) ax.set_xlabel('[s]', size=8) # y axis if which_ax != 10: ybot, ytop = -25, 25 else: # ctx plot ybot, ytop = -180, 90 ax.yaxis.grid(True, which='major') ax.set_ylim((ybot, ytop)) yts = np.concatenate( [np.arange(0, ybot, -15)[1:], np.arange(0, ytop, 15)]) ax.set_yticks(yts) # right plots if which_ax in (5, 6, 7, 8, 9): ax.tick_params(labelright=True) ax.set_yticklabels((-15, 0, 15), size=8) # ctx plots elif which_ax == 10: ax.tick_params(labelright=True) ax.set_yticklabels( [yt if yt in (-150, -75, 0, 60) else None for yt in yts], size=8) # ctx plot and middle, right plot if which_ax in (7, 10): ax.annotate('[uV]', (.263, 0), size=8, annotation_clip=False) if which_ax == 10: ax.set_title(f'{mouse}-{parad}-{stim_t}') # draw lfp's for (chnl, dat), ax in zip(thal_lfp.iterrows(), lfp_axl + lfp_axr): ax.set_ylabel(chnl + 1, fontsize=10, rotation=0, ha='right', va='center', labelpad=2) ax.plot(x_time, dat.values, clip_on=False, color=const.REGION_CMAP['Th']) # ax.vlines((lfp_summ.loc[chnl, 'Peak_neg_ts']/1000), ybot, ytop) # ax.vlines((mua_summ.loc[chnl, 'AvgTimetoFirstSpike']/1000), ybot, ytop, color='green') [ lfp_ax_ctx.plot(x_time, lfp, label=region, color=const.REGION_CMAP[region], clip_on=False) for region, lfp in ctx_lfp.iterrows() ] lfp_ax_ctx.legend() # setup the assignment plot assig_ax.tick_params(left=True, labelleft=True, pad=9.5) assig_ax.set_xlim((0, 1)) assig_ax.set_ylim((32.5, 22.5)) assig_ax.set_yticks(np.arange(23, 33)) # somewhat cryptic... this checks if the cortical map in the previous # paradigm differs from the current one. If True, the current plot # gets annotated with the change in cortical mapping first_chnls = pd.Series({ mouse_parad: mapping.index[this_map == 'SG'][0] for mouse_parad, this_map in mapping.iteritems() }) mouse_parad = f'{mouse}-{parad}' mouse_first_chnls = first_chnls[[ entr for entr in first_chnls.index if mouse in entr ]] current_first_chnl = mouse_first_chnls.loc[mouse_parad] if parad != const.PARAD_ORDER[mouse][0]: # if not the first paradigm before_first_chnl = mouse_first_chnls[ mouse_first_chnls.index.get_loc(mouse_parad) - 1] moved_by = before_first_chnl - current_first_chnl if moved_by: up_dwn = 'up' if moved_by > 0 else 'down' ann = f'Cortical map moved\n{up_dwn} by {abs(moved_by)}!' assig_ax.annotate(ann, (.1, 22), annotation_clip=False, fontsize=13) th_mapping = mapping.iloc[-10:, :].loc[:, [mouse_parad]] cols = [const.REGION_CMAP[region[0]] for region in th_mapping.values] # define how each layer is colored, label assig_ax.barh(np.arange(23, 33), 1, height=1, edgecolor='k', color=cols, alpha=.6) [ assig_ax.annotate(th_mapping.values[i][0], (.2, i + 23 + .2), fontsize=11) for i in range(len(th_mapping)) ] frate_ax.tick_params(top=True, labeltop=True, right=True) frate_ax.xaxis.set_label_position('top') frate_ax.set_xticks((-.05, 0, .1, .2)) frate_ax.set_xticklabels((-.05, 0, .1, .2), size=8) frate_ax.set_xlabel('[ms]', size=8) frate_ax.set_yticks(np.arange(23, 33)) frate_ax.set_ylim((22.5, 32.5)) frate_ax.imshow(frates, cmap='gnuplot', aspect='auto', extent=[-.0525, 0.2025, 22.5, 32.5], vmin=0, vmax=15) frate_ax.vlines([-.002], 22.5, 33.5, color='w', alpha=.6) # plt.show() return fig data = fetch(paradigms=which_paradigms, collapse_ctx_chnls=True, collapse_th_chnls=False, drop_not_assigned_chnls=False) # get the current assignment mapping = pd.read_csv(f'{const.P["outputPath"]}/../chnls_map.csv', index_col=0).reset_index(drop=True) for mouse in const.ALL_MICE: if mouse not in const.PARAD_ORDER: print( f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.' ) exit(1) for i, parad in enumerate(const.PARAD_ORDER[mouse]): if parad not in which_paradigms: continue if parad not in ['MS', 'DOC1', 'DOC2']: peak_stim = 'Deviant' else: peak_stim = 'C1' if parad == 'MS' else 'Standard' frates = slice_data(data, [mouse], [parad], [peak_stim], firingrate=True, frate_noise_subtraction='paradigm_wise', drop_labels=True)[0] lfp = slice_data(data, [mouse], [parad], [peak_stim], lfp=True, drop_labels=True)[0] lfp_summ = slice_data(data, [mouse], [parad], [peak_stim], lfp_summary=True, drop_labels=True)[0] mua_summ = slice_data(data, [mouse], [parad], [peak_stim], mua_summary=True, drop_labels=True)[0] fig = make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad, peak_stim) f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{i+1}_{parad}.png' fig.savefig(f) if anatomy_dir and ts_plots_dir: plot = Image.open(f) anat = Image.open(f'{anatomy_dir}/{mouse}.png') ts_plot = Image.open( f'{ts_plots_dir}/{mouse}_{parad}_{peak_stim}_first_ts.png') final_img = Image.new( 'RGB', (plot.width + anat.width, anat.height + ts_plot.height), color='white') final_img.paste(plot, (0, 0)) # upper left final_img.paste(anat, (plot.width, 0)) # upper right final_img.paste(ts_plot, (plot.width, anat.height)) # upper right final_img.save(f)
def plot_first_response(dest_dir_appdx='../', which_paradigms=['DAC1', 'DAC2']): """ This function plots a comparison between first responses in the different channels. This information is taken from MUA summary (AvgTimetoFirstSpike) and LFP summary (`Peak_neg_ts`). These plots are used in the panel for defining thalamic channels using the function below. `which_paradigms` selects the paradigms to plot, defaults to [`DAC1`, `DAC2`]. If all paradigms should be plotted, you could pass const.ALL_PARADIGMS. The plots are saved as usual at P["outputPath"]/dest_dir_appdx/ """ data = fetch(paradigms=which_paradigms, collapse_ctx_chnls=True, collapse_th_chnls=False, drop_not_assigned_chnls=False) for mouse in const.ALL_MICE: if mouse not in const.PARAD_ORDER: print( f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.' ) exit(1) for i, parad in enumerate(const.PARAD_ORDER[mouse]): if parad not in which_paradigms: continue if parad not in ['MS', 'DOC1', 'DOC2']: peak_stim = 'Deviant' else: peak_stim = 'C1' if parad == 'MS' else 'Standard' fig, axes = plt.subplots(nrows=2, figsize=(10, 3.4)) fig.subplots_adjust(right=.975, top=.94, left=.02, bottom=.13, hspace=.15) for which_ax, ax in enumerate(axes): # get the data (lfp or mua time stamp) if which_ax == 0: lfp_summ = slice_data(data, [mouse], [parad], [peak_stim], lfp_summary=True, drop_labels=True)[0] first_ts = lfp_summ.loc[:, 'Peak_neg_ts'].iloc[ -10:].sort_values() first_ts_G = lfp_summ.loc['G', 'Peak_neg_ts'] elif which_ax == 1: mua_summ = slice_data(data, [mouse], [parad], [peak_stim], mua_summary=True, drop_labels=True)[0] first_ts = mua_summ.loc[:, 'AvgTimetoFirstSpike'].iloc[ -10:].sort_values() first_ts_G = mua_summ.loc['G', 'AvgTimetoFirstSpike'] ax.tick_params(left=False, labelleft=False) if which_ax == 0: ax.tick_params(bottom=False, labelbottom=False) [sp.set_visible(False) for sp in ax.spines.values()] ax.patch.set_facecolor('grey') ax.patch.set_alpha(.16) ax.hlines((0), -10, 200, color='grey') tit = 'Peak negative LFP time stamp' if which_ax == 0 else 'Avg first spike time stamp' nan_chnls = [ chnl + 1 for chnl in first_ts.index[first_ts.isna()] ] if nan_chnls: tit = f'{tit} (NA channels: {nan_chnls})' if which_ax == 1: zero_chnls = (first_ts.index[first_ts == 0]).values first_ts[zero_chnls] = np.nan tit = f'{tit} (no spike channels: {list(zero_chnls+1)})' ax.set_title(tit, loc='left', pad=2, size=9) ax.set_ylim((-1.5, 1.5)) ax.set_xlim((-10, 200)) xts = (0, 4, 8, 12, 16, 20, 30, 40, 60, 80, 100, 150, 200) ax.set_xticks(xts) ax.xaxis.grid(True, which='major') if which_ax == 1: ax.set_xlabel('[ms] post stimulus', labelpad=2, size=9) step = .4 ycoords = [] close_coords = [] for i in range(len(first_ts)): if i and first_ts.iloc[i] - first_ts.iloc[i - 1] < 5: step *= -1 if step in close_coords: step = step + .4 if step > 0 else step - .4 if len(close_coords) == 0: close_coords.append(step * -1) close_coords.append(step) else: step = .4 close_coords = [] ycoords.append(step) [ ax.vlines(x, 0, y, linewidth=1, color=const.GENERAL_CMAP['Th'], alpha=.7) for y, x in zip(ycoords, first_ts) ] [ ax.annotate(chnl + 1, (x, y), va='center', ha='center', zorder=20) for y, (chnl, x) in zip(ycoords, first_ts.iteritems()) ] ax.vlines(first_ts_G, -.4, .4, linewidth=1, color=const.GENERAL_CMAP['G'], alpha=.7) ax.annotate('G', (first_ts_G, 0), va='center', ha='center', zorder=20, size=12) f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{parad}_{peak_stim}_first_ts.png' fig.savefig(f)
def firingrate_noise_timeline(dest_dir_appdx='../', fname_postfix='', subtr_noise=False): """Check background firingrates activity for experimental time line between 4 different mice and averaged stimulus types. `subtr_noise` should either be False or `deviant_alone`, `paradigm_wise` (ie the method).""" data = fetch() ratio = { 'width_ratios': [.1] * 11 + [.28], } fig, axes = plt.subplots(4, 12, sharex=False, sharey=False, figsize=(15, 13), gridspec_kw=ratio) fig.subplots_adjust(hspace=.2, wspace=.03, right=.95, top=.86, left=.1, bottom=.07) title = 'noise over experimental paradigm timeline\n(mean over stimulus types)' if subtr_noise: title += '\n\nNoise subtracted: ' + subtr_noise fig.suptitle(title, size=14) plt.cm.get_cmap('gnuplot').set_gamma(.8) [ ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) for ax in axes.flatten() ] for mouse, i in zip(const.ALL_MICE, range(4)): mouse_dat = slice_data(data, [mouse], firingrate=True, frate_noise_subtraction=subtr_noise) axes[i, 0].set_ylabel(mouse + '\nchannels', size=12, rotation=0, ha='right', va='center') neg_frates = [] for parad, j in zip(const.PARAD_ORDER[mouse], range(11)): parad_frates = [ df for key, df in mouse_dat.items() if parad in key ] neg_frates_counts = [(fr < 0).sum().sum() for fr in parad_frates] neg_frates.append(sum(neg_frates_counts) / len(neg_frates_counts)) frates = (sum(parad_frates) / len(parad_frates)).astype(int) im = axes[i, j].imshow(frates, cmap='gnuplot', aspect='auto', extent=[-52.5, 202.5, -.5, 31.5], vmin=0, vmax=500) axes[i, j].set_title(parad, size=7, pad=2) if 'DA' in parad and subtr_noise == 'deviant_alone': axes[i, j].set_title('**' + parad + '**', size=9, pad=2) [ axes[i, j].spines[where].set_color('yellow') for where in axes[i, j].spines ] [ axes[i, j].spines[where].set_linewidth(1.5) for where in axes[i, j].spines ] if (i == 0) and (j == 0): axes[i, j].set_xlim((-52.5, 202.5)) axes[i, j].set_xticks([-50, 0, 80, 160]) # colorbar and legend at = ( 0.77, .95, .2, .012, ) cb = fig.colorbar(im, cax=fig.add_axes(at), orientation='horizontal') cb.set_label('Mean Firing Rate in 5ms frame', size=12) cb.ax.get_xaxis().set_label_position('top') rel_neg_frates = np.array(neg_frates) / 1600 axes[i, 11].bar(range(11), rel_neg_frates) axes[i, 11].bar([12], rel_neg_frates.mean()) # x axis axes[i, 11].tick_params(labelbottom=True, rotation=35, labelsize=6.5) axes[i, 11].set_xticks(list(range(11)) + [12]) axes[i, 11].set_xticklabels(const.PARAD_ORDER[mouse] + ('Average', ), clip_on=False, ha='right', y=.04) # y axis axes[i, 11].tick_params(right=True, labelright=True) axes[i, 11].set_ylim((0, 1)) yticks = np.arange(0, 1.1, .1) axes[i, 11].set_yticks(yticks) axes[i, 11].set_yticklabels( [str(yt) if yt in (0, 1) else '' for yt in yticks]) axes[i, 11].yaxis.set_label_position("right") axes[i, 11].yaxis.grid(True, alpha=.5) axes[i, 11].set_axisbelow(True) if i == 1: axes[i, 11].set_ylabel( 'Proportion of negative firing rates (of 32 channels x 50 time bins)', size=10) plt.savefig( f'{const.P["outputPath"]}/{dest_dir_appdx}/firingrate_over_time{fname_postfix}.{const.PLOT_FORMAT}' )
def onset_offset_response(plots_dest_dir_appdx='', csv_dest_dir_appdx='', generate_plots=True, single_channels=True, draw_gmm_fit=True): """This function presents the start of the onset-offset analysis pipeline. This functions iterates all mice, all paradigms, all stimulus types as usual then fetches the dataframe with all the negative time stamps. Out of this, a histogram of the 0-20ms trace is generated (200 bins, 0.1ms timeframe bins). The bin counts serve as the input feature for the classifier. The final CSV with the computed bin counts will be saved at P["outputPath"]/csv_dest_dir_appdx/onset_offset_spikebins_*_.csv'. If `csv_dest_dir_appdx` is None, no data is saved. In addition, `generate_plots` gives the option to draw the histograms.`draw_gmm_fit` will fit a Gaussian Mixture Model to the histogram with 10 components. When `single_channels` is True (default), the 32 channels are not collpassed into the previously defined region mapping. The classifier is designed to classify single channels, not collapsed regions. Setting this to False is implemented for plotting region histograms. The plots are saved at P["outputPath"]/plots_dest_dir_appdx/onset_offset_*_*_.png'. The produced histogram data is also returned besides being saved.""" # get all the available data from the output dir if single_channels: data = fetch() which_region = 'channels' plt_spacers = { 'hspace': 0, 'right': .97, 'top': .96, 'left': .1, 'bottom': .07 } else: data = fetch(collapse_ctx_chnls=True, collapse_th_chnls=True, drop_not_assigned_chnls=True) which_region = 'regions' plt_spacers = { 'hspace': 0, 'right': .97, 'top': .85, 'left': .1, 'bottom': .22 } # due to different base level actitvity, set heatmap vmax seperately vmaxs = { 'mGE82': 20, 'mGE83': 40, 'mGE84': 40, 'mGE85': 40, } # iter the usual dimensions nspikes = 0 all_spike_bins = [] for m_id in const.ALL_MICE: for parad in const.ALL_PARADIGMS: for stim_t in const.ALL_STIMTYPES: key = '-'.join([m_id, parad, stim_t]) # not all combinations of paradigm/stimtype exist if key not in data.keys(): continue # get the negative sptike time stamps (a MultiIndex DataFrame) spikes = slice_data(data, m_id, parad, stim_t, neg_spikes=True, drop_labels=True)[0].sort_index(axis=1) nspikes += np.count_nonzero(spikes) if not single_channels: spikes = spikes.reindex(['SG', 'G', 'IG', 'dIG', 'VPM'], axis=1, level=0) if generate_plots: # init plot with fitting size (height depends on n channels/ regions) nregions = len(spikes.columns.unique(0)) fig, axes = plt.subplots(nrows=nregions, figsize=(6, .4 * nregions)) fig.subplots_adjust(**plt_spacers) [ ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) for ax in axes.flatten() ] axes[0].set_title(key) else: # dummy axes = range(len(spikes.columns.unique(0))) # # add empty channels as well so that the total samples is always consistent no_spike_chnls = [ chnl for chnl in range(1, 33) if chnl not in spikes.columns.unique(0) ] for region in no_spike_chnls: lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}' all_spike_bins.append( pd.Series(np.zeros(200, dtype=int), name=lbl)) for region, ax in zip(spikes.columns.unique(0), axes): # get the spike timestamp data sliced to the channel/ region region_spikes = spikes[region].values.flatten() # set 0 to nan _> ignored by np.hist region_spikes[region_spikes == 0] = np.nan hist = np.histogram(region_spikes, bins=2000, range=(-50, 200)) # slice to 0-20ms bins start, stop = 400, 600 # relative to 2000 bins from -50-200 spike_bins = hist[0][np.newaxis, start:stop] if spikes.shape[0] != 200: # norm spike counts to 200 trails spike_bins = spike_bins / spikes.shape[0] spike_bins = (spike_bins * 200).astype(int) lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}' all_spike_bins.append(pd.Series(spike_bins[0], name=lbl)) if not generate_plots: continue # draw the heatmap, setup axis ax.imshow(spike_bins, aspect='auto', extent=(hist[1][start], hist[1][stop], 0, 1), vmin=0, vmax=vmaxs[m_id]) ax.set_ylabel(region, rotation=0, va='center', labelpad=20) ax.set_ylim(0, .4) xt = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20] ax.set_xticks(xt) ax.set_xlim(xt[0], xt[-1]) if draw_gmm_fit and (spike_bins > 1).sum() > 15: model = bgmm(10, covariance_type='diag', random_state=1, mean_prior=(8, ), covariance_prior=(.1, ), degrees_of_freedom_prior=10) post_stim = np.logical_and(region_spikes > 0, region_spikes < 20) model.fit(region_spikes[post_stim, np.newaxis]) x = np.linspace(-50, 200, 2000).reshape(2000, 1) logprob = model.score_samples(x) pdf = np.exp(logprob) ax.plot(x, pdf, '-w', linewidth=.8) # the last plot gets a red stimulus indication bar if region == spikes.columns.unique(0)[-1]: ax.tick_params(bottom=True, labelbottom=True) ax.set_xlabel('[ms]') ax.hlines(0, 0, 8, clip_on=False, linewidth=6, color='r') ax.annotate('Stimulus', (2.3, -.6), color='r', fontsize=15, annotation_clip=False) if generate_plots: f = (f'{const.P["outputPath"]}/{plots_dest_dir_appdx}/' f'onset_offset_{key}_{which_region}.png') fig.savefig(f) plt.close() on_off_scores = pd.concat(all_spike_bins, axis=1).T if csv_dest_dir_appdx: f = f'{const.P["outputPath"]}/{csv_dest_dir_appdx}/onset_offset_spikebins_{which_region}.csv' on_off_scores.to_csv(f) print(f'\nSpikes found: {nspikes:,}') return on_off_scores
def plot_paradigm(parad): print(parad) if 'C1' in parad or 'C2' in parad: dev = parad[-2:] std = 'C2' if dev == 'C1' else 'C1' to_regions = [ 'collapse_ctx_chnls', 'collapse_th_chnls', 'drop_not_assigned_chnls' ] to_regions = dict.fromkeys(to_regions, True if chnls_to_regions else False) if grouping == 'paradigm_wise': data = fetch(paradigms=[parad]) elif grouping == 'whisker_wise': if parad != 'MS': dev_data = fetch(paradigms=[parad], stim_types=['Deviant'], **to_regions) std_parad = parad.replace(dev, std) stim_types = ['Predeviant', 'UniquePredeviant'] if all_stimuli: stim_types.extend( ['Standard', 'Postdeviant', 'UniquePostdeviant']) std_data = fetch(paradigms=[std_parad], stim_types=stim_types, **to_regions) if std_parad in ['DOC1', 'DOC2']: std_data = fetch(paradigms=[std_parad], stim_types=['Standard'], **to_regions) data = {**std_data, **dev_data} else: data = fetch(paradigms=[parad], **to_regions) elif grouping == 'whisker_wise_reduced': dev_parads = [ this_parad for this_parad in const.ALL_PARADIGMS if dev in this_parad ] std_parads = [ this_parad.replace(dev, std) for this_parad in dev_parads ] if dev == 'C1': order = ( 'DAC1-Deviant', 'O10C2-Predeviant', 'O10C1-Deviant', 'O25C2-UniquePredeviant', 'O25C1-Deviant', 'MS-C1', 'DOC2-Standard', 'DOC1-Deviant', ) else: order = ('DAC2-Deviant', 'O10C1-Predeviant', 'O10C2-Deviant', 'O25C1-UniquePredeviant', 'O25C2-Deviant', 'MS-C2', 'DOC1-Standard', 'DOC2-Deviant') dev_data = fetch(paradigms=dev_parads, stim_types=['Deviant']) std_data = fetch(paradigms=std_parads, stim_types=['Predeviant', 'UniquePredeviant']) std_data.update( fetch(paradigms=['DO' + std], stim_types=['Standard'])) ms_data = fetch(paradigms=['MS'], stim_types=[dev]) data = {**std_data, **dev_data, **ms_data} data = OrderedDict({ key: data[key] for ord_key in order for key in data.keys() if ord_key in key }) if grouping != 'whisker_wise_reduced': if parad != 'MS' and not all_stimuli: args = {'ncols': 2} width = 7 else: args = {'ncols': 4} width = 13 else: args = { 'ncols': 8 + 4, 'gridspec_kw': { 'width_ratios': [.1, .015, .1, .1, .015, .1, .1, .015, .1, .015, .1, .1], } } width = 20 fig, axes = plt.subplots(4, **args, sharex=True, sharey=True, figsize=(width, 13)) fig.subplots_adjust(hspace=.06, wspace=.03, right=.98, top=.86, left=.13, bottom=.07) [ ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) for ax in axes.flatten() ] if grouping != 'whisker_wise_reduced': title = const.PARAD_FULL[ parad] + '- mean firing rates across 4 mice' else: title = parad + '- mean firing rates across 4 mice' if subtr_noise: title += '\nNOISE SUBTRACTED' fig.suptitle(title, size=14) plt.cm.get_cmap('gnuplot').set_gamma(.8) print() print() print() for mouse, i in zip(const.ALL_MICE, range(4)): mouse_dat = slice_data(data, [mouse], firingrate=True, frate_noise_subtraction=subtr_noise) if chnls_to_regions: axes[i, 0].tick_params(left=True, labelleft=True) lbls = list(mouse_dat.values())[0].index[::-1] axes[i, 0].set_yticks(np.arange(6.4 / 2, 32, 6.4)) axes[i, 0].set_yticklabels(lbls) axes[i, 0].set_ylabel(mouse + '\nchannels', size=12, rotation=0, ha='right', va='center', x=-50) which_ax = 0 for (key, frates), j in zip(mouse_dat.items(), range(args['ncols'])): im = axes[i, which_ax].imshow(frates, cmap='gnuplot', aspect='auto', extent=[-52.5, 202.5, -.5, 31.5], vmin=0, vmax=500) axes[i, which_ax].vlines(0, -.5, 31.5, color='w', alpha=.6, linewidth=1) if i == 0: stim_t = key[key.rfind('-') + 1:] col = 'k' if stim_t != 'Deviant' else const.COLORS[ 'deep_red'] if grouping == 'paradigm_wise' or parad == 'MS': axes[i, which_ax].set_title(stim_t, color=col) elif grouping == 'whisker_wise': axes[i, which_ax].set_title(f'{parad[-2:]} {stim_t}', color=col) elif grouping == 'whisker_wise_reduced': pard_full = const.PARAD_FULL[ key[key.find('-') + 1:key.rfind('-')]][:-3] title = f'{dev} {stim_t}\n{pard_full}' if 'MS' not in key: axes[i, which_ax].set_title(title, color=col) else: axes[i, which_ax].set_title(stim_t + '\nMany Standards', color=col) elif i == 3: axes[i, which_ax].tick_params(bottom=True, labelbottom=True) axes[i, which_ax].set_xlabel('ms') if (i == 0) and (which_ax == 0): axes[i, which_ax].set_xlim((-52.5, 202.5)) axes[i, which_ax].set_xticks([-50, 0, 80, 160]) # colorbar and legend at = ( .58, .9, 2.5 / width, .012, ) cb = fig.colorbar(im, cax=fig.add_axes(at), orientation='horizontal') cb.set_label('Mean Firing Rate in 5ms frame', size=12) cb.ax.get_xaxis().set_label_position('top') which_ax += 1 if grouping == 'whisker_wise_reduced' and which_ax in [ 1, 4, 7, 9 ]: axes[i, which_ax].set_visible(False) which_ax += 1 return fig