def plot_clustering_quality(probe_insertion, clustering_method=None, axs=None):
    probe_insertion = probe_insertion.proj()

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(
                str(e) +
                '\nPlease specify one with the kwarg "clustering_method"')

    amp, snr, spk_rate, isi_violation = (
        ephys.Unit * ephys.UnitStat * ephys.ProbeInsertion.InsertionLocation
        & probe_insertion & {
            'clustering_method': clustering_method
        }).fetch('unit_amp', 'unit_snr', 'avg_firing_rate', 'isi_violation')

    metrics = {
        'amp': amp,
        'snr': snr,
        'isi': np.array(isi_violation) * 100,  # to percentage
        'rate': np.array(spk_rate)
    }
    label_mapper = {
        'amp': 'Amplitude',
        'snr': 'Signal to noise ratio (SNR)',
        'isi': 'ISI violation (%)',
        'rate': 'Firing rate (spike/s)'
    }

    fig = None
    if axs is None:
        fig, axs = plt.subplots(2, 3, figsize=(12, 8))
        fig.subplots_adjust(wspace=0.4)

    assert axs.size == 6

    for (m1, m2), ax in zip(itertools.combinations(list(metrics.keys()), 2),
                            axs.flatten()):
        ax.plot(metrics[m1], metrics[m2], '.k')
        ax.set_xlabel(label_mapper[m1])
        ax.set_ylabel(label_mapper[m2])

        # cosmetic
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

    return fig
Esempio n. 2
0
    def make(self, key):
        # import here to avoid circular imports
        from pipeline.ingest import ephys as ephys_ingest
        from pipeline.util import _get_clustering_method

        ephys_file = (ephys_ingest.EphysIngest.EphysFile.proj(
            insertion_number='probe_insertion_number')
                      & key).fetch1('ephys_file')
        rigpaths = ephys_ingest.get_ephys_paths()
        for rigpath in rigpaths:
            rigpath = pathlib.Path(rigpath)
            if (rigpath / ephys_file).exists():
                session_ephys_dir = rigpath / ephys_file
                break
        else:
            raise FileNotFoundError(
                'Error - No ephys data directory found for {}'.format(
                    ephys_file))

        key['clustering_method'] = _get_clustering_method(key)
        units = (Unit & key).fetch('unit')
        unit_quality_types = UnitQualityType.fetch('unit_quality')

        ks = ephys_ingest.Kilosort(session_ephys_dir)
        curated_cluster_notes = ks.extract_curated_cluster_notes()

        cluster_notes = []
        for curation_source, cluster_note in curated_cluster_notes.items():
            if curation_source == 'group':
                continue
            cluster_notes.extend([{
                **key, 'note_source': curation_source,
                'unit': u,
                'unit_quality': note
            } for u, note in zip(cluster_note['cluster_ids'],
                                 cluster_note['cluster_notes'])
                                  if u in units and note in unit_quality_types
                                  ])
        self.insert(cluster_notes)
