Example #1
0
def plot_avg_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units,
                                           'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    good_unit = ephys.Unit & 'unit_quality != "all"'

    conds_i = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_left_hit'
                   if hemi == 'left' else 'good_noearlylick_right_hit'
               }).fetch('KEY')

    conds_c = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_right_hit'
                   if hemi == 'left' else 'good_noearlylick_left_hit'
               }).fetch('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    psth_is_it = (((psth.UnitPsth & conds_i) * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_is_ct = (((psth.UnitPsth & conds_c) * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_ct = (((psth.UnitPsth & conds_c) * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_it = (((psth.UnitPsth & conds_i) * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    _plot_avg_psth(psth_cs_it, psth_cs_ct, period_starts, axs[0],
                   'Contra-selective')
    _plot_avg_psth(psth_is_it, psth_is_ct, period_starts, axs[1],
                   'Ipsi-selective')

    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])
Example #2
0
def plot_unit_psth(unit_key,
                   condition_name_kw=['good_noearlylick_', '_hit'],
                   axs=None,
                   title='',
                   xlim=_plt_xlim):
    """
    Default raster and PSTH plot for a specified unit - only {good, no early lick, correct trials} selected
    condition_name_kw: list of keywords to match for the TrialCondition name
    """

    hemi = _get_units_hemisphere(unit_key)

    ipsi_cond_name = TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['left' if hemi == 'left' else 'right'])[0]
    contra_cond_name = TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['right' if hemi == 'left' else 'left'])[0]

    ipsi_hit_unit_psth = UnitPsth.get_plotting_data(
        unit_key, {'trial_condition_name': ipsi_cond_name})

    contra_hit_unit_psth = UnitPsth.get_plotting_data(
        unit_key, {'trial_condition_name': contra_cond_name})

    # get event start times: sample, delay, response
    trial_cond_name = TrialCondition.get_cond_name_from_keywords(
        condition_name_kw)[0]
    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_key,
                                           trial_cond_name)

    # photostim shaded bar (if applicable)
    try:
        stim_trial_cond_name = TrialCondition.get_cond_name_from_keywords(
            condition_name_kw + ['_stim'])[0]
        stim_bar = _get_photostim_time_and_duration(
            unit_key,
            TrialCondition().get_trials(stim_trial_cond_name))
    except:
        stim_bar = None

    if axs is None:
        fig, axs = plt.subplots(2, 1)

    _plot_spike_raster(ipsi_hit_unit_psth,
                       contra_hit_unit_psth,
                       ax=axs[0],
                       vlines=period_starts,
                       shade_bar=stim_bar,
                       title=title if title else f'Unit #: {unit_key["unit"]}',
                       xlim=xlim)
    _plot_psth(ipsi_hit_unit_psth,
               contra_hit_unit_psth,
               vlines=period_starts,
               shade_bar=stim_bar,
               ax=axs[1],
               xlim=xlim)
Example #3
0
def plot_psth_photostim_effect(units,
                               condition_name_kw=['both_alm'],
                               axs=None):
    """
    For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction
    The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: bilateral ALM)
    """
    units = units.proj()

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    hemi = _get_units_hemisphere(units)

    # no photostim:
    psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(
        ['_nostim', '_left'])[0]
    psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(
        ['_nostim', '_right'])[0]

    psth_n_l = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_n_l
                } & 'unit_psth is not NULL').fetch('unit_psth')
    psth_n_r = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_n_r
                } & 'unit_psth is not NULL').fetch('unit_psth')

    psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim_left'])[0]
    psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim_right'])[0]

    psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_s_l
                } & 'unit_psth is not NULL').fetch('unit_psth')
    psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_s_r
                } & 'unit_psth is not NULL').fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(
        ['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim duration
    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim'])[0]
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials(stim_trial_cond_name)
         & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])

    # add shaded bar for photostim
    stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
    axs[1].axvspan(stim_time,
                   stim_time + stim_dur,
                   alpha=0.3,
                   color='royalblue')

    return fig
Example #4
0
def plot_psth_bilateral_photostim_effect(units, axs=None):
    units = units.proj()

    hemi = _get_units_hemisphere(units)

    psth_s_l = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_left'
        }).fetch('unit_psth')

    psth_n_l = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_left'
        }).fetch('unit_psth')

    psth_s_r = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_right'
        }).fetch('unit_psth')

    psth_n_r = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_right'
        }).fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(
        ['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Bilateral ALM photostim')
    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))

    # add shaded bar for photostim
    stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
    axs[1].axvspan(stim_time,
                   stim_time + stim_dur,
                   alpha=0.3,
                   color='royalblue')

    return fig
