def plot_coding_direction(units, time_period=None, label=None, axs=None):
    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    _, proj_contra_trial, proj_ipsi_trial, time_stamps, _ = psth.compute_CD_projected_psth(
        units.fetch('KEY'), time_period=time_period)

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(8, 6))

    # plot
    _plot_with_sem(proj_contra_trial, time_stamps, ax=axs, c='b')
    _plot_with_sem(proj_ipsi_trial, time_stamps, ax=axs, c='r')

    for x in period_starts:
        axs.axvline(x=x, linestyle = '--', color = 'k')
    # cosmetic
    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_ylabel('CD projection (a.u.)')
    axs.set_xlabel('Time (s)')
    if label:
        axs.set_title(label)

    return fig
def plot_avg_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    good_unit = ephys.Unit & 'unit_quality != "all"'

    conds_i = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch('KEY')

    conds_c = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    psth_is_it = (((psth.UnitPsth & conds_i)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_is_ct = (((psth.UnitPsth & conds_c)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_i.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_ct = (((psth.UnitPsth & conds_c)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    psth_cs_it = (((psth.UnitPsth & conds_i)
                   * ephys.Unit.proj('unit_posy'))
                  & good_unit.proj() & sel_c.proj()).fetch(
                      'unit_psth', order_by='unit_posy desc')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    _plot_avg_psth(psth_cs_it, psth_cs_ct, period_starts, axs[0],
                   'Contra-selective')
    _plot_avg_psth(psth_is_it, psth_is_ct, period_starts, axs[1],
                   'Ipsi-selective')

    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))

    return fig
def plot_stacked_contra_ipsi_psth(units, axs=None):
    units = units.proj()

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    hemi = _get_units_hemisphere(units)

    conds_i = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_left_hit' if hemi == 'left' else 'good_noearlylick_right_hit'}).fetch1('KEY')

    conds_c = (psth.TrialCondition
               & {'trial_condition_name':
                  'good_noearlylick_right_hit' if hemi == 'left' else 'good_noearlylick_left_hit'}).fetch1('KEY')

    sel_i = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "ipsi-selective"' & units)

    sel_c = (ephys.Unit * psth.UnitSelectivity
             & 'unit_selectivity = "contra-selective"' & units)

    # ipsi selective ipsi trials
    psth_is_it = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc')

    # ipsi selective contra trials
    psth_is_ct = (psth.UnitPsth * sel_i.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc')

    # contra selective contra trials
    psth_cs_ct = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_c).fetch(order_by='unit_posy desc')

    # contra selective ipsi trials
    psth_cs_it = (psth.UnitPsth * sel_c.proj('unit_posy') & conds_i).fetch(order_by='unit_posy desc')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(20, 20))
    assert axs.size == 2

    _plot_stacked_psth_diff(psth_cs_ct, psth_cs_it, ax=axs[0], vlines=period_starts, flip=True)

    axs[0].set_title('Contra-selective Units')
    axs[0].set_ylabel('Unit (by depth)')
    axs[0].set_xlabel('Time to go (s)')

    _plot_stacked_psth_diff(psth_is_it, psth_is_ct, ax=axs[1], vlines=period_starts)

    axs[1].set_title('Ipsi-selective Units')
    axs[1].set_ylabel('Unit (by depth)')
    axs[1].set_xlabel('Time to go (s)')

    return fig
示例#4
0
def plot_unit_psth(unit_key, axs=None, title='', xlim=_plt_xlim):
    """
    Default raster and PSTH plot for a specified unit - only {good, no early lick, correct trials} selected
    condition_name_kw: list of keywords to match for the TrialCondition name
    """

    hemi = _get_units_hemisphere(unit_key)

    ipsi_hit_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"left" if hemi == "left" else "right"}_hit'
        })

    contra_hit_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"right" if hemi == "left" else "left"}_hit'
        })

    ipsi_miss_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"left" if hemi == "left" else "right"}_miss'
        })

    contra_miss_unit_psth = psth.UnitPsth.get_plotting_data(
        unit_key, {
            'trial_condition_name':
            f'good_noearlylick_{"right" if hemi == "left" else "left"}_miss'
        })

    # get event start times: sample, delay, response
    periods, period_starts = _get_trial_event_times(['sample', 'delay', 'go'],
                                                    unit_key,
                                                    'good_noearlylick_hit')

    fig = None
    if axs is None:
        fig, axs = plt.subplots(2, 2)

    # correct response
    _plot_spike_raster(ipsi_hit_unit_psth,
                       contra_hit_unit_psth,
                       ax=axs[0, 0],
                       vlines=period_starts,
                       title=title if title else
                       f'Unit #: {unit_key["unit"]}\nCorrect Response',
                       xlim=xlim)
    _plot_psth(ipsi_hit_unit_psth,
               contra_hit_unit_psth,
               vlines=period_starts,
               ax=axs[1, 0],
               xlim=xlim)

    # incorrect response
    _plot_spike_raster(ipsi_miss_unit_psth,
                       contra_miss_unit_psth,
                       ax=axs[0, 1],
                       vlines=period_starts,
                       title=title if title else
                       f'Unit #: {unit_key["unit"]}\nIncorrect Response',
                       xlim=xlim)
    _plot_psth(ipsi_miss_unit_psth,
               contra_miss_unit_psth,
               vlines=period_starts,
               ax=axs[1, 1],
               xlim=xlim)

    return fig
