Ejemplo n.º 1
0
def _get_trial_event_times(events, units, trial_cond_name):
    """
    Get median event start times from all unit-trials from the specified "trial_cond_name" and "units" - aligned to GO CUE
    :param events: list of events
    """
    events = list(events) + ['go']

    event_types, event_times = (
        psth.TrialCondition().get_trials(trial_cond_name) *
        (experiment.TrialEvent & [{
            'trial_event_type': eve
        } for eve in events])
        & units).fetch('trial_event_type', 'trial_event_time')
    period_starts = [
        np.nanmedian((event_times[event_types == event_type] -
                      event_times[event_types == 'go']).astype(float))
        for event_type in events[:-1]
    ]
    return period_starts
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
    """
    For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction
    The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: both ALM)
    """
    units = units.proj()

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    hemi = _get_units_hemisphere(units)

    # no photostim:
    psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0]
    psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0]

    psth_n_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_n_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # with photostim
    psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_left'])[0]
    psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_right'])[0]

    psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim onset and duration
    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0]
    stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
                           * psth.TrialCondition().get_trials(stim_trial_cond_name)
                           & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)
    stim_time = _get_stim_onset_time(units, stim_trial_cond_name)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0],
                   'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])

    # add shaded bar for photostim
    axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')

    return fig
Ejemplo n.º 3
0
def plot_psth_bilateral_photostim_effect(units, axs=None):
    units = units.proj()

    hemi = _get_units_hemisphere(units)

    psth_s_l = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_left'
        }).fetch('unit_psth')

    psth_n_l = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_left'
        }).fetch('unit_psth')

    psth_s_r = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_right'
        }).fetch('unit_psth')

    psth_n_r = (
        psth.UnitPsth * psth.TrialCondition & units
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_right'
        }).fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(
        ['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Bilateral ALM photostim')
    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))

    # add shaded bar for photostim
    stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
    axs[1].axvspan(stim_time,
                   stim_time + stim_dur,
                   alpha=0.3,
                   color='royalblue')

    return fig
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None):
    probe_insertion = probe_insertion.proj()

    if not (psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion):
        raise PhotostimError('No Bilateral ALM Photo-stimulation present')

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

    dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('depth')

    no_stim_cond = (psth.TrialCondition
                    & {'trial_condition_name':
                       'all_noearlylick_nostim'}).fetch1('KEY')

    bi_stim_cond = (psth.TrialCondition
                    & {'trial_condition_name':
                       'all_noearlylick_both_alm_stim'}).fetch1('KEY')

    units = ephys.Unit & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"'

    metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

    # get photostim onset and duration
    stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
                           * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
                           & probe_insertion).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)
    stim_time = _get_stim_onset_time(units, 'all_noearlylick_both_alm_stim')

    # XXX: could be done with 1x fetch+join
    for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):
        if clustering_method in ('kilosort2'):
            x, y = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj()
                    * lab.ProbeType.Electrode.proj('x_coord', 'y_coord') & unit).fetch1('x_coord', 'y_coord')
        else:
            x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

        # obtain unit psth per trial, for all nostim and bistim trials
        nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(no_stim_cond['trial_condition_name'])
        bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(bi_stim_cond['trial_condition_name'])

        nostim_psths, nostim_edge = psth.compute_unit_psth(unit, nostim_trials.fetch('KEY'), per_trial=True)
        bistim_psths, bistim_edge = psth.compute_unit_psth(unit, bistim_trials.fetch('KEY'), per_trial=True)

        # compute the firing rate difference between contra vs. ipsi within the stimulation time window
        ctrl_frate = np.array([nostim_psth[np.logical_and(nostim_edge >= stim_time,
                                                          nostim_edge <= stim_time + stim_dur)].mean()
                               for nostim_psth in nostim_psths])
        stim_frate = np.array([bistim_psth[np.logical_and(bistim_edge >= stim_time,
                                                          bistim_edge <= stim_time + stim_dur)].mean()
                               for bistim_psth in bistim_psths])

        frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean()
        frate_change = abs(frate_change) if frate_change < 0 else 0.0001

        metrics.loc[u_idx] = (int(unit['unit']), x, float(dv_loc) + y, frate_change)

    metrics.frate_change = metrics.frate_change / metrics.frate_change.max()

    # --- prepare for plotting
    shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode,
                                                                shank_count='count(distinct shank)').fetch1('shank_count')
    m_scale = get_m_scale(shank_count)

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(4, 8))

    xmax = 1.3 * metrics.x.max()
    xmin = -1/6*xmax

    cosmetic = {'legend': None,
                'linewidth': 1.75,
                'alpha': 0.9,
                'facecolor': 'none', 'edgecolor': 'k'}

    sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change*m_scale,
                    ax=axs, **cosmetic)

    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_title('% change')
    axs.set_xlim((xmin, xmax))

    return fig