def plot_unit_characteristic(probe_insertion, clustering_method=None, axs=None):
    probe_insertion = probe_insertion.proj()

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

    if clustering_method in ('kilosort2'):
        q_unit = (ephys.Unit * ephys.ProbeInsertion.InsertionLocation.proj('depth') * ephys.UnitStat
                  * lab.ElectrodeConfig.Electrode.proj() * lab.ProbeType.Electrode.proj('x_coord', 'y_coord')
                  & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"').proj(
            ..., x='x_coord', y='y_coord')
    else:
        q_unit = (ephys.Unit * ephys.ProbeInsertion.InsertionLocation.proj('depth') * ephys.UnitStat
                  & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"').proj(
            ..., x='unit_posx', y='unit_posy')

    amp, snr, spk_rate, x, y, insertion_depth = q_unit.fetch(
        'unit_amp', 'unit_snr', 'avg_firing_rate', 'x', 'y', 'depth')

    metrics = pd.DataFrame(list(zip(*(amp/amp.max(), snr/snr.max(), spk_rate/spk_rate.max(),
                                      x, insertion_depth.astype(float) + y))))
    metrics.columns = ['amp', 'snr', 'rate', 'x', 'y']

    # --- prepare for plotting
    shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode,
                                                                shank_count='count(distinct shank)').fetch1('shank_count')
    m_scale = get_m_scale(shank_count)

    ymin = metrics.y.min() - 100
    ymax = metrics.y.max() + 200
    xmax = 1.3 * metrics.x.max()
    xmin = -1/6*xmax
    cosmetic = {'legend': None,
                'linewidth': 1.75,
                'alpha': 0.9,
                'facecolor': 'none', 'edgecolor': 'k'}

    # --- plot
    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(10, 8))
        fig.subplots_adjust(wspace=0.6)

    assert axs.size == 3

    sns.scatterplot(data=metrics, x='x', y='y', s=metrics.amp*m_scale, ax=axs[0], **cosmetic)
    sns.scatterplot(data=metrics, x='x', y='y', s=metrics.snr*m_scale, ax=axs[1], **cosmetic)
    sns.scatterplot(data=metrics, x='x', y='y', s=metrics.rate*m_scale, ax=axs[2], **cosmetic)

    # manually draw the legend
    lg_ypos = ymax
    data = pd.DataFrame({'x': [0.1*xmax, 0.4*xmax, 0.75*xmax], 'y': [lg_ypos, lg_ypos, lg_ypos],
                         'size_ratio': np.array([0.2, 0.5, 0.8])})
    for ax, ax_maxval in zip(axs.flatten(), (amp.max(), snr.max(), spk_rate.max())):
        sns.scatterplot(data=data, x='x', y='y', s=data.size_ratio*m_scale, ax=ax, **dict(cosmetic, facecolor='k'))
        for _, r in data.iterrows():
            ax.text(r['x']-4, r['y']+70, (r['size_ratio']*ax_maxval).astype(int))

    # cosmetic
    for title, ax in zip(('Amplitude', 'SNR', 'Firing rate'), axs.flatten()):
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_title(title)
        ax.set_xlim((xmin, xmax))
        ax.plot([0.5*xmin, xmax], [lg_ypos-80, lg_ypos-80], '-k')
        ax.set_ylim((ymin, ymax + 150))

    return fig