Example #5
0
def plot_stacked_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(
        ['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    conds_i = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_left_hit'
                   if hemi == 'left' else 'good_noearlylick_right_hit'
               }).fetch1('KEY')

    conds_c = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_right_hit'
                   if hemi == 'left' else 'good_noearlylick_left_hit'
               }).fetch1('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    # ipsi selective ipsi trials
    psth_is_it = (psth.UnitPsth * sel_i.proj('unit_posy')
                  & conds_i).fetch(order_by='unit_posy desc')

    # ipsi selective contra trials
    psth_is_ct = (psth.UnitPsth * sel_i.proj('unit_posy')
                  & conds_c).fetch(order_by='unit_posy desc')

    # contra selective contra trials
    psth_cs_ct = (psth.UnitPsth * sel_c.proj('unit_posy')
                  & conds_c).fetch(order_by='unit_posy desc')

    # contra selective ipsi trials
    psth_cs_it = (psth.UnitPsth * sel_c.proj('unit_posy')
                  & conds_i).fetch(order_by='unit_posy desc')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(20, 20))
    assert axs.size == 2

    _plot_stacked_psth_diff(psth_cs_ct,
                            psth_cs_it,
                            ax=axs[0],
                            vlines=period_starts,
                            flip=True)

    axs[0].set_title('Contra-selective Units')
    axs[0].set_ylabel('Unit (by depth)')
    axs[0].set_xlabel('Time to go (s)')

    _plot_stacked_psth_diff(psth_is_it,
                            psth_is_ct,
                            ax=axs[1],
                            vlines=period_starts)

    axs[1].set_title('Ipsi-selective Units')
    axs[1].set_ylabel('Unit (by depth)')
    axs[1].set_xlabel('Time to go (s)')

    return fig
Example #6
0
def plot_selectivity_change_photostim_effect(units,
                                             condition_name_kw,
                                             recover_time_window=None,
                                             ax=None):
    """
    For each unit in the specified units, extract:
    + control, left-instruct PSTH (ctrl_left)
    + control, right-instruct PSTH (ctrl_right)
    + stim, left-instruct PSTH (stim_left)
    + stim, right-instruct PSTH (stim_right)
    Then, control_PSTH and stim_PSTH is defined as
        (ctrl_left - ctrl_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa
        (stim_left - stim_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa
    Selectivity change is then defined as: control_PSTH - stim_PSTH
    """
    trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        ['good_noearlylick_', '_hit'])[0]
    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units,
                                           trial_cond_name)

    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim'])[0]
    stim_time, stim_dur = _get_photostim_time_and_duration(
        units,
        psth.TrialCondition().get_trials(stim_trial_cond_name))

    ctrl_left_cond_name = 'all_noearlylick_nostim_left'
    ctrl_right_cond_name = 'all_noearlylick_nostim_right'
    stim_left_cond_name = psth.TrialCondition().get_cond_name_from_keywords(
        condition_name_kw + ['noearlylick', 'stim', 'left'])[0]
    stim_right_cond_name = psth.TrialCondition().get_cond_name_from_keywords(
        condition_name_kw + ['noearlylick', 'stim', 'right'])[0]

    delta_sels, ctrl_psths = [], []
    for unit in (units * psth.UnitSelectivity
                 & 'unit_selectivity != "non-selective"'
                 ).proj('unit_selectivity').fetch(as_dict=True):
        # ---- trial count criteria ----
        # no less than 5 trials for control
        if (len(psth.TrialCondition.get_trials(ctrl_left_cond_name) & unit) < 5
                or len(
                    psth.TrialCondition.get_trials(ctrl_right_cond_name)
                    & unit) < 5):
            continue
        # no less than 2 trials for stimulation
        if (len(psth.TrialCondition.get_trials(stim_left_cond_name) & unit) < 2
                or len(
                    psth.TrialCondition.get_trials(stim_right_cond_name)
                    & unit) < 2):
            continue

        hemi = _get_units_hemisphere(unit)

        ctrl_left_psth, t_vec = psth.UnitPsth.get_plotting_data(
            unit, {'trial_condition_name': ctrl_left_cond_name})['psth']
        ctrl_right_psth, _ = psth.UnitPsth.get_plotting_data(
            unit, {'trial_condition_name': ctrl_right_cond_name})['psth']
        try:
            stim_left_psth, _ = psth.UnitPsth.get_plotting_data(
                unit, {'trial_condition_name': stim_left_cond_name})['psth']
            stim_right_psth, _ = psth.UnitPsth.get_plotting_data(
                unit, {'trial_condition_name': stim_right_cond_name})['psth']
        except:
            continue

        if unit['unit_selectivity'] == 'ipsi-selective':
            ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'left' else ctrl_right_psth - ctrl_left_psth
            stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'left' else stim_right_psth - stim_left_psth
        elif unit['unit_selectivity'] == 'contra-selective':
            ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'right' else ctrl_right_psth - ctrl_left_psth
            stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'right' else stim_right_psth - stim_left_psth

        ctrl_psths.append(ctrl_psth_diff)
        delta_sels.append(ctrl_psth_diff - stim_psth_diff)

    ctrl_psths = np.vstack(ctrl_psths)
    delta_sels = np.vstack(delta_sels)

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(4, 6))

    _plot_with_sem(delta_sels, t_vec, ax)

    if recover_time_window:
        recovery_times = []
        for i in range(1000):
            i_sample = np.random.choice(delta_sels.shape[0],
                                        delta_sels.shape[0],
                                        replace=True)
            btstrp_diff = np.nanmean(delta_sels[i_sample, :],
                                     axis=0) / np.nanmean(
                                         ctrl_psths[i_sample, :], axis=0)
            t_recovered = t_vec[(btstrp_diff < 0.2)
                                & (t_vec > recover_time_window[0]) &
                                (t_vec < recover_time_window[1])]
            if len(t_recovered) > 0:
                recovery_times.append(t_recovered[0])
        ax.axvline(x=np.mean(recovery_times), linestyle='--', color='g')
        ax.axvspan(np.mean(recovery_times) - np.std(recovery_times),
                   np.mean(recovery_times) + np.std(recovery_times),
                   alpha=0.2,
                   color='g')

    ax.axhline(y=0, color='k')
    for x in period_starts:
        ax.axvline(x=x, linestyle='--', color='k')
    # add shaded bar for photostim
    ax.axvspan(stim_time,
               stim_time + stim_dur,
               0.95,
               1,
               alpha=0.3,
               color='royalblue')
    ax.set_ylabel('Selectivity change (spike/s)')
    ax.set_xlabel('Time (s)')