Ejemplo n.º 5
0
def plot_unit_bilateral_photostim_effect(probe_insertion,
                                         clustering_method=None,
                                         axs=None):
    probe_insertion = probe_insertion.proj()

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(
                str(e) +
                '\nPlease specify one with the kwarg "clustering_method"')

    dv_loc = (ephys.ProbeInsertion.InsertionLocation
              & probe_insertion).fetch1('dv_location')

    no_stim_cond = (
        psth.TrialCondition
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim'
        }).fetch1('KEY')

    bi_stim_cond = (psth.TrialCondition
                    & {
                        'trial_condition_name': 'all_noearlylick_both_alm_stim'
                    }).fetch1('KEY')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & probe_insertion).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    units = ephys.Unit & probe_insertion & {
        'clustering_method': clustering_method
    } & 'unit_quality != "all"'

    metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

    _, cue_onset = _get_trial_event_times(['delay'], units,
                                          'all_noearlylick_both_alm_nostim')
    cue_onset = cue_onset[0]

    # XXX: could be done with 1x fetch+join
    for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):

        x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

        # obtain unit psth per trial, for all nostim and bistim trials
        nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(
            no_stim_cond['trial_condition_name'])
        bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(
            bi_stim_cond['trial_condition_name'])

        nostim_psths, nostim_edge = psth.compute_unit_psth(
            unit, nostim_trials.fetch('KEY'), per_trial=True)
        bistim_psths, bistim_edge = psth.compute_unit_psth(
            unit, bistim_trials.fetch('KEY'), per_trial=True)

        # compute the firing rate difference between contra vs. ipsi within the stimulation duration
        ctrl_frate = np.array([
            nostim_psth[np.logical_and(
                nostim_edge >= cue_onset,
                nostim_edge <= cue_onset + stim_dur)].mean()
            for nostim_psth in nostim_psths
        ])
        stim_frate = np.array([
            bistim_psth[np.logical_and(
                bistim_edge >= cue_onset,
                bistim_edge <= cue_onset + stim_dur)].mean()
            for bistim_psth in bistim_psths
        ])

        frate_change = (stim_frate.mean() -
                        ctrl_frate.mean()) / ctrl_frate.mean()
        frate_change = abs(frate_change) if frate_change < 0 else 0.0001

        metrics.loc[u_idx] = (int(unit['unit']), x, y - dv_loc, frate_change)

    metrics.frate_change = metrics.frate_change / metrics.frate_change.max()

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(4, 8))

    cosmetic = {
        'legend': None,
        'linewidth': 1.75,
        'alpha': 0.9,
        'facecolor': 'none',
        'edgecolor': 'k'
    }

    sns.scatterplot(data=metrics,
                    x='x',
                    y='y',
                    s=metrics.frate_change * m_scale,
                    ax=axs,
                    **cosmetic)

    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_title('% change')
    axs.set_xlim((-10, 60))

    return fig
def plot_psth_bilateral_photostim_effect(probe_insert_key, axs=None):
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    insert = (ephys.ProbeInsertion.InsertionLocation * experiment.BrainLocation
              & probe_insert_key).fetch1()

    period_starts = (
        experiment.Period
        & 'period in ("sample", "delay", "response")').fetch('period_start')

    psth_s_l = (
        psth.UnitPsth * psth.TrialCondition & probe_insert_key
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_left'
        }).fetch('unit_psth')

    psth_n_l = (
        psth.UnitPsth * psth.TrialCondition & probe_insert_key
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_left'
        }).fetch('unit_psth')

    psth_s_r = (
        psth.UnitPsth * psth.TrialCondition & probe_insert_key
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_stim_right'
        }).fetch('unit_psth')

    psth_n_r = (
        psth.UnitPsth * psth.TrialCondition & probe_insert_key
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim_right'
        }).fetch('unit_psth')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & probe_insert_key).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    if insert['hemisphere'] == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Bilateral ALM photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))

    # add shaded bar for photostim
    delay = (
        experiment.Period  # TODO: use from period_starts
        & 'period = "delay"').fetch1('period_start')
    axs[1].axvspan(delay, delay + stim_dur, alpha=0.3, color='royalblue')