示例#5
0
    def make(self, key):
        water_res_num, sess_date = get_wr_sessdate(key)
        sess_dir = store_stage / water_res_num / sess_date
        sess_dir.mkdir(parents=True, exist_ok=True)

        # ---- Setup ----
        time_period = (-0.4, 0)
        probe_keys = (ephys.ProbeInsertion & key).fetch(
            'KEY', order_by='insertion_number')

        fig1, axs = plt.subplots(len(probe_keys),
                                 len(probe_keys),
                                 figsize=(16, 16))

        if len(probe_keys) > 1:
            [a.axis('off') for a in axs.flatten()]

            # ---- Plot Coding Direction per probe ----
            probe_proj = {}
            for pid, probe in enumerate(probe_keys):
                units = ephys.Unit & probe
                label = (ephys.ProbeInsertion & probe).aggr(
                    ephys.ProbeInsertion.RecordableBrainRegion.proj(
                        brain_region='CONCAT(hemisphere, " ", brain_area)'),
                    brain_regions='GROUP_CONCAT(brain_region SEPARATOR", ")'
                ).fetch1('brain_regions')
                label = '({}) {}'.format(probe['insertion_number'], label)

                _, period_starts = _get_trial_event_times(
                    ['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

                # ---- compute CD projected PSTH ----
                _, proj_contra_trial, proj_ipsi_trial, time_stamps, hemi = psth.compute_CD_projected_psth(
                    units.fetch('KEY'), time_period=time_period)

                # ---- save projection results ----
                probe_proj[pid] = (proj_contra_trial, proj_ipsi_trial,
                                   time_stamps, label, hemi)

                # ---- generate fig with CD plot for this probe ----
                fig, ax = plt.subplots(1, 1, figsize=(6, 6))
                _plot_with_sem(proj_contra_trial, time_stamps, ax=ax, c='b')
                _plot_with_sem(proj_ipsi_trial, time_stamps, ax=ax, c='r')
                # cosmetic
                for x in period_starts:
                    ax.axvline(x=x, linestyle='--', color='k')
                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                ax.set_ylabel('CD projection (a.u.)')
                ax.set_xlabel('Time (s)')
                ax.set_title(label)
                fig.tight_layout()

                # ---- plot this fig into the main figure ----
                buf = io.BytesIO()
                fig.savefig(buf, format='png')
                buf.seek(0)
                axs[pid, pid].imshow(Image.open(buf))
                buf.close()
                plt.close(fig)

            # ---- Plot probe-pair correlation ----
            for p1, p2 in itertools.combinations(probe_proj.keys(), r=2):
                proj_contra_trial_g1, proj_ipsi_trial_g1, time_stamps, label_g1, p1_hemi = probe_proj[
                    p1]
                proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, label_g2, p2_hemi = probe_proj[
                    p2]
                labels = [label_g1, label_g2]

                # plot trial CD-endpoint correlation
                p_start, p_end = time_period
                contra_cdend_1 = proj_contra_trial_g1[:,
                                                      np.logical_and(
                                                          time_stamps >=
                                                          p_start, time_stamps
                                                          < p_end)].mean(
                                                              axis=1)
                ipsi_cdend_1 = proj_ipsi_trial_g1[:,
                                                  np.logical_and(
                                                      time_stamps >= p_start,
                                                      time_stamps < p_end
                                                  )].mean(axis=1)
                if p1_hemi == p2_hemi:
                    contra_cdend_2 = proj_contra_trial_g2[:,
                                                          np.logical_and(
                                                              time_stamps >=
                                                              p_start,
                                                              time_stamps <
                                                              p_end)].mean(
                                                                  axis=1)
                    ipsi_cdend_2 = proj_ipsi_trial_g2[:,
                                                      np.logical_and(
                                                          time_stamps >=
                                                          p_start, time_stamps
                                                          < p_end)].mean(
                                                              axis=1)
                else:
                    contra_cdend_2 = proj_ipsi_trial_g2[:,
                                                        np.logical_and(
                                                            time_stamps >=
                                                            p_start,
                                                            time_stamps < p_end
                                                        )].mean(axis=1)
                    ipsi_cdend_2 = proj_contra_trial_g2[:,
                                                        np.logical_and(
                                                            time_stamps >=
                                                            p_start,
                                                            time_stamps < p_end
                                                        )].mean(axis=1)

                c_df = pd.DataFrame([contra_cdend_1, contra_cdend_2]).T
                c_df.columns = labels
                c_df['trial-type'] = 'contra'
                i_df = pd.DataFrame([ipsi_cdend_1, ipsi_cdend_2]).T
                i_df.columns = labels
                i_df['trial-type'] = 'ipsi'
                df = c_df.append(i_df)

                # remove NaN trial - could be due to some trials having no spikes
                non_nan = ~np.logical_or(
                    np.isnan(df[labels[0]]).values,
                    np.isnan(df[labels[1]]).values)
                df = df[non_nan]

                fig = plt.figure(figsize=(6, 6))
                jplot = _jointplot_w_hue(data=df,
                                         x=labels[0],
                                         y=labels[1],
                                         hue='trial-type',
                                         colormap=['b', 'r'],
                                         figsize=(8, 6),
                                         fig=fig,
                                         scatter_kws=None)

                # ---- plot this fig into the main figure ----
                buf = io.BytesIO()
                fig.savefig(buf, format='png')
                buf.seek(0)
                axs[p1, p2].imshow(Image.open(buf))
                buf.close()
                plt.close(fig)

        else:
            # ---- Plot Single-Probe Coding Direction ----
            probe = probe_keys[0]
            units = ephys.Unit & probe
            label = (ephys.ProbeInsertion & probe).aggr(
                ephys.ProbeInsertion.RecordableBrainRegion.proj(
                    brain_region='CONCAT(hemisphere, " ", brain_area)'),
                brain_regions='GROUP_CONCAT(brain_region SEPARATOR", ")'
            ).fetch1('brain_regions')

            unit_characteristic_plot.plot_coding_direction(
                units, time_period=time_period, label=label, axs=axs)

        # ---- Save fig and insert ----
        fn_prefix = f'{water_res_num}_{sess_date}_'
        fig_dict = save_figs((fig1, ), ('coding_direction', ), sess_dir,
                             fn_prefix)

        plt.close('all')
        self.insert1({**key, **fig_dict, 'cd_probe_count': len(probe_keys)})
def plot_paired_coding_direction(unit_g1, unit_g2, labels=None, time_period=None):
    """
    Plot trial-to-trial CD-endpoint correlation between CD-projected trial-psth from two unit-groups (e.g. two brain regions)
    Note: coding direction is calculated on selective units, contra vs. ipsi, within the specified time_period
    """
    _, proj_contra_trial_g1, proj_ipsi_trial_g1, time_stamps, unit_g1_hemi = psth.compute_CD_projected_psth(
        unit_g1.fetch('KEY'), time_period=time_period)
    _, proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, unit_g2_hemi = psth.compute_CD_projected_psth(
        unit_g2.fetch('KEY'), time_period=time_period)

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_g1, 'good_noearlylick_hit')

    if labels:
        assert len(labels) == 2
    else:
        labels = ('unit group 1', 'unit group 2')

    # plot projected trial-psth
    fig, axs = plt.subplots(1, 2, figsize=(16, 6))

    _plot_with_sem(proj_contra_trial_g1, time_stamps, ax=axs[0], c='b')
    _plot_with_sem(proj_ipsi_trial_g1, time_stamps, ax=axs[0], c='r')
    _plot_with_sem(proj_contra_trial_g2, time_stamps, ax=axs[1], c='b')
    _plot_with_sem(proj_ipsi_trial_g2, time_stamps, ax=axs[1], c='r')

    # cosmetic
    for ax, label in zip(axs, labels):
        for x in period_starts:
            ax.axvline(x=x, linestyle = '--', color = 'k')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_ylabel('CD projection (a.u.)')
        ax.set_xlabel('Time (s)')
        ax.set_title(label)

    # plot trial CD-endpoint correlation - if 2 unit-groups are from 2 hemispheres,
    #   then contra-ipsi definition is based on the first group
    p_start, p_end = time_period
    contra_cdend_1 = proj_contra_trial_g1[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)
    ipsi_cdend_1 = proj_ipsi_trial_g1[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)
    if unit_g1_hemi == unit_g1_hemi:
        contra_cdend_2 = proj_contra_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)
        ipsi_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)
    else:
        contra_cdend_2 = proj_ipsi_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)
        ipsi_cdend_2 = proj_contra_trial_g2[:, np.logical_and(time_stamps >= p_start, time_stamps < p_end)].mean(axis=1)

    c_df = pd.DataFrame([contra_cdend_1, contra_cdend_2]).T
    c_df.columns = labels
    c_df['trial-type'] = 'contra'
    i_df = pd.DataFrame([ipsi_cdend_1, ipsi_cdend_2]).T
    i_df.columns = labels
    i_df['trial-type'] = 'ipsi'
    df = c_df.append(i_df)

    jplot = _jointplot_w_hue(data=df, x=labels[0], y=labels[1], hue= 'trial-type', colormap=['b', 'r'],
                             figsize=(8, 6), fig=None, scatter_kws=None)
    jplot['fig'].show()

    return fig