Example #7
0
def plot_selectivity_sorted_stacked_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(20, 20))
    assert axs.size == 2

    trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        ['good_noearlylick_', '_hit'])[0]
    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units,
                                           trial_cond_name)

    hemi = _get_units_hemisphere(units)

    conds_i = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_left_hit'
                   if hemi == 'left' else 'good_noearlylick_right_hit'
               }).fetch1('KEY')

    conds_c = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_right_hit'
                   if hemi == 'left' else 'good_noearlylick_left_hit'
               }).fetch1('KEY')

    # ---- separate units to:
    # i) sample or delay not response:
    sample_delay_units = units & (psth.PeriodSelectivity
                                  & 'period in ("sample", "delay")'
                                  & 'period_selectivity != "non-selective"')
    sample_delay_units = sample_delay_units & (
        psth.PeriodSelectivity & units
        & 'period = "response"'
        & 'period_selectivity = "non-selective"')
    # ii) sample or delay and response:
    sample_delay_response_units = units & (
        psth.PeriodSelectivity
        & 'period in ("sample", "delay")'
        & 'period_selectivity != "non-selective"')
    sample_delay_response_units = sample_delay_response_units & (
        psth.PeriodSelectivity & units
        & 'period = "response"'
        & 'period_selectivity != "non-selective"')
    # iii) not sample nor delay and response:
    response_units = (units & (psth.PeriodSelectivity & 'period in ("sample")'
                               & 'period_selectivity = "non-selective"')
                      & (psth.PeriodSelectivity & 'period in ("delay")'
                         & 'period_selectivity = "non-selective"'))
    response_units = response_units & (
        psth.PeriodSelectivity & units
        & 'period = "response"'
        & 'period_selectivity != "non-selective"')

    ipsi_selective_psth, contra_selective_psth = [], []
    for units in (sample_delay_units, sample_delay_response_units,
                  response_units):
        sel_i = (ephys.Unit * psth.UnitSelectivity
                 & 'unit_selectivity = "ipsi-selective"' & units)
        sel_c = (ephys.Unit * psth.UnitSelectivity
                 & 'unit_selectivity = "contra-selective"' & units)

        # ipsi selective ipsi trials
        psth_is_it = (psth.UnitPsth * sel_i & conds_i).fetch()
        # ipsi selective contra trials
        psth_is_ct = (psth.UnitPsth * sel_i & conds_c).fetch()
        # contra selective contra trials
        psth_cs_ct = (psth.UnitPsth * sel_c & conds_c).fetch()
        # contra selective ipsi trials
        psth_cs_it = (psth.UnitPsth * sel_c & conds_i).fetch()

        contra_selective_psth.append(
            _plot_stacked_psth_diff(psth_cs_ct,
                                    psth_cs_it,
                                    ax=axs[0],
                                    flip=True,
                                    plot=False))
        ipsi_selective_psth.append(
            _plot_stacked_psth_diff(psth_is_it,
                                    psth_is_ct,
                                    ax=axs[1],
                                    plot=False))

    contra_boundaries = np.cumsum(
        [len(k) for k in contra_selective_psth[::-1]])
    ipsi_boundaries = np.cumsum([len(k) for k in ipsi_selective_psth[::-1]])

    contra_selective_psth = np.vstack(contra_selective_psth)
    ipsi_selective_psth = np.vstack(ipsi_selective_psth)

    xlim = -3, 2
    im = axs[0].imshow(contra_selective_psth,
                       cmap=plt.cm.bwr,
                       aspect=4.5 / contra_selective_psth.shape[0],
                       extent=[-3, 3, 0, contra_selective_psth.shape[0]])
    im.set_clim((-1, 1))

    im = axs[1].imshow(ipsi_selective_psth,
                       cmap=plt.cm.bwr,
                       aspect=4.5 / ipsi_selective_psth.shape[0],
                       extent=[-3, 3, 0, ipsi_selective_psth.shape[0]])
    im.set_clim((-1, 1))

    # cosmetic
    for ax, title, hspans in zip(
            axs, ('Contra-selective Units', 'Ipsi-selective Units'),
        (contra_boundaries, ipsi_boundaries)):
        for x in period_starts:
            ax.axvline(x=x, linestyle='--', color='k')
        ax.set_title(title)
        ax.set_ylabel('Unit')
        ax.set_xlabel('Time to go-cue (s)')
        ax.set_xlim(xlim)
        for ystart, ystop, color in zip([0] + list(hspans[:-1]), hspans,
                                        ('k', 'grey', 'w')):
            ax.axhspan(ystart, ystop, 0.98, 1, alpha=1, color=color)