def plot_unit_bilateral_photostim_effect(probe_insert_key, axs=None):
    cue_onset = (experiment.Period & 'period = "delay"').fetch1('period_start')

    no_stim_cond = (
        psth.TrialCondition
        & {
            'trial_condition_name': 'all_noearlylick_both_alm_nostim'
        }).fetch1('KEY')

    bi_stim_cond = (psth.TrialCondition
                    & {
                        'trial_condition_name': 'all_noearlylick_both_alm_stim'
                    }).fetch1('KEY')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & probe_insert_key).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    units = ephys.Unit & probe_insert_key & 'unit_quality != "all"'

    metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

    # XXX: could be done with 1x fetch+join
    for u_idx, unit in enumerate(units.fetch('KEY')):

        x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

        nostim_psth, nostim_edge = (psth.UnitPsth & {
            **unit,
            **no_stim_cond
        }).fetch1('unit_psth')

        bistim_psth, bistim_edge = (psth.UnitPsth & {
            **unit,
            **bi_stim_cond
        }).fetch1('unit_psth')

        # compute the firing rate difference between contra vs. ipsi within the stimulation duration
        ctrl_frate = nostim_psth[np.logical_and(
            nostim_edge[1:] >= cue_onset,
            nostim_edge[1:] <= cue_onset + stim_dur)]
        stim_frate = bistim_psth[np.logical_and(
            bistim_edge[1:] >= cue_onset,
            bistim_edge[1:] <= cue_onset + stim_dur)]

        frate_change = np.abs(stim_frate.mean() -
                              ctrl_frate.mean()) / ctrl_frate.mean()

        metrics.loc[u_idx] = (int(unit['unit']), x, y, frate_change)

    metrics.frate_change = metrics.frate_change / metrics.frate_change.max()

    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(4, 8))

    cosmetic = {
        'legend': None,
        'linewidth': 1.75,
        'alpha': 0.9,
        'facecolor': 'none',
        'edgecolor': 'k'
    }

    sns.scatterplot(data=metrics,
                    x='x',
                    y='y',
                    s=metrics.frate_change * m_scale,
                    ax=axs,
                    **cosmetic)

    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_title('% change')
    axs.set_xlim((-10, 60))
