def plot_coding_direction(units, time_period=None, label=None, axs=None): # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit') _, proj_contra_trial, proj_ipsi_trial, time_stamps, _ = psth.compute_CD_projected_psth( units.fetch('KEY'), time_period=time_period) fig = None if axs is None: fig, axs = plt.subplots(1, 1, figsize=(8, 6)) # plot _plot_with_sem(proj_contra_trial, time_stamps, ax=axs, c='b') _plot_with_sem(proj_ipsi_trial, time_stamps, ax=axs, c='r') for x in period_starts: axs.axvline(x=x, linestyle = '--', color = 'k') # cosmetic axs.spines['right'].set_visible(False) axs.spines['top'].set_visible(False) axs.set_ylabel('CD projection (a.u.)') axs.set_xlabel('Time (s)') if label: axs.set_title(label) return fig
def plot_avg_contra_ipsi_psth(units, axs=None): units = units.proj() # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit') hemi = _get_units_hemisphere(units) good_unit = ephys.Unit & 'unit_quality != "all"' conds_i = (psth.TrialCondition & {'trial_condition_name': 'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch('KEY') conds_c = (psth.TrialCondition & {'trial_condition_name': 'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch('KEY') sel_i = (ephys.Unit * psth.UnitSelectivity & 'unit_selectivity = "ipsi-selective"' & units) sel_c = (ephys.Unit * psth.UnitSelectivity & 'unit_selectivity = "contra-selective"' & units) psth_is_it = (((psth.UnitPsth & conds_i) * ephys.Unit.proj('unit_posy')) & good_unit.proj() & sel_i.proj()).fetch( 'unit_psth', order_by='unit_posy desc') psth_is_ct = (((psth.UnitPsth & conds_c) * ephys.Unit.proj('unit_posy')) & good_unit.proj() & sel_i.proj()).fetch( 'unit_psth', order_by='unit_posy desc') psth_cs_ct = (((psth.UnitPsth & conds_c) * ephys.Unit.proj('unit_posy')) & good_unit.proj() & sel_c.proj()).fetch( 'unit_psth', order_by='unit_posy desc') psth_cs_it = (((psth.UnitPsth & conds_i) * ephys.Unit.proj('unit_posy')) & good_unit.proj() & sel_c.proj()).fetch( 'unit_psth', order_by='unit_posy desc') fig = None if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 _plot_avg_psth(psth_cs_it, psth_cs_ct, period_starts, axs[0], 'Contra-selective') _plot_avg_psth(psth_is_it, psth_is_ct, period_starts, axs[1], 'Ipsi-selective') ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) return fig
def plot_stacked_contra_ipsi_psth(units, axs=None): units = units.proj() # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit') hemi = _get_units_hemisphere(units) conds_i = (psth.TrialCondition & {'trial_condition_name': 'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch1('KEY') conds_c = (psth.TrialCondition & {'trial_condition_name': 'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch1('KEY') sel_i = (ephys.Unit * psth.UnitSelectivity & 'unit_selectivity = "ipsi-selective"' & units) sel_c = (ephys.Unit * psth.UnitSelectivity & 'unit_selectivity = "contra-selective"' & units) # ipsi selective ipsi trials psth_is_it = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc') # ipsi selective contra trials psth_is_ct = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc') # contra selective contra trials psth_cs_ct = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc') # contra selective ipsi trials psth_cs_it = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc') fig = None if axs is None: fig, axs = plt.subplots(1, 2, figsize=(20, 20)) assert axs.size == 2 _plot_stacked_psth_diff(psth_cs_ct, psth_cs_it, ax=axs[0], vlines=period_starts, flip=True) axs[0].set_title('Contra-selective Units') axs[0].set_ylabel('Unit (by depth)') axs[0].set_xlabel('Time to go (s)') _plot_stacked_psth_diff(psth_is_it, psth_is_ct, ax=axs[1], vlines=period_starts) axs[1].set_title('Ipsi-selective Units') axs[1].set_ylabel('Unit (by depth)') axs[1].set_xlabel('Time to go (s)') return fig
def plot_unit_psth(unit_key, axs=None, title='', xlim=_plt_xlim): """ Default raster and PSTH plot for a specified unit - only {good, no early lick, correct trials} selected condition_name_kw: list of keywords to match for the TrialCondition name """ hemi = _get_units_hemisphere(unit_key) ipsi_hit_unit_psth = psth.UnitPsth.get_plotting_data( unit_key, { 'trial_condition_name': f'good_noearlylick_{"left" if hemi == "left" else "right"}_hit' }) contra_hit_unit_psth = psth.UnitPsth.get_plotting_data( unit_key, { 'trial_condition_name': f'good_noearlylick_{"right" if hemi == "left" else "left"}_hit' }) ipsi_miss_unit_psth = psth.UnitPsth.get_plotting_data( unit_key, { 'trial_condition_name': f'good_noearlylick_{"left" if hemi == "left" else "right"}_miss' }) contra_miss_unit_psth = psth.UnitPsth.get_plotting_data( unit_key, { 'trial_condition_name': f'good_noearlylick_{"right" if hemi == "left" else "left"}_miss' }) # get event start times: sample, delay, response periods, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_key, 'good_noearlylick_hit') fig = None if axs is None: fig, axs = plt.subplots(2, 2) # correct response _plot_spike_raster(ipsi_hit_unit_psth, contra_hit_unit_psth, ax=axs[0, 0], vlines=period_starts, title=title if title else f'Unit #: {unit_key["unit"]}\nCorrect Response', xlim=xlim) _plot_psth(ipsi_hit_unit_psth, contra_hit_unit_psth, vlines=period_starts, ax=axs[1, 0], xlim=xlim) # incorrect response _plot_spike_raster(ipsi_miss_unit_psth, contra_miss_unit_psth, ax=axs[0, 1], vlines=period_starts, title=title if title else f'Unit #: {unit_key["unit"]}\nIncorrect Response', xlim=xlim) _plot_psth(ipsi_miss_unit_psth, contra_miss_unit_psth, vlines=period_starts, ax=axs[1, 1], xlim=xlim) return fig
def make(self, key): water_res_num, sess_date = get_wr_sessdate(key) sess_dir = store_stage / water_res_num / sess_date sess_dir.mkdir(parents=True, exist_ok=True) # ---- Setup ---- time_period = (-0.4, 0) probe_keys = (ephys.ProbeInsertion & key).fetch( 'KEY', order_by='insertion_number') fig1, axs = plt.subplots(len(probe_keys), len(probe_keys), figsize=(16, 16)) if len(probe_keys) > 1: [a.axis('off') for a in axs.flatten()] # ---- Plot Coding Direction per probe ---- probe_proj = {} for pid, probe in enumerate(probe_keys): units = ephys.Unit & probe label = (ephys.ProbeInsertion & probe).aggr( ephys.ProbeInsertion.RecordableBrainRegion.proj( brain_region='CONCAT(hemisphere, " ", brain_area)'), brain_regions='GROUP_CONCAT(brain_region SEPARATOR", ")' ).fetch1('brain_regions') label = '({}) {}'.format(probe['insertion_number'], label) _, period_starts = _get_trial_event_times( ['sample', 'delay', 'go'], units, 'good_noearlylick_hit') # ---- compute CD projected PSTH ---- _, proj_contra_trial, proj_ipsi_trial, time_stamps, hemi = psth.compute_CD_projected_psth( units.fetch('KEY'), time_period=time_period) # ---- save projection results ---- probe_proj[pid] = (proj_contra_trial, proj_ipsi_trial, time_stamps, label, hemi) # ---- generate fig with CD plot for this probe ---- fig, ax = plt.subplots(1, 1, figsize=(6, 6)) _plot_with_sem(proj_contra_trial, time_stamps, ax=ax, c='b') _plot_with_sem(proj_ipsi_trial, time_stamps, ax=ax, c='r') # cosmetic for x in period_starts: ax.axvline(x=x, linestyle='--', color='k') ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.set_ylabel('CD projection (a.u.)') ax.set_xlabel('Time (s)') ax.set_title(label) fig.tight_layout() # ---- plot this fig into the main figure ---- buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) axs[pid, pid].imshow(Image.open(buf)) buf.close() plt.close(fig) # ---- Plot probe-pair correlation ---- for p1, p2 in itertools.combinations(probe_proj.keys(), r=2): proj_contra_trial_g1, proj_ipsi_trial_g1, time_stamps, label_g1, p1_hemi = probe_proj[ p1] proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, label_g2, p2_hemi = probe_proj[ p2] labels = [label_g1, label_g2] # plot trial CD-endpoint correlation p_start, p_end = time_period contra_cdend_1 = proj_contra_trial_g1[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end)].mean( axis=1) ipsi_cdend_1 = proj_ipsi_trial_g1[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end )].mean(axis=1) if p1_hemi == p2_hemi: contra_cdend_2 = proj_contra_trial_g2[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end)].mean( axis=1) ipsi_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end)].mean( axis=1) else: contra_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end )].mean(axis=1) ipsi_cdend_2 = proj_contra_trial_g2[:, np.logical_and( time_stamps >= p_start, time_stamps < p_end )].mean(axis=1) c_df = pd.DataFrame([contra_cdend_1, contra_cdend_2]).T c_df.columns = labels c_df['trial-type'] = 'contra' i_df = pd.DataFrame([ipsi_cdend_1, ipsi_cdend_2]).T i_df.columns = labels i_df['trial-type'] = 'ipsi' df = c_df.append(i_df) # remove NaN trial - could be due to some trials having no spikes non_nan = ~np.logical_or( np.isnan(df[labels[0]]).values, np.isnan(df[labels[1]]).values) df = df[non_nan] fig = plt.figure(figsize=(6, 6)) jplot = _jointplot_w_hue(data=df, x=labels[0], y=labels[1], hue='trial-type', colormap=['b', 'r'], figsize=(8, 6), fig=fig, scatter_kws=None) # ---- plot this fig into the main figure ---- buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) axs[p1, p2].imshow(Image.open(buf)) buf.close() plt.close(fig) else: # ---- Plot Single-Probe Coding Direction ---- probe = probe_keys[0] units = ephys.Unit & probe label = (ephys.ProbeInsertion & probe).aggr( ephys.ProbeInsertion.RecordableBrainRegion.proj( brain_region='CONCAT(hemisphere, " ", brain_area)'), brain_regions='GROUP_CONCAT(brain_region SEPARATOR", ")' ).fetch1('brain_regions') unit_characteristic_plot.plot_coding_direction( units, time_period=time_period, label=label, axs=axs) # ---- Save fig and insert ---- fn_prefix = f'{water_res_num}_{sess_date}_' fig_dict = save_figs((fig1, ), ('coding_direction', ), sess_dir, fn_prefix) plt.close('all') self.insert1({**key, **fig_dict, 'cd_probe_count': len(probe_keys)})
def plot_paired_coding_direction(unit_g1, unit_g2, labels=None, time_period=None): """ Plot trial-to-trial CD-endpoint correlation between CD-projected trial-psth from two unit-groups (e.g. two brain regions) Note: coding direction is calculated on selective units, contra vs. ipsi, within the specified time_period """ _, proj_contra_trial_g1, proj_ipsi_trial_g1, time_stamps, unit_g1_hemi = psth.compute_CD_projected_psth( unit_g1.fetch('KEY'), time_period=time_period) _, proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, unit_g2_hemi = psth.compute_CD_projected_psth( unit_g2.fetch('KEY'), time_period=time_period) # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_g1, 'good_noearlylick_hit') if labels: assert len(labels) == 2 else: labels = ('unit group 1', 'unit group 2') # plot projected trial-psth fig, axs = plt.subplots(1, 2, figsize=(16, 6)) _plot_with_sem(proj_contra_trial_g1, time_stamps, ax=axs[0], c='b') _plot_with_sem(proj_ipsi_trial_g1, time_stamps, ax=axs[0], c='r') _plot_with_sem(proj_contra_trial_g2, time_stamps, ax=axs[1], c='b') _plot_with_sem(proj_ipsi_trial_g2, time_stamps, ax=axs[1], c='r') # cosmetic for ax, label in zip(axs, labels): for x in period_starts: ax.axvline(x=x, linestyle = '--', color = 'k') ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.set_ylabel('CD projection (a.u.)') ax.set_xlabel('Time (s)') ax.set_title(label) # plot trial CD-endpoint correlation - if 2 unit-groups are from 2 hemispheres, # then contra-ipsi definition is based on the first group p_start, p_end = time_period contra_cdend_1 = proj_contra_trial_g1[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) ipsi_cdend_1 = proj_ipsi_trial_g1[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) if unit_g1_hemi == unit_g1_hemi: contra_cdend_2 = proj_contra_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) ipsi_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) else: contra_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) ipsi_cdend_2 = proj_contra_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1) c_df = pd.DataFrame([contra_cdend_1, contra_cdend_2]).T c_df.columns = labels c_df['trial-type'] = 'contra' i_df = pd.DataFrame([ipsi_cdend_1, ipsi_cdend_2]).T i_df.columns = labels i_df['trial-type'] = 'ipsi' df = c_df.append(i_df) jplot = _jointplot_w_hue(data=df, x=labels[0], y=labels[1], hue= 'trial-type', colormap=['b', 'r'], figsize=(8, 6), fig=None, scatter_kws=None) jplot['fig'].show() return fig
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None): """ For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: both ALM) """ units = units.proj() fig = None if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 hemi = _get_units_hemisphere(units) # no photostim: psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0] psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0] psth_n_l = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_n_l} & 'unit_psth is not NULL').fetch('unit_psth') psth_n_r = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_n_r} & 'unit_psth is not NULL').fetch('unit_psth') # with photostim psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_left'])[0] psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_right'])[0] psth_s_l = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_s_l} & 'unit_psth is not NULL').fetch('unit_psth') psth_s_r = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth') # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit') # get photostim onset and duration stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0] stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials(stim_trial_cond_name) & units).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) stim_time = _get_stim_onset_time(units, stim_trial_cond_name) if hemi == 'left': psth_s_i = psth_s_l psth_n_i = psth_n_l psth_s_c = psth_s_r psth_n_c = psth_n_r else: psth_s_i = psth_s_r psth_n_i = psth_n_r psth_s_c = psth_s_l psth_n_c = psth_n_l _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control') _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Photostim') # cosmetic ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) ax.set_xlim([_plt_xmin, _plt_xmax]) # add shaded bar for photostim axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue') return fig
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None): probe_insertion = probe_insertion.proj() if clustering_method is None: try: clustering_method = _get_clustering_method(probe_insertion) except ValueError as e: raise ValueError( str(e) + '\nPlease specify one with the kwarg "clustering_method"') dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('depth') no_stim_cond = (psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_nostim' }).fetch1('KEY') bi_stim_cond = (psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_both_alm_stim' }).fetch1('KEY') # get photostim duration stim_durs = np.unique( (experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) units = ephys.Unit & probe_insertion & { 'clustering_method': clustering_method } & 'unit_quality != "all"' metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) _, cue_onset = _get_trial_event_times(['delay'], units, 'all_noearlylick_nostim') cue_onset = cue_onset[0] # XXX: could be done with 1x fetch+join for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')): if clustering_method in ('kilosort2'): x, y = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj() * lab.ProbeType.Electrode.proj('x_coord', 'y_coord') & unit).fetch1('x_coord', 'y_coord') else: x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy') # obtain unit psth per trial, for all nostim and bistim trials nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials( no_stim_cond['trial_condition_name']) bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials( bi_stim_cond['trial_condition_name']) nostim_psths, nostim_edge = psth.compute_unit_psth( unit, nostim_trials.fetch('KEY'), per_trial=True) bistim_psths, bistim_edge = psth.compute_unit_psth( unit, bistim_trials.fetch('KEY'), per_trial=True) # compute the firing rate difference between contra vs. ipsi within the stimulation duration ctrl_frate = np.array([ nostim_psth[np.logical_and( nostim_edge >= cue_onset, nostim_edge <= cue_onset + stim_dur)].mean() for nostim_psth in nostim_psths ]) stim_frate = np.array([ bistim_psth[np.logical_and( bistim_edge >= cue_onset, bistim_edge <= cue_onset + stim_dur)].mean() for bistim_psth in bistim_psths ]) frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean() frate_change = abs(frate_change) if frate_change < 0 else 0.0001 metrics.loc[u_idx] = (int(unit['unit']), x, float(dv_loc) + y, frate_change) metrics.frate_change = metrics.frate_change / metrics.frate_change.max() # --- prepare for plotting shank_count = (ephys.ProbeInsertion & probe_insertion).aggr( lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode, shank_count='count(distinct shank)').fetch1('shank_count') m_scale = get_m_scale(shank_count) fig = None if axs is None: fig, axs = plt.subplots(1, 1, figsize=(4, 8)) xmax = 1.3 * metrics.x.max() xmin = -1 / 6 * xmax cosmetic = { 'legend': None, 'linewidth': 1.75, 'alpha': 0.9, 'facecolor': 'none', 'edgecolor': 'k' } sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change * m_scale, ax=axs, **cosmetic) axs.spines['right'].set_visible(False) axs.spines['top'].set_visible(False) axs.set_title('% change') axs.set_xlim((xmin, xmax)) return fig