def plot_avg_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    good_unit = ephys.Unit & 'unit_quality != "all"'

    conds_i = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch('KEY')

    conds_c = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    psth_is_it = (((psth.UnitPsth & conds_i)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_is_ct = (((psth.UnitPsth & conds_c)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_ct = (((psth.UnitPsth & conds_c)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_it = (((psth.UnitPsth & conds_i)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    _plot_avg_psth(psth_cs_it, psth_cs_ct, period_starts, axs[0],
                   'Contra-selective')
    _plot_avg_psth(psth_is_it, psth_is_ct, period_starts, axs[1],
                   'Ipsi-selective')

    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))

    return fig
def plot_stacked_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    conds_i = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch1('KEY')

    conds_c = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch1('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    # ipsi selective ipsi trials
    psth_is_it = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc')

    # ipsi selective contra trials
    psth_is_ct = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc')

    # contra selective contra trials
    psth_cs_ct = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc')

    # contra selective ipsi trials
    psth_cs_it = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(20, 20))
    assert axs.size == 2

    _plot_stacked_psth_diff(psth_cs_ct, psth_cs_it, ax=axs[0], vlines=period_starts, flip=True)

    axs[0].set_title('Contra-selective Units')
    axs[0].set_ylabel('Unit (by depth)')
    axs[0].set_xlabel('Time to go (s)')

    _plot_stacked_psth_diff(psth_is_it, psth_is_ct, ax=axs[1], vlines=period_starts)

    axs[1].set_title('Ipsi-selective Units')
    axs[1].set_ylabel('Unit (by depth)')
    axs[1].set_xlabel('Time to go (s)')

    return fig
Пример #3
0
def plot_unit_psth(unit_key, axs=None, title='', xlim=_plt_xlim):
    """
    Default raster and PSTH plot for a specified unit - only {good, no early lick, correct trials} selected
    condition_name_kw: list of keywords to match for the TrialCondition name
    """

    hemi = _get_units_hemisphere(unit_key)

    ipsi_hit_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"left" if hemi == "left" else "right"}_hit'
        })

    contra_hit_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"right" if hemi == "left" else "left"}_hit'
        })

    ipsi_miss_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"left" if hemi == "left" else "right"}_miss'
        })

    contra_miss_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"right" if hemi == "left" else "left"}_miss'
        })

    # get event start times: sample, delay, response
    periods, period_starts = _get_trial_event_times(['sample', 'delay', 'go'],
                                                    unit_key,
                                                    'good_noearlylick_hit')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(2, 2)

    # correct response
    _plot_spike_raster(ipsi_hit_unit_psth,
                       contra_hit_unit_psth,
                       ax=axs[0, 0],
                       vlines=period_starts,
                       title=title if title else
                       f'Unit #: {unit_key["unit"]}\nCorrect Response',
                       xlim=xlim)
    _plot_psth(ipsi_hit_unit_psth,
               contra_hit_unit_psth,
               vlines=period_starts,
               ax=axs[1, 0],
               xlim=xlim)

    # incorrect response
    _plot_spike_raster(ipsi_miss_unit_psth,
                       contra_miss_unit_psth,
                       ax=axs[0, 1],
                       vlines=period_starts,
                       title=title if title else
                       f'Unit #: {unit_key["unit"]}\nIncorrect Response',
                       xlim=xlim)
    _plot_psth(ipsi_miss_unit_psth,
               contra_miss_unit_psth,
               vlines=period_starts,
               ax=axs[1, 1],
               xlim=xlim)

    return fig