Ejemplo n.º 8
0
def plot_selectivity_change_photostim_effect(units,
                                             condition_name_kw,
                                             recover_time_window=None,
                                             ax=None):
    """
    For each unit in the specified units, extract:
    + control, left-instruct PSTH (ctrl_left)
    + control, right-instruct PSTH (ctrl_right)
    + stim, left-instruct PSTH (stim_left)
    + stim, right-instruct PSTH (stim_right)
    Then, control_PSTH and stim_PSTH is defined as
        (ctrl_left - ctrl_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa
        (stim_left - stim_right) for ipsi-selective unit that locates on the left-hemisphere, and vice versa
    Selectivity change is then defined as: control_PSTH - stim_PSTH
    """
    trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        ['good_noearlylick_', '_hit'])[0]
    period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units,
                                           trial_cond_name)

    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim'])[0]
    stim_time, stim_dur = _get_photostim_time_and_duration(
        units,
        psth.TrialCondition().get_trials(stim_trial_cond_name))

    ctrl_left_cond_name = 'all_noearlylick_nostim_left'
    ctrl_right_cond_name = 'all_noearlylick_nostim_right'
    stim_left_cond_name = psth.TrialCondition().get_cond_name_from_keywords(
        condition_name_kw + ['noearlylick', 'stim', 'left'])[0]
    stim_right_cond_name = psth.TrialCondition().get_cond_name_from_keywords(
        condition_name_kw + ['noearlylick', 'stim', 'right'])[0]

    delta_sels, ctrl_psths = [], []
    for unit in (units * psth.UnitSelectivity
                 & 'unit_selectivity != "non-selective"'
                 ).proj('unit_selectivity').fetch(as_dict=True):
        # ---- trial count criteria ----
        # no less than 5 trials for control
        if (len(psth.TrialCondition.get_trials(ctrl_left_cond_name) & unit) < 5
                or len(
                    psth.TrialCondition.get_trials(ctrl_right_cond_name)
                    & unit) < 5):
            continue
        # no less than 2 trials for stimulation
        if (len(psth.TrialCondition.get_trials(stim_left_cond_name) & unit) < 2
                or len(
                    psth.TrialCondition.get_trials(stim_right_cond_name)
                    & unit) < 2):
            continue

        hemi = _get_units_hemisphere(unit)

        ctrl_left_psth, t_vec = psth.UnitPsth.get_plotting_data(
            unit, {'trial_condition_name': ctrl_left_cond_name})['psth']
        ctrl_right_psth, _ = psth.UnitPsth.get_plotting_data(
            unit, {'trial_condition_name': ctrl_right_cond_name})['psth']
        try:
            stim_left_psth, _ = psth.UnitPsth.get_plotting_data(
                unit, {'trial_condition_name': stim_left_cond_name})['psth']
            stim_right_psth, _ = psth.UnitPsth.get_plotting_data(
                unit, {'trial_condition_name': stim_right_cond_name})['psth']
        except:
            continue

        if unit['unit_selectivity'] == 'ipsi-selective':
            ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'left' else ctrl_right_psth - ctrl_left_psth
            stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'left' else stim_right_psth - stim_left_psth
        elif unit['unit_selectivity'] == 'contra-selective':
            ctrl_psth_diff = ctrl_left_psth - ctrl_right_psth if hemi == 'right' else ctrl_right_psth - ctrl_left_psth
            stim_psth_diff = stim_left_psth - stim_right_psth if hemi == 'right' else stim_right_psth - stim_left_psth

        ctrl_psths.append(ctrl_psth_diff)
        delta_sels.append(ctrl_psth_diff - stim_psth_diff)

    ctrl_psths = np.vstack(ctrl_psths)
    delta_sels = np.vstack(delta_sels)

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(4, 6))

    _plot_with_sem(delta_sels, t_vec, ax)

    if recover_time_window:
        recovery_times = []
        for i in range(1000):
            i_sample = np.random.choice(delta_sels.shape[0],
                                        delta_sels.shape[0],
                                        replace=True)
            btstrp_diff = np.nanmean(delta_sels[i_sample, :],
                                     axis=0) / np.nanmean(
                                         ctrl_psths[i_sample, :], axis=0)
            t_recovered = t_vec[(btstrp_diff < 0.2)
                                & (t_vec > recover_time_window[0]) &
                                (t_vec < recover_time_window[1])]
            if len(t_recovered) > 0:
                recovery_times.append(t_recovered[0])
        ax.axvline(x=np.mean(recovery_times), linestyle='--', color='g')
        ax.axvspan(np.mean(recovery_times) - np.std(recovery_times),
                   np.mean(recovery_times) + np.std(recovery_times),
                   alpha=0.2,
                   color='g')

    ax.axhline(y=0, color='k')
    for x in period_starts:
        ax.axvline(x=x, linestyle='--', color='k')
    # add shaded bar for photostim
    ax.axvspan(stim_time,
               stim_time + stim_dur,
               0.95,
               1,
               alpha=0.3,
               color='royalblue')
    ax.set_ylabel('Selectivity change (spike/s)')
    ax.set_xlabel('Time (s)')
Ejemplo n.º 9
0
def plot_psth_photostim_effect(units,
                               condition_name_kw=['both_alm'],
                               axs=None):
    """
    For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction
    The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: bilateral ALM)
    """
    units = units.proj()

    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    hemi = _get_units_hemisphere(units)

    period_starts = (
        experiment.Period
        & 'period in ("sample", "delay", "response")').fetch('period_start')

    # no photostim:
    psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(
        ['_nostim', '_left'])[0]
    psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(
        ['_nostim', '_right'])[0]

    psth_n_l = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_n_l
                } & 'unit_psth is not NULL').fetch('unit_psth')
    psth_n_r = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_n_r
                } & 'unit_psth is not NULL').fetch('unit_psth')

    psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim_left'])[0]
    psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim_right'])[0]

    psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_s_l
                } & 'unit_psth is not NULL').fetch('unit_psth')
    psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
                & {
                    'trial_condition_name': psth_s_r
                } & 'unit_psth is not NULL').fetch('unit_psth')

    # get photostim duration and stim time (relative to go-cue)
    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(
        condition_name_kw + ['_stim'])[0]
    stim_time, stim_dur = _get_photostim_time_and_duration(
        units,
        psth.TrialCondition().get_trials(stim_trial_cond_name))

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0], 'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1], 'Photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])

    # add shaded bar for photostim
    axs[1].axvspan(stim_time,
                   stim_time + stim_dur,
                   alpha=0.3,
                   color='royalblue')