Example #8
0
def plot_stacked_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(20, 20))
    assert axs.size == 2

    trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        ['good_noearlylick_', '_hit'])[0]
    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units,
                                           trial_cond_name)

    hemi = _get_units_hemisphere(units)

    conds_i = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_left_hit'
                   if hemi == 'left' else 'good_noearlylick_right_hit'
               }).fetch1('KEY')

    conds_c = (psth.TrialCondition
               & {
                   'trial_condition_name':
                   'good_noearlylick_right_hit'
                   if hemi == 'left' else 'good_noearlylick_left_hit'
               }).fetch1('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    # ipsi selective ipsi trials
    psth_is_it = (psth.UnitPsth * sel_i.proj('unit_posy')
                  & conds_i).fetch(order_by='unit_posy desc')
    # ipsi selective contra trials
    psth_is_ct = (psth.UnitPsth * sel_i.proj('unit_posy')
                  & conds_c).fetch(order_by='unit_posy desc')
    # contra selective contra trials
    psth_cs_ct = (psth.UnitPsth * sel_c.proj('unit_posy')
                  & conds_c).fetch(order_by='unit_posy desc')
    # contra selective ipsi trials
    psth_cs_it = (psth.UnitPsth * sel_c.proj('unit_posy')
                  & conds_i).fetch(order_by='unit_posy desc')

    _plot_stacked_psth_diff(psth_cs_ct,
                            psth_cs_it,
                            ax=axs[0],
                            vlines=period_starts,
                            flip=True)
    _plot_stacked_psth_diff(psth_is_it,
                            psth_is_ct,
                            ax=axs[1],
                            vlines=period_starts)

    # cosmetic
    for ax, title in zip(axs,
                         ('Contra-selective Units', 'Ipsi-selective Units')):
        ax.set_title(title)
        ax.set_ylabel('Unit')
        ax.set_xlabel('Time to go-cue (s)')
        ax.set_xlim([_plt_xmin, _plt_xmax])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)