def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
    """
    For the specified `units`, plot PSTH comparison between stim vs. no-stim with left/right trial instruction
    The stim location (or other appropriate search keywords) can be specified in `condition_name_kw` (default: both ALM)
    """
    units = units.proj()

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
    assert axs.size == 2

    hemi = _get_units_hemisphere(units)

    # no photostim:
    psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0]
    psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0]

    psth_n_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_n_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_n_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # with photostim
    psth_s_l = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_left'])[0]
    psth_s_r = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim_right'])[0]

    psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_l} & 'unit_psth is not NULL').fetch('unit_psth')
    psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
                & {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth')

    # get event start times: sample, delay, response
    period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')

    # get photostim onset and duration
    stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0]
    stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
                           * psth.TrialCondition().get_trials(stim_trial_cond_name)
                           & units).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)
    stim_time = _get_stim_onset_time(units, stim_trial_cond_name)

    if hemi == 'left':
        psth_s_i = psth_s_l
        psth_n_i = psth_n_l
        psth_s_c = psth_s_r
        psth_n_c = psth_n_r
    else:
        psth_s_i = psth_s_r
        psth_n_i = psth_n_r
        psth_s_c = psth_s_l
        psth_n_c = psth_n_l

    _plot_avg_psth(psth_n_i, psth_n_c, period_starts, axs[0],
                   'Control')
    _plot_avg_psth(psth_s_i, psth_s_c, period_starts, axs[1],
                   'Photostim')

    # cosmetic
    ymax = max([ax.get_ylim()[1] for ax in axs])
    for ax in axs:
        ax.set_ylim((0, ymax))
        ax.set_xlim([_plt_xmin, _plt_xmax])

    # add shaded bar for photostim
    axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')

    return fig