Пример #4
0
def plot_unit_psth_choice_outcome(unit_key,
                                  align_types=[
                                      'trial_start', 'go_cue',
                                      'first_lick_after_go_cue', 'iti_start',
                                      'next_trial_start'
                                  ],
                                  if_raster=True,
                                  axs=None,
                                  title='',
                                  if_exclude_early_lick=False):
    """Plot psth grouped by (choice x outcome) for the foraging task.
     
    In general, PSTH is specificied by two things: trial conditions and alignment types. 

    Here, trial conditions include all four combinitions of choice (ipsi or contra) and outcome (hit or miss), 
    whereas align types are defined by the user. See `psth_foraging.AlignType`
    
    Parameters
    ----------
    unit_key : [type]
        [description]
    align_types : list, optional
        list of align_type_name in psth_foraging.AlignType, by default ['trial_start', 'go_cue', 'first_lick_after_go_cue', 'iti_start', 'next_trial_start']
    if_raster : bool, optional
        whether to plot raster, by default True
    axs : [type], optional
        [description], by default None
    title : str, optional
        [description], by default ''
    if_exclude_early_lick : bool, optional
        whether to exclude early licks, by default False

    Returns
    -------
    [type]
        [description]
        
    """

    # for (the very few) sessions without zaber feedback signal, use 'bitcodestart' with manual correction (see compute_unit_psth_and_raster)
    if not ephys.TrialEvent & unit_key & 'trial_event_type = "zaberready"':
        align_types = [
            a + '_bitcode' if 'trial_start' in a else a for a in align_types
        ]

    hemi = _get_units_hemisphere(unit_key)
    ipsi = "L" if hemi == "left" else "R"
    contra = "R" if hemi == "left" else "L"
    no_early_lick = '_noearlylick' if if_exclude_early_lick else ''

    unit_info = (
        f'{(lab.WaterRestriction & unit_key).fetch1("water_restriction_number")}, '
        f'{(experiment.Session & unit_key).fetch1("session_date")}, '
        f'imec {unit_key["insertion_number"]-1}\n'
        f'Unit #: {unit_key["unit"]}, '
        f'{(((ephys.Unit & unit_key) * histology.ElectrodeCCFPosition.ElectrodePosition) * ccf.CCFAnnotation).fetch1("annotation")}'
    )

    fig = None
    if axs is None:
        fig = plt.figure(figsize=(len(align_types) / 5 * 16,
                                  (1 + if_raster) / 2 * 9))
        axs = fig.subplots(1 + if_raster,
                           len(align_types),
                           sharey='row',
                           sharex='col')
        axs = np.atleast_2d(axs).reshape((1 + if_raster, -1))
    xlims = []

    for ax_i, align_type in enumerate(align_types):

        offset, xlim = (psth_foraging.AlignType & {
            'align_type_name': align_type
        }).fetch1('trial_offset', 'xlim')
        xlims.append(xlim)

        # align_trial_offset is added on the get_trials, which effectively
        # makes the psth conditioned on the previous {align_trial_offset} trials
        ipsi_hit_trials = psth_foraging.TrialCondition.get_trials(
            f'{ipsi}_hit{no_early_lick}', offset) & unit_key
        ipsi_hit_unit_psth = psth_foraging.compute_unit_psth_and_raster(
            unit_key, ipsi_hit_trials, align_type)

        contra_hit_trials = psth_foraging.TrialCondition.get_trials(
            f'{contra}_hit{no_early_lick}', offset) & unit_key
        contra_hit_unit_psth = psth_foraging.compute_unit_psth_and_raster(
            unit_key, contra_hit_trials, align_type)

        ipsi_miss_trials = psth_foraging.TrialCondition.get_trials(
            f'{ipsi}_miss{no_early_lick}', offset) & unit_key
        ipsi_miss_unit_psth = psth_foraging.compute_unit_psth_and_raster(
            unit_key, ipsi_miss_trials, align_type)

        contra_miss_trials = psth_foraging.TrialCondition.get_trials(
            f'{contra}_miss{no_early_lick}', offset) & unit_key
        contra_miss_unit_psth = psth_foraging.compute_unit_psth_and_raster(
            unit_key, contra_miss_trials, align_type)

        # --- plot psths (all 4 in one plot) ---
        ax_psth = axs[1 if if_raster else 0, ax_i]
        period_starts_hit = _get_ephys_trial_event_times(
            align_types,
            align_to=align_type,
            trial_keys=psth_foraging.TrialCondition.get_trials(
                f'LR_hit{no_early_lick}') & unit_key,
            # cannot use *_hit_trials because it could have been offset
        )
        # _, period_starts_miss = _get_ephys_trial_event_times([trialstart, 'go', 'choice', 'trialend'],
        #                                                   ipsi_miss_trials.proj() + contra_miss_trials.proj(), align_event=align_event_type)

        _plot_psth_foraging(ipsi_hit_unit_psth,
                            contra_hit_unit_psth,
                            vlines=period_starts_hit,
                            ax=ax_psth,
                            xlim=xlim,
                            label='rew',
                            linestyle='-')

        _plot_psth_foraging(ipsi_miss_unit_psth,
                            contra_miss_unit_psth,
                            vlines=[],
                            ax=ax_psth,
                            xlim=xlim,
                            label='norew',
                            linestyle='--')

        ax_psth.set(title=f'{align_type}')
        if ax_i > 0:
            ax_psth.spines['left'].set_visible(False)
            ax_psth.get_yaxis().set_visible(False)

        # --- plot rasters (optional) ---
        if if_raster:
            ax_raster = axs[0, ax_i]
            _plot_spike_raster_foraging(ipsi_hit_unit_psth,
                                        contra_hit_unit_psth,
                                        ax=ax_raster,
                                        offset=0,
                                        vlines=period_starts_hit,
                                        title='',
                                        xlim=xlim)
            _plot_spike_raster_foraging(
                ipsi_miss_unit_psth,
                contra_miss_unit_psth,
                ax=ax_raster,
                offset=len(ipsi_hit_unit_psth['trials']) +
                len(contra_hit_unit_psth['trials']),
                vlines=[],
                title='' if ax_i > 0 else unit_info,
                xlim=xlim)
            ax_raster.invert_yaxis()

    # Scale axis widths to keep the same horizontal aspect ratio (time) across axs
    _set_same_horizonal_aspect_ratio(axs[1 if if_raster else 0, :], xlims)
    if if_raster:
        _set_same_horizonal_aspect_ratio(axs[0, :], xlims)
    ax_psth.legend(fontsize=8)

    return fig
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
    """
    For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction
    The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: both ALM)
    """
    units = units.proj()

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    hemi = _get_units_hemisphere(units)

    # no photostim:
    psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0]
    psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0]

    psth_n_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_n_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # with photostim
    psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_left'])[0]
    psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_right'])[0]

    psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim onset and duration
    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0]
    stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
                           * psth.TrialCondition().get_trials(stim_trial_cond_name)
                           & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)
    stim_time = _get_stim_onset_time(units, stim_trial_cond_name)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0],
                   'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])

    # add shaded bar for photostim
    axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')

    return fig