def plot_driftmap(probe_insertion, clustering_method=None, shank_no=1):
    probe_insertion = probe_insertion.proj()

    assert histology.InterpolatedShankTrack & probe_insertion

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

    units = (ephys.Unit * lab.ElectrodeConfig.Electrode
             & probe_insertion & {'clustering_method': clustering_method}
             & 'unit_quality != "all"')
    units = (units.proj('spike_times', 'spike_depths', 'unit_posy')
             * ephys.ProbeInsertion.proj()
             * lab.ProbeType.Electrode.proj('shank') & {'shank': shank_no})

    # ---- ccf region ----
    annotated_electrodes = (lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode
                            * ephys.ProbeInsertion
                            * histology.ElectrodeCCFPosition.ElectrodePosition
                            * ccf.CCFAnnotation * ccf.CCFBrainRegion.proj(..., annotation='region_name')
                            & probe_insertion & {'shank': shank_no})
    pos_y, ccf_y, color_code = annotated_electrodes.fetch(
        'y_coord', 'ccf_y', 'color_code', order_by='y_coord DESC')

    # CCF position of most ventral recording site
    last_electrode_site = np.array((histology.InterpolatedShankTrack.DeepestElectrodePoint
                                    & probe_insertion & {'shank': shank_no}).fetch1(
        'ccf_x', 'ccf_y', 'ccf_z'))
    # CCF position of the brain surface where this shank crosses
    brain_surface_site = np.array((histology.InterpolatedShankTrack.BrainSurfacePoint
                                   & probe_insertion & {'shank': shank_no}).fetch1(
        'ccf_x', 'ccf_y', 'ccf_z'))

    # CCF position of most ventral recording site, with respect to the brain surface
    y_ref = -np.linalg.norm(last_electrode_site - brain_surface_site)

    # ---- spikes ----brain_surface_site
    spike_times, spike_depths = units.fetch('spike_times', 'spike_depths', order_by='unit')

    spike_times = np.hstack(spike_times)
    spike_depths = np.hstack(spike_depths)

    # histogram
    # time_res = 10    # time resolution: 1sec
    # depth_res = 10  # depth resolution: 10um
    #
    # spike_bins = np.arange(0, spike_times.max() + time_res, time_res)
    # depth_bins = np.arange(spike_depths.min() - depth_res, spike_depths.max() + depth_res, depth_res)

    # time-depth 2D histogram
    time_bin_count = 1000
    depth_bin_count = 200

    spike_bins = np.linspace(0, spike_times.max(), time_bin_count)
    depth_bins = np.linspace(0, np.nanmax(spike_depths), depth_bin_count)

    spk_count, spk_edges, depth_edges = np.histogram2d(spike_times, spike_depths, bins=[spike_bins, depth_bins])
    spk_rates = spk_count / np.mean(np.diff(spike_bins))
    spk_edges = spk_edges[:-1]
    depth_edges = depth_edges[:-1]

    # region colorcode, by depths
    binned_hexcodes = []

    y_spacing = np.abs(np.nanmedian(np.where(np.diff(pos_y)==0, np.nan, np.diff(pos_y))))
    anno_depth_bins = np.arange(0, depth_bins[-1], y_spacing)
    for s, e in zip(anno_depth_bins[:-1], anno_depth_bins[1:]):
        hexcodes = color_code[np.logical_and(pos_y > s, pos_y <= e)]
        if len(hexcodes):
            binned_hexcodes.append(Counter(hexcodes).most_common()[0][0])
        else:
            binned_hexcodes.append('FFFFFF')

    region_rgba = np.array([list(ImageColor.getcolor("#" + chex, "RGBA")) for chex in binned_hexcodes])
    region_rgba = np.repeat(region_rgba[:, np.newaxis, :], 10, axis=1)

    # canvas setup
    fig = plt.figure(figsize=(16, 8))
    grid = plt.GridSpec(12, 12)

    ax_main = plt.subplot(grid[1:, 0:9])
    ax_cbar = plt.subplot(grid[0, 0:9])
    ax_spkcount = plt.subplot(grid[1:, 9:11])
    ax_anno = plt.subplot(grid[1:, 11:])

    # -- plot main --
    im = ax_main.imshow(spk_rates.T, aspect='auto', cmap='gray_r',
                        extent=[spike_bins[0], spike_bins[-1], depth_bins[-1], depth_bins[0]])
    # cosmetic
    ax_main.invert_yaxis()
    ax_main.set_xlabel('Time (sec)')
    ax_main.set_ylabel('Distance from tip sites (um)')
    ax_main.set_ylim(depth_edges[0], depth_edges[-1])
    ax_main.spines['right'].set_visible(False)
    ax_main.spines['top'].set_visible(False)

    cb = fig.colorbar(im, cax=ax_cbar, orientation='horizontal')
    cb.outline.set_visible(False)
    cb.ax.xaxis.tick_top()
    cb.set_label('Firing rate (Hz)')
    cb.ax.xaxis.set_label_position('top')

    # -- plot spikecount --
    ax_spkcount.plot(spk_count.sum(axis=0) / 10e3, depth_edges, 'k')
    ax_spkcount.set_xlabel('Spike count (x$10^3$)')
    ax_spkcount.set_yticks([])
    ax_spkcount.set_ylim(depth_edges[0], depth_edges[-1])

    ax_spkcount.spines['right'].set_visible(False)
    ax_spkcount.spines['top'].set_visible(False)
    ax_spkcount.spines['bottom'].set_visible(False)
    ax_spkcount.spines['left'].set_visible(False)

    # -- plot colored region annotation
    ax_anno.imshow(region_rgba, aspect='auto',
                   extent=[0, 10, (anno_depth_bins[-1] + y_ref) / 1000, (anno_depth_bins[0] + y_ref) / 1000])

    ax_anno.invert_yaxis()

    ax_anno.spines['right'].set_visible(False)
    ax_anno.spines['top'].set_visible(False)
    ax_anno.spines['bottom'].set_visible(False)
    ax_anno.spines['left'].set_visible(False)

    ax_anno.set_xticks([])
    ax_anno.yaxis.tick_right()
    ax_anno.set_ylabel('Depth in the brain (mm)')
    ax_anno.yaxis.set_label_position('right')

    return fig
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None):
    probe_insertion = probe_insertion.proj()

    if not (psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim') & probe_insertion):
        raise PhotostimError('No Bilateral ALM Photo-stimulation present')

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

    dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('depth')

    no_stim_cond = (psth.TrialCondition
                    & {'trial_condition_name':
                       'all_noearlylick_nostim'}).fetch1('KEY')

    bi_stim_cond = (psth.TrialCondition
                    & {'trial_condition_name':
                       'all_noearlylick_both_alm_stim'}).fetch1('KEY')

    units = ephys.Unit & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"'

    metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])

    # get photostim onset and duration
    stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
                           * psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
                           & probe_insertion).fetch('duration'))
    stim_dur = _extract_one_stim_dur(stim_durs)
    stim_time = _get_stim_onset_time(units, 'all_noearlylick_both_alm_stim')

    # XXX: could be done with 1x fetch+join
    for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):
        if clustering_method in ('kilosort2'):
            x, y = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj()
                    * lab.ProbeType.Electrode.proj('x_coord', 'y_coord') & unit).fetch1('x_coord', 'y_coord')
        else:
            x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')

        # obtain unit psth per trial, for all nostim and bistim trials
        nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(no_stim_cond['trial_condition_name'])
        bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(bi_stim_cond['trial_condition_name'])

        nostim_psths, nostim_edge = psth.compute_unit_psth(unit, nostim_trials.fetch('KEY'), per_trial=True)
        bistim_psths, bistim_edge = psth.compute_unit_psth(unit, bistim_trials.fetch('KEY'), per_trial=True)

        # compute the firing rate difference between contra vs. ipsi within the stimulation time window
        ctrl_frate = np.array([nostim_psth[np.logical_and(nostim_edge >= stim_time,
                                                          nostim_edge <= stim_time + stim_dur)].mean()
                               for nostim_psth in nostim_psths])
        stim_frate = np.array([bistim_psth[np.logical_and(bistim_edge >= stim_time,
                                                          bistim_edge <= stim_time + stim_dur)].mean()
                               for bistim_psth in bistim_psths])

        frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean()
        frate_change = abs(frate_change) if frate_change < 0 else 0.0001

        metrics.loc[u_idx] = (int(unit['unit']), x, float(dv_loc) + y, frate_change)

    metrics.frate_change = metrics.frate_change / metrics.frate_change.max()

    # --- prepare for plotting
    shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode,
                                                                shank_count='count(distinct shank)').fetch1('shank_count')
    m_scale = get_m_scale(shank_count)

    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 1, figsize=(4, 8))

    xmax = 1.3 * metrics.x.max()
    xmin = -1/6*xmax

    cosmetic = {'legend': None,
                'linewidth': 1.75,
                'alpha': 0.9,
                'facecolor': 'none', 'edgecolor': 'k'}

    sns.scatterplot(data=metrics, x='x', y='y', s=metrics.frate_change*m_scale,
                    ax=axs, **cosmetic)

    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    axs.set_title('% change')
    axs.set_xlim((xmin, xmax))

    return fig
