def _get_trial_event_times(events, units, trial_cond_name): """ Get median event start times from all unit-trials from the specified "trial_cond_name" and "units" - aligned to GO CUE :param events: list of events """ events = list(events) + ['go'] event_types, event_times = ( psth.TrialCondition().get_trials(trial_cond_name) * (experiment.TrialEvent & [{ 'trial_event_type': eve } for eve in events]) & units).fetch('trial_event_type', 'trial_event_time') period_starts = [ np.nanmedian((event_times[event_types == event_type] - event_times[event_types == 'go']).astype(float)) for event_type in events[:-1] ] return period_starts
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None): """ For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: both ALM) """ units = units.proj() fig = None if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 hemi = _get_units_hemisphere(units) # no photostim: psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0] psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0] psth_n_l = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_n_l} & 'unit_psth is not NULL').fetch('unit_psth') psth_n_r = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_n_r} & 'unit_psth is not NULL').fetch('unit_psth') # with photostim psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_left'])[0] psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_right'])[0] psth_s_l = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_s_l} & 'unit_psth is not NULL').fetch('unit_psth') psth_s_r = (psth.UnitPsth * psth.TrialCondition & units & {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth') # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit') # get photostim onset and duration stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0] stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials(stim_trial_cond_name) & units).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) stim_time = _get_stim_onset_time(units, stim_trial_cond_name) if hemi == 'left': psth_s_i = psth_s_l psth_n_i = psth_n_l psth_s_c = psth_s_r psth_n_c = psth_n_r else: psth_s_i = psth_s_r psth_n_i = psth_n_r psth_s_c = psth_s_l psth_n_c = psth_n_l _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control') _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Photostim') # cosmetic ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) ax.set_xlim([_plt_xmin, _plt_xmax]) # add shaded bar for photostim axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue') return fig
def plot_psth_bilateral_photostim_effect(units, axs=None): units = units.proj() hemi = _get_units_hemisphere(units) psth_s_l = ( psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': 'all_noearlylick_both_alm_stim_left' }).fetch('unit_psth') psth_n_l = ( psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim_left' }).fetch('unit_psth') psth_s_r = ( psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': 'all_noearlylick_both_alm_stim_right' }).fetch('unit_psth') psth_n_r = ( psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim_right' }).fetch('unit_psth') # get event start times: sample, delay, response period_names, period_starts = _get_trial_event_times( ['sample', 'delay', 'go'], units, 'good_noearlylick_hit') # get photostim duration stim_durs = np.unique( (experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & units).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) if hemi == 'left': psth_s_i = psth_s_l psth_n_i = psth_n_l psth_s_c = psth_s_r psth_n_c = psth_n_r else: psth_s_i = psth_s_r psth_n_i = psth_n_r psth_s_c = psth_s_l psth_n_c = psth_n_l fig = None if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control') _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Bilateral ALM photostim') # cosmetic ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) # add shaded bar for photostim stim_time = period_starts[np.where(period_names == 'delay')[0][0]] axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue') return fig
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None): probe_insertion = probe_insertion.proj() if not (psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion): raise PhotostimError('No Bilateral ALM Photo-stimulation present') if clustering_method is None: try: clustering_method = _get_clustering_method(probe_insertion) except ValueError as e: raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"') dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('depth') no_stim_cond = (psth.TrialCondition & {'trial_condition_name': 'all_noearlylick_nostim'}).fetch1('KEY') bi_stim_cond = (psth.TrialCondition & {'trial_condition_name': 'all_noearlylick_both_alm_stim'}).fetch1('KEY') units = ephys.Unit & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"' metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) # get photostim onset and duration stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) stim_time = _get_stim_onset_time(units, 'all_noearlylick_both_alm_stim') # XXX: could be done with 1x fetch+join for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')): if clustering_method in ('kilosort2'): x, y = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj() * lab.ProbeType.Electrode.proj('x_coord', 'y_coord') & unit).fetch1('x_coord', 'y_coord') else: x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy') # obtain unit psth per trial, for all nostim and bistim trials nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(no_stim_cond['trial_condition_name']) bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(bi_stim_cond['trial_condition_name']) nostim_psths, nostim_edge = psth.compute_unit_psth(unit, nostim_trials.fetch('KEY'), per_trial=True) bistim_psths, bistim_edge = psth.compute_unit_psth(unit, bistim_trials.fetch('KEY'), per_trial=True) # compute the firing rate difference between contra vs. ipsi within the stimulation time window ctrl_frate = np.array([nostim_psth[np.logical_and(nostim_edge >= stim_time, nostim_edge <= stim_time + stim_dur)].mean() for nostim_psth in nostim_psths]) stim_frate = np.array([bistim_psth[np.logical_and(bistim_edge >= stim_time, bistim_edge <= stim_time + stim_dur)].mean() for bistim_psth in bistim_psths]) frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean() frate_change = abs(frate_change) if frate_change < 0 else 0.0001 metrics.loc[u_idx] = (int(unit['unit']), x, float(dv_loc) + y, frate_change) metrics.frate_change = metrics.frate_change / metrics.frate_change.max() # --- prepare for plotting shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode, shank_count='count(distinct shank)').fetch1('shank_count') m_scale = get_m_scale(shank_count) fig = None if axs is None: fig, axs = plt.subplots(1, 1, figsize=(4, 8)) xmax = 1.3 * metrics.x.max() xmin = -1/6*xmax cosmetic = {'legend': None, 'linewidth': 1.75, 'alpha': 0.9, 'facecolor': 'none', 'edgecolor': 'k'} sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change*m_scale, ax=axs, **cosmetic) axs.spines['right'].set_visible(False) axs.spines['top'].set_visible(False) axs.set_title('% change') axs.set_xlim((xmin, xmax)) return fig
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None): probe_insertion = probe_insertion.proj() if clustering_method is None: try: clustering_method = _get_clustering_method(probe_insertion) except ValueError as e: raise ValueError( str(e) + '\nPlease specify one with the kwarg "clustering_method"') dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('dv_location') no_stim_cond = ( psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim' }).fetch1('KEY') bi_stim_cond = (psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_both_alm_stim' }).fetch1('KEY') # get photostim duration stim_durs = np.unique( (experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) units = ephys.Unit & probe_insertion & { 'clustering_method': clustering_method } & 'unit_quality != "all"' metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) _, cue_onset = _get_trial_event_times(['delay'], units, 'all_noearlylick_both_alm_nostim') cue_onset = cue_onset[0] # XXX: could be done with 1x fetch+join for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')): x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy') # obtain unit psth per trial, for all nostim and bistim trials nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials( no_stim_cond['trial_condition_name']) bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials( bi_stim_cond['trial_condition_name']) nostim_psths, nostim_edge = psth.compute_unit_psth( unit, nostim_trials.fetch('KEY'), per_trial=True) bistim_psths, bistim_edge = psth.compute_unit_psth( unit, bistim_trials.fetch('KEY'), per_trial=True) # compute the firing rate difference between contra vs. ipsi within the stimulation duration ctrl_frate = np.array([ nostim_psth[np.logical_and( nostim_edge >= cue_onset, nostim_edge <= cue_onset + stim_dur)].mean() for nostim_psth in nostim_psths ]) stim_frate = np.array([ bistim_psth[np.logical_and( bistim_edge >= cue_onset, bistim_edge <= cue_onset + stim_dur)].mean() for bistim_psth in bistim_psths ]) frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean() frate_change = abs(frate_change) if frate_change < 0 else 0.0001 metrics.loc[u_idx] = (int(unit['unit']), x, y - dv_loc, frate_change) metrics.frate_change = metrics.frate_change / metrics.frate_change.max() fig = None if axs is None: fig, axs = plt.subplots(1, 1, figsize=(4, 8)) cosmetic = { 'legend': None, 'linewidth': 1.75, 'alpha': 0.9, 'facecolor': 'none', 'edgecolor': 'k' } sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change * m_scale, ax=axs, **cosmetic) axs.spines['right'].set_visible(False) axs.spines['top'].set_visible(False) axs.set_title('% change') axs.set_xlim((-10, 60)) return fig
def plot_psth_bilateral_photostim_effect(probe_insert_key, axs=None): if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 insert = (ephys.ProbeInsertion.InsertionLocation * experiment.BrainLocation & probe_insert_key).fetch1() period_starts = ( experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start') psth_s_l = ( psth.UnitPsth * psth.TrialCondition & probe_insert_key & { 'trial_condition_name': 'all_noearlylick_both_alm_stim_left' }).fetch('unit_psth') psth_n_l = ( psth.UnitPsth * psth.TrialCondition & probe_insert_key & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim_left' }).fetch('unit_psth') psth_s_r = ( psth.UnitPsth * psth.TrialCondition & probe_insert_key & { 'trial_condition_name': 'all_noearlylick_both_alm_stim_right' }).fetch('unit_psth') psth_n_r = ( psth.UnitPsth * psth.TrialCondition & probe_insert_key & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim_right' }).fetch('unit_psth') # get photostim duration stim_durs = np.unique( (experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insert_key).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) if insert['hemisphere'] == 'left': psth_s_i = psth_s_l psth_n_i = psth_n_l psth_s_c = psth_s_r psth_n_c = psth_n_r else: psth_s_i = psth_s_r psth_n_i = psth_n_r psth_s_c = psth_s_l psth_n_c = psth_n_l _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control') _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Bilateral ALM photostim') # cosmetic ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) # add shaded bar for photostim delay = ( experiment.Period # TODO: use from period_starts & 'period = "delay"').fetch1('period_start') axs[1].axvspan(delay, delay + stim_dur, alpha=0.3, color='royalblue')
def plot_unit_bilateral_photostim_effect(probe_insert_key, axs=None): cue_onset = (experiment.Period & 'period = "delay"').fetch1('period_start') no_stim_cond = ( psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_both_alm_nostim' }).fetch1('KEY') bi_stim_cond = (psth.TrialCondition & { 'trial_condition_name': 'all_noearlylick_both_alm_stim' }).fetch1('KEY') # get photostim duration stim_durs = np.unique( (experiment.Photostim & experiment.PhotostimEvent * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insert_key).fetch('duration')) stim_dur = _extract_one_stim_dur(stim_durs) units = ephys.Unit & probe_insert_key & 'unit_quality != "all"' metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) # XXX: could be done with 1x fetch+join for u_idx, unit in enumerate(units.fetch('KEY')): x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy') nostim_psth, nostim_edge = (psth.UnitPsth & { **unit, **no_stim_cond }).fetch1('unit_psth') bistim_psth, bistim_edge = (psth.UnitPsth & { **unit, **bi_stim_cond }).fetch1('unit_psth') # compute the firing rate difference between contra vs. ipsi within the stimulation duration ctrl_frate = nostim_psth[np.logical_and( nostim_edge[1:] >= cue_onset, nostim_edge[1:] <= cue_onset + stim_dur)] stim_frate = bistim_psth[np.logical_and( bistim_edge[1:] >= cue_onset, bistim_edge[1:] <= cue_onset + stim_dur)] frate_change = np.abs(stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean() metrics.loc[u_idx] = (int(unit['unit']), x, y, frate_change) metrics.frate_change = metrics.frate_change / metrics.frate_change.max() if axs is None: fig, axs = plt.subplots(1, 1, figsize=(4, 8)) cosmetic = { 'legend': None, 'linewidth': 1.75, 'alpha': 0.9, 'facecolor': 'none', 'edgecolor': 'k' } sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change * m_scale, ax=axs, **cosmetic) axs.spines['right'].set_visible(False) axs.spines['top'].set_visible(False) axs.set_title('% change') axs.set_xlim((-10, 60))
def plot_selectivity_change_photostim_effect(units, condition_name_kw, recover_time_window=None, ax=None): """ For each unit in the specified units, extract: + control, left-instruct PSTH (ctrl_left) + control, right-instruct PSTH (ctrl_right) + stim, left-instruct PSTH (stim_left) + stim, right-instruct PSTH (stim_right) Then, control_PSTH and stim_PSTH is defined as (ctrl_left - ctrl_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa (stim_left - stim_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa Selectivity change is then defined as: control_PSTH - stim_PSTH """ trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords( ['good_noearlylick_', '_hit'])[0] period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, trial_cond_name) stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords( condition_name_kw + ['_stim'])[0] stim_time, stim_dur = _get_photostim_time_and_duration( units, psth.TrialCondition().get_trials(stim_trial_cond_name)) ctrl_left_cond_name = 'all_noearlylick_nostim_left' ctrl_right_cond_name = 'all_noearlylick_nostim_right' stim_left_cond_name = psth.TrialCondition().get_cond_name_from_keywords( condition_name_kw + ['noearlylick', 'stim', 'left'])[0] stim_right_cond_name = psth.TrialCondition().get_cond_name_from_keywords( condition_name_kw + ['noearlylick', 'stim', 'right'])[0] delta_sels, ctrl_psths = [], [] for unit in (units * psth.UnitSelectivity & 'unit_selectivity != "non-selective"' ).proj('unit_selectivity').fetch(as_dict=True): # ---- trial count criteria ---- # no less than 5 trials for control if (len(psth.TrialCondition.get_trials(ctrl_left_cond_name) & unit) < 5 or len( psth.TrialCondition.get_trials(ctrl_right_cond_name) & unit) < 5): continue # no less than 2 trials for stimulation if (len(psth.TrialCondition.get_trials(stim_left_cond_name) & unit) < 2 or len( psth.TrialCondition.get_trials(stim_right_cond_name) & unit) < 2): continue hemi = _get_units_hemisphere(unit) ctrl_left_psth, t_vec = psth.UnitPsth.get_plotting_data( unit, {'trial_condition_name': ctrl_left_cond_name})['psth'] ctrl_right_psth, _ = psth.UnitPsth.get_plotting_data( unit, {'trial_condition_name': ctrl_right_cond_name})['psth'] try: stim_left_psth, _ = psth.UnitPsth.get_plotting_data( unit, {'trial_condition_name': stim_left_cond_name})['psth'] stim_right_psth, _ = psth.UnitPsth.get_plotting_data( unit, {'trial_condition_name': stim_right_cond_name})['psth'] except: continue if unit['unit_selectivity'] == 'ipsi-selective': ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'left' else ctrl_right_psth - ctrl_left_psth stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'left' else stim_right_psth - stim_left_psth elif unit['unit_selectivity'] == 'contra-selective': ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'right' else ctrl_right_psth - ctrl_left_psth stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'right' else stim_right_psth - stim_left_psth ctrl_psths.append(ctrl_psth_diff) delta_sels.append(ctrl_psth_diff - stim_psth_diff) ctrl_psths = np.vstack(ctrl_psths) delta_sels = np.vstack(delta_sels) if ax is None: fig, ax = plt.subplots(1, 1, figsize=(4, 6)) _plot_with_sem(delta_sels, t_vec, ax) if recover_time_window: recovery_times = [] for i in range(1000): i_sample = np.random.choice(delta_sels.shape[0], delta_sels.shape[0], replace=True) btstrp_diff = np.nanmean(delta_sels[i_sample, :], axis=0) / np.nanmean( ctrl_psths[i_sample, :], axis=0) t_recovered = t_vec[(btstrp_diff < 0.2) & (t_vec > recover_time_window[0]) & (t_vec < recover_time_window[1])] if len(t_recovered) > 0: recovery_times.append(t_recovered[0]) ax.axvline(x=np.mean(recovery_times), linestyle='--', color='g') ax.axvspan(np.mean(recovery_times) - np.std(recovery_times), np.mean(recovery_times) + np.std(recovery_times), alpha=0.2, color='g') ax.axhline(y=0, color='k') for x in period_starts: ax.axvline(x=x, linestyle='--', color='k') # add shaded bar for photostim ax.axvspan(stim_time, stim_time + stim_dur, 0.95, 1, alpha=0.3, color='royalblue') ax.set_ylabel('Selectivity change (spike/s)') ax.set_xlabel('Time (s)')
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None): """ For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: bilateral ALM) """ units = units.proj() if axs is None: fig, axs = plt.subplots(1, 2, figsize=(16, 6)) assert axs.size == 2 hemi = _get_units_hemisphere(units) period_starts = ( experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start') # no photostim: psth_n_l = psth.TrialCondition.get_cond_name_from_keywords( ['_nostim', '_left'])[0] psth_n_r = psth.TrialCondition.get_cond_name_from_keywords( ['_nostim', '_right'])[0] psth_n_l = (psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': psth_n_l } & 'unit_psth is not NULL').fetch('unit_psth') psth_n_r = (psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': psth_n_r } & 'unit_psth is not NULL').fetch('unit_psth') psth_s_l = psth.TrialCondition.get_cond_name_from_keywords( condition_name_kw + ['_stim_left'])[0] psth_s_r = psth.TrialCondition.get_cond_name_from_keywords( condition_name_kw + ['_stim_right'])[0] psth_s_l = (psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': psth_s_l } & 'unit_psth is not NULL').fetch('unit_psth') psth_s_r = (psth.UnitPsth * psth.TrialCondition & units & { 'trial_condition_name': psth_s_r } & 'unit_psth is not NULL').fetch('unit_psth') # get photostim duration and stim time (relative to go-cue) stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords( condition_name_kw + ['_stim'])[0] stim_time, stim_dur = _get_photostim_time_and_duration( units, psth.TrialCondition().get_trials(stim_trial_cond_name)) if hemi == 'left': psth_s_i = psth_s_l psth_n_i = psth_n_l psth_s_c = psth_s_r psth_n_c = psth_n_r else: psth_s_i = psth_s_r psth_n_i = psth_n_r psth_s_c = psth_s_l psth_n_c = psth_n_l _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control') _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Photostim') # cosmetic ymax = max([ax.get_ylim()[1] for ax in axs]) for ax in axs: ax.set_ylim((0, ymax)) ax.set_xlim([_plt_xmin, _plt_xmax]) # add shaded bar for photostim axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')