def plot_unit_bilateral_photostim_effect(probe_insertion,
                                         clustering_method=None,
                                         axs=None):
    probe_insertion = probe_insertion.proj()

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(
                str(e) +
                '\nPlease specify one with the kwarg "clustering_method"')

    dv_loc = (ephys.ProbeInsertion.InsertionLocation
              & probe_insertion).fetch1('depth')

    no_stim_cond = (psth.TrialCondition
                    & {
                        'trial_condition_name': 'all_noearlylick_nostim'
                    }).fetch1('KEY')

    bi_stim_cond = (psth.TrialCondition
                    & {
                        'trial_condition_name': 'all_noearlylick_both_alm_stim'
                    }).fetch1('KEY')

    # get photostim duration
    stim_durs = np.unique(
        (experiment.Photostim & experiment.PhotostimEvent *
         psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
         & probe_insertion).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)

    units = ephys.Unit & probe_insertion & {
        'clustering_method': clustering_method
    } & 'unit_quality != "all"'

    metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

    _, cue_onset = _get_trial_event_times(['delay'], units,
                                          'all_noearlylick_nostim')
    cue_onset = cue_onset[0]

    # XXX: could be done with 1x fetch+join
    for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):
        if clustering_method in ('kilosort2'):
            x, y = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj() *
                    lab.ProbeType.Electrode.proj('x_coord', 'y_coord')
                    & unit).fetch1('x_coord', 'y_coord')
        else:
            x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

        # obtain unit psth per trial, for all nostim and bistim trials
        nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(
            no_stim_cond['trial_condition_name'])
        bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(
            bi_stim_cond['trial_condition_name'])

        nostim_psths, nostim_edge = psth.compute_unit_psth(
            unit, nostim_trials.fetch('KEY'), per_trial=True)
        bistim_psths, bistim_edge = psth.compute_unit_psth(
            unit, bistim_trials.fetch('KEY'), per_trial=True)

        # compute the firing rate difference between contra vs. ipsi within the stimulation duration
        ctrl_frate = np.array([
            nostim_psth[np.logical_and(
                nostim_edge >= cue_onset,
                nostim_edge <= cue_onset + stim_dur)].mean()
            for nostim_psth in nostim_psths
        ])
        stim_frate = np.array([
            bistim_psth[np.logical_and(
                bistim_edge >= cue_onset,
                bistim_edge <= cue_onset + stim_dur)].mean()
            for bistim_psth in bistim_psths
        ])

        frate_change = (stim_frate.mean() -
                        ctrl_frate.mean()) / ctrl_frate.mean()
        frate_change = abs(frate_change) if frate_change < 0 else 0.0001

        metrics.loc[u_idx] = (int(unit['unit']), x, float(dv_loc) + y,
                              frate_change)

    metrics.frate_change = metrics.frate_change / metrics.frate_change.max()

    # --- prepare for plotting
    shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(
        lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode,
        shank_count='count(distinct shank)').fetch1('shank_count')
    m_scale = get_m_scale(shank_count)

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(4, 8))

    xmax = 1.3 * metrics.x.max()
    xmin = -1 / 6 * xmax

    cosmetic = {
        'legend': None,
        'linewidth': 1.75,
        'alpha': 0.9,
        'facecolor': 'none',
        'edgecolor': 'k'
    }

    sns.scatterplot(data=metrics,
                    x='x',
                    y='y',
                    s=metrics.frate_change * m_scale,
                    ax=axs,
                    **cosmetic)

    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_title('% change')
    axs.set_xlim((xmin, xmax))

    return fig