def plot_unit_selectivity(probe_insertion, clustering_method=None, axs=None):
    probe_insertion = probe_insertion.proj()

    if clustering_method is None:
        try:
            clustering_method = _get_clustering_method(probe_insertion)
        except ValueError as e:
            raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')

    if clustering_method in ('kilosort2'):
        q_unit = (psth.PeriodSelectivity * ephys.Unit * ephys.ProbeInsertion.InsertionLocation
                  * lab.ElectrodeConfig.Electrode.proj() * lab.ProbeType.Electrode.proj('x_coord', 'y_coord')
                  * experiment.Period & probe_insertion & {'clustering_method': clustering_method}
                  & 'period_selectivity != "non-selective"').proj(..., x='unit_posx', y='unit_posy').proj(
            ..., x='x_coord', y='y_coord')
    else:
        q_unit = (psth.PeriodSelectivity * ephys.Unit * ephys.ProbeInsertion.InsertionLocation
                  * experiment.Period & probe_insertion & {'clustering_method': clustering_method}
                  & 'period_selectivity != "non-selective"').proj(..., x='unit_posx', y='unit_posy')

    attr_names = ['unit', 'period', 'period_selectivity', 'contra_firing_rate',
                  'ipsi_firing_rate', 'x', 'y', 'depth']
    selective_units = q_unit.fetch(*attr_names)
    selective_units = pd.DataFrame(selective_units).T
    selective_units.columns = attr_names
    selective_units.period_selectivity.astype('category')

    # --- account for insertion depth (manipulator depth)
    selective_units.y = selective_units.depth.values.astype(float) + selective_units.y

    # --- get ipsi vs. contra firing rate difference
    f_rate_diff = np.abs(selective_units.ipsi_firing_rate - selective_units.contra_firing_rate)
    selective_units['f_rate_diff'] = f_rate_diff / f_rate_diff.max()

    # --- prepare for plotting
    shank_count = (ephys.ProbeInsertion & probe_insertion).aggr(lab.ElectrodeConfig.Electrode * lab.ProbeType.Electrode,
                                                                shank_count='count(distinct shank)').fetch1('shank_count')
    m_scale = get_m_scale(shank_count)

    cosmetic = {'legend': None,
                'linewidth': 0.0001}
    ymin = selective_units.y.min() - 100
    ymax = selective_units.y.max() + 100
    xmax = 1.3 * selective_units.x.max()
    xmin = -1/6*xmax

    # a bit of hack to get the 'open circle'
    pts = np.linspace(0, np.pi * 2, 24)
    circ = np.c_[np.sin(pts) / 2, -np.cos(pts) / 2]
    vert = np.r_[circ, circ[::-1] * .7]

    open_circle = mpl.path.Path(vert)

    # --- plot
    fig = None
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(10, 8))
        fig.subplots_adjust(wspace=0.6)

    assert axs.size == 3

    for (title, df), ax in zip(((p, selective_units[selective_units.period == p])
                                for p in ('sample', 'delay', 'response')), axs):
        sns.scatterplot(data=df, x='x', y='y',
                        s=df.f_rate_diff.values.astype(float)*m_scale,
                        hue='period_selectivity', marker=open_circle,
                        palette={'contra-selective': 'b', 'ipsi-selective': 'r'},
                        ax=ax, **cosmetic)
        contra_p = (df.period_selectivity == 'contra-selective').sum() / len(df) * 100
        # cosmetic
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.set_title(f'{title}\n% contra: {contra_p:.2f}\n% ipsi: {100-contra_p:.2f}')
        ax.set_xlim((xmin, xmax))
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_ylim((ymin, ymax))

    return fig
Esempio n. 7
0
def _export_recording(insert_key,
                      output_dir='./',
                      filename=None,
                      overwrite=False):
    '''
    Export a 'recording' (probe specific data + related events) to a file.

    Parameters:

      - insert_key: an ephys.ProbeInsertion.primary_key
        currently: {'subject_id', 'session', 'insertion_number'})

      - output_dir: directory to save the file at (default to be the current working directory)

      - filename: an optional output file path string. If not provided,
        filename will be autogenerated using the 'mkfilename'
        function.
    '''

    if filename is None:
        filename = mkfilename(insert_key)

    filepath = pathlib.Path(output_dir) / filename

    if filepath.exists() and not overwrite:
        print('{} already exists, skipping...'.format(filepath))
        return

    print(
        '\n========================================================================'
    )
    print('exporting {} to {}'.format(insert_key, filepath))

    print('fetching spike/behavior data')

    try:
        insertion = (ephys.ProbeInsertion.InsertionLocation *
                     ephys.ProbeInsertion.proj('probe_type')
                     & insert_key).fetch1()
        loc = (ephys.ProbeInsertion & insert_key).aggr(
            ephys.ProbeInsertion.RecordableBrainRegion.proj(
                brain_region='CONCAT(hemisphere, " ", brain_area)'),
            brain_regions='GROUP_CONCAT(brain_region SEPARATOR ", ")').fetch1(
                'brain_regions')
    except dj.DataJointError:
        raise KeyError('Probe Insertion Location not yet available')

    clustering_method = _get_clustering_method(insert_key)

    q_unit = (ephys.Unit * lab.ElectrodeConfig.Electrode.proj() *
              lab.ProbeType.Electrode.proj('shank') & insert_key & {
                  'clustering_method': clustering_method
              })

    units = q_unit.fetch(order_by='unit')

    behav = (experiment.BehaviorTrial & insert_key).aggr(
        experiment.TrialNote & 'trial_note_type="autolearn"', ...,
        auto_learn='trial_note',
        keep_all_rows=True).fetch(order_by='trial asc')

    trials = behav['trial']

    exports = [
        'probe_insertion_info', 'neuron_single_units', 'neuron_unit_waveforms',
        'neuron_unit_info', 'neuron_unit_quality_control', 'behavior_report',
        'behavior_early_report', 'behavior_lick_times',
        'behavior_lick_directions', 'behavior_is_free_water',
        'behavior_is_auto_water', 'behavior_auto_learn', 'task_trial_type',
        'task_stimulation', 'trial_end_time', 'task_sample_time',
        'task_delay_time', 'task_cue_time', 'tracking', 'histology'
    ]

    edata = {k: [] for k in exports}

    print('reshaping/processing for export')

    # probe_insertion_info
    # -------------------
    edata['probe_insertion_info'] = {
        k: float(v) if isinstance(v, Decimal) else v
        for k, v in dict(insertion, recordable_brain_regions=loc).items()
        if k not in ephys.ProbeInsertion.InsertionLocation.primary_key
    }

    # neuron_single_units
    # -------------------

    # [[u0t0.spikes, ..., u0tN.spikes], ..., [uNt0.spikes, ..., uNtN.spikes]]
    print('... neuron_single_units:', end='')

    q_trial_spikes = (experiment.SessionTrial.proj() * ephys.Unit.proj()
                      & insert_key).aggr(ephys.Unit.TrialSpikes, ...,
                                         spike_times='spike_times',
                                         keep_all_rows=True)

    trial_spikes = q_trial_spikes.fetch(format='frame',
                                        order_by='trial asc').reset_index()

    # replace None with np.array([])
    isna = trial_spikes.spike_times.isna()
    trial_spikes.loc[isna, 'spike_times'] = pd.Series([np.array([])] *
                                                      isna.sum()).values

    single_units = defaultdict(list)
    for u in set(trial_spikes.unit):
        single_units[u] = trial_spikes.spike_times[trial_spikes.unit ==
                                                   u].values.tolist()

    # reformat to a MATLAB compatible form
    ndarray_object = np.empty((len(single_units.keys()), 1), dtype=np.object)
    for idx, i in enumerate(sorted(single_units.keys())):
        ndarray_object[idx, 0] = np.array(single_units[i], ndmin=2).T

    edata['neuron_single_units'] = ndarray_object

    print('ok.')

    # neuron_unit_waveforms
    # -------------------

    edata['neuron_unit_waveforms'] = np.array(units['waveform'], ndmin=2).T

    # neuron_unit_info
    # ----------------
    #
    # [[unit_id, unit_quality, unit_x_in_um, depth_in_um, associated_electrode, shank, cell_type, recording_location] ...]
    print('... neuron_unit_info:', end='')

    dv = float(insertion['depth']) if insertion['depth'] else np.nan

    cell_types = {
        u['unit']: u['cell_type']
        for u in (ephys.UnitCellType
                  & insert_key).fetch(as_dict=True, order_by='unit')
    }

    _ui = []
    for u in units:
        typ = cell_types[u['unit']] if u['unit'] in cell_types else 'unknown'
        _ui.append([
            u['unit'], u['unit_quality'], u['unit_posx'], u['unit_posy'] + dv,
            u['electrode'], u['shank'], typ, loc
        ])

    edata['neuron_unit_info'] = np.array(_ui, dtype='O')

    print('ok.')

    # neuron_unit_quality_control
    # ----------------
    # structure of all of the QC fields, each contains 1d array of length equals to the number of unit. E.g.:
    # presence_ratio: (Nx1)
    # unit_amp: (Nx1)
    # unit_snr: (Nx1)
    # ...

    q_qc = (ephys.Unit & insert_key).proj('unit_amp', 'unit_snr').aggr(
        ephys.UnitStat, ...,
        **{
            n: n
            for n in ephys.UnitStat.heading.names
            if n not in ephys.UnitStat.heading.primary_key
        },
        keep_all_rows=True).aggr(
            ephys.MAPClusterMetric.DriftMetric, ...,
            **{
                n: n
                for n in ephys.MAPClusterMetric.DriftMetric.heading.names if n
                not in ephys.MAPClusterMetric.DriftMetric.heading.primary_key
            },
            keep_all_rows=True).aggr(
                ephys.ClusterMetric,
                ...,
                **{
                    n: n
                    for n in ephys.ClusterMetric.heading.names
                    if n not in ephys.ClusterMetric.heading.primary_key
                },
                keep_all_rows=True).aggr(
                    ephys.WaveformMetric, ...,
                    **{
                        n: n
                        for n in ephys.WaveformMetric.heading.names
                        if n not in ephys.WaveformMetric.heading.primary_key
                    },
                    keep_all_rows=True)
    qc_names = [n for n in q_qc.heading.names if n not in q_qc.primary_key]

    if q_qc:
        qc = (q_qc & insert_key).fetch(*qc_names, order_by='unit')
        qc_df = pd.DataFrame(qc).T
        qc_df.columns = qc_names
        edata['neuron_unit_quality_control'] = {
            n: qc_df.get(n).values
            for n in qc_names if not np.all(np.isnan(qc_df.get(n).values))
        }

    # behavior_report
    # ---------------
    print('... behavior_report:', end='')

    behavior_report_map = {'hit': 1, 'miss': 0, 'ignore': -1}
    edata['behavior_report'] = np.array(
        [behavior_report_map[i] for i in behav['outcome']])

    print('ok.')

    # behavior_early_report
    # ---------------------
    print('... behavior_early_report:', end='')

    early_report_map = {'early': 1, 'no early': 0}
    edata['behavior_early_report'] = np.array(
        [early_report_map[i] for i in behav['early_lick']])

    print('ok.')

    # behavior_is_free_water
    # ---------------------
    print('... behavior_is_free_water:', end='')

    edata['behavior_is_free_water'] = np.array(
        [i for i in behav['free_water']])

    print('ok.')

    # behavior_is_auto_water
    # ---------------------
    print('... behavior_is_auto_water:', end='')

    edata['behavior_is_auto_water'] = np.array(
        [i for i in behav['auto_water']])

    print('ok.')

    # behavior_auto_learn
    # ---------------------
    print('... behavior_auto_learn:', end='')

    edata['behavior_auto_learn'] = np.array(
        [i or 'n/a' for i in behav['auto_learn']])

    print('ok.')

    # behavior_touch_times
    # --------------------

    behavior_touch_times = None  # NOQA no data (see ActionEventType())

    # behavior_lick_times - 0: left lick; 1: right lick
    # -------------------
    print('... behavior_lick_times:', end='')
    lick_direction_mapper = {'left lick': 0, 'right lick': 1}

    _lt, _ld = [], []

    licks = (experiment.ActionEvent() & insert_key
             & "action_event_type in ('left lick', 'right lick')").fetch()

    for t in trials:

        _lt.append([
            float(i) for i in  # decimal -> float
            licks[licks['trial'] == t]['action_event_time']
        ] if t in licks['trial'] else [])
        _ld.append([
            lick_direction_mapper[i] for i in  # decimal -> float
            licks[licks['trial'] == t]['action_event_type']
        ] if t in licks['trial'] else [])

    edata['behavior_lick_times'] = np.array(_lt)
    edata['behavior_lick_directions'] = np.array(_ld)

    behavior_whisker_angle = None  # NOQA no data
    behavior_whisker_dist2pol = None  # NOQA no data

    print('ok.')

    # task_trial_type
    # ---------------
    print('... task_trial_type:', end='')

    task_trial_type_map = {'left': 'l', 'right': 'r'}
    edata['task_trial_type'] = np.array(
        [task_trial_type_map[i] for i in behav['trial_instruction']],
        dtype='O')

    print('ok.')

    # task_stimulation
    # ----------------
    print('... task_stimulation:', end='')

    _ts = []  # [[power, type, on-time, off-time], ...]

    q_photostim = (experiment.Photostim * experiment.PhotostimBrainRegion.proj(
        stim_brain_region='CONCAT(stim_laterality, " ", stim_brain_area)')
                   & insert_key)

    photostim_keyval = {'left ALM': 1, 'right ALM': 2, 'both ALM': 6}

    photostim_map, photostim_dat = {}, {}
    for pstim in q_photostim.fetch():
        photostim_map[pstim['photo_stim']] = photostim_keyval[
            pstim['stim_brain_region']]
        photostim_dat[pstim['photo_stim']] = pstim

    photostim_ev = (experiment.PhotostimEvent & insert_key).fetch()

    for t in trials:

        if t in photostim_ev['trial']:

            ev = photostim_ev[np.where(photostim_ev['trial'] == t)]
            ps = photostim_map[ev['photo_stim'][0]]
            pdat = photostim_dat[ev['photo_stim'][0]]

            _ts.append([
                float(ev['power']), ps,
                float(ev['photostim_event_time']),
                float(ev['photostim_event_time'] + pdat['duration'])
            ])

        else:
            _ts.append([0, math.nan, math.nan, math.nan])

    edata['task_stimulation'] = np.array(_ts)

    print('ok.')

    # task_pole_time
    # --------------

    task_pole_time = None  # NOQA no data

    # task_sample_time - (sample period) - list of (onset, duration) - the LAST "sample" event in a trial
    # -------------

    print('... task_sample_time:', end='')

    _tst, _tsd = ((experiment.BehaviorTrial & insert_key).aggr(
        experiment.TrialEvent & 'trial_event_type = "sample"',
        trial_event_id='max(trial_event_id)') * experiment.TrialEvent).fetch(
            'trial_event_time', 'duration', order_by='trial')

    edata['task_sample_time'] = np.array([_tst, _tsd]).astype(float)

    print('ok.')

    # task_delay_time - (delay period) - list of (onset, duration) - the LAST "delay" event in a trial
    # -------------

    print('... task_delay_time:', end='')

    _tdt, _tdd = ((experiment.BehaviorTrial & insert_key).aggr(
        experiment.TrialEvent & 'trial_event_type = "delay"',
        trial_event_id='max(trial_event_id)') * experiment.TrialEvent).fetch(
            'trial_event_time', 'duration', order_by='trial')

    edata['task_delay_time'] = np.array([_tdt, _tdd]).astype(float)

    print('ok.')

    # task_cue_time - (response period) - list of (onset, duration) - the LAST "go" event in a trial
    # -------------

    print('... task_cue_time:', end='')

    _tct, _tcd = ((experiment.BehaviorTrial & insert_key).aggr(
        experiment.TrialEvent & 'trial_event_type = "go"',
        trial_event_id='max(trial_event_id)') * experiment.TrialEvent).fetch(
            'trial_event_time', 'duration', order_by='trial')

    edata['task_cue_time'] = np.array([_tct, _tcd]).astype(float)

    print('ok.')

    # trial_end_time - list of (onset, duration) - the FIRST "trialend" event in a trial
    # -------------

    print('... trial_end_time:', end='')

    _tet, _ted = ((experiment.BehaviorTrial & insert_key).aggr(
        experiment.TrialEvent & 'trial_event_type = "trialend"',
        trial_event_id='min(trial_event_id)') * experiment.TrialEvent).fetch(
            'trial_event_time', 'duration', order_by='trial')

    edata['trial_end_time'] = np.array([_tet, _ted]).astype(float)

    print('ok.')

    # tracking
    # ----------------
    print('... tracking:', end='')
    tracking_struct = {}
    for feature, feature_tbl in tracking.Tracking().tracking_features.items():
        ft_attrs = [
            n for n in feature_tbl.heading.names
            if n not in feature_tbl.primary_key
        ]
        trk_data = (
            tracking.Tracking * feature_tbl * tracking.TrackingDevice.proj(
                fs='sampling_rate',
                camera='concat(tracking_device, "_", tracking_position)')
            & insert_key).fetch('camera',
                                'fs',
                                'tracking_samples',
                                'trial',
                                *ft_attrs,
                                order_by='trial',
                                as_dict=True)

        for trk_d in trk_data:
            camera = trk_d['camera'].replace(' ', '_').lower()
            if camera not in tracking_struct:
                tracking_struct[camera] = {
                    'fs': float(trk_d['fs']),
                    'Nframes': [],
                    'trialNum': []
                }
            if trk_d['trial'] not in tracking_struct[camera]['trialNum']:
                tracking_struct[camera]['trialNum'].append(trk_d['trial'])
                tracking_struct[camera]['Nframes'].append(trk_d[ft_attrs[0]])
            for ft in ft_attrs:
                if ft not in tracking_struct[camera]:
                    tracking_struct[camera][ft] = []
                tracking_struct[camera][ft].append(trk_d[ft])

    if tracking_struct:
        edata['tracking'] = tracking_struct
        print('ok.')
    else:
        print('n/a')

    # histology - unit ccf
    # ----------------
    print('... histology:', end='')
    unit_ccfs = []
    for ccf_tbl in (histology.ElectrodeCCFPosition.ElectrodePosition,
                    histology.ElectrodeCCFPosition.ElectrodePositionError):
        unit_ccf = (ephys.Unit * ccf_tbl & insert_key & {
            'clustering_method': clustering_method
        }).aggr(ccf.CCFAnnotation, ...,
                annotation='IFNULL(annotation, "")',
                keep_all_rows=True).fetch('unit',
                                          'ccf_x',
                                          'ccf_y',
                                          'ccf_z',
                                          'annotation',
                                          order_by='unit')
        unit_ccfs.extend(list(zip(*unit_ccf)))

    if unit_ccfs:
        unit_id, ccf_x, ccf_y, ccf_z, anno = zip(
            *sorted(unit_ccfs, key=lambda x: x[0]))
        edata['histology'] = {
            'unit': unit_id,
            'ccf_x': ccf_x,
            'ccf_y': ccf_y,
            'ccf_z': ccf_z,
            'annotation': anno
        }
        print('ok.')
    else:
        print('n/a')

    # savemat
    # -------
    print('... saving to {}:'.format(filepath), end='')

    scio.savemat(filepath, edata)

    print('ok.')