예제 #1
0
def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''):
    """
    Extracts sync.times, sync.channels and sync.polarities from binary ephys dataset

    :param raw_ephys_apfile: bin file containing ephys data or spike
    :param output_path: output directory
    :param save: bool write to disk only if True
    :param parts: string or list of strings that will be appended to the filename before extension
    :return:
    """
    # handles input argument: support ibllib.io.spikeglx.Reader, str and pathlib.Path
    if isinstance(raw_ephys_apfile, spikeglx.Reader):
        sr = raw_ephys_apfile
    else:
        raw_ephys_apfile = Path(raw_ephys_apfile)
        sr = spikeglx.Reader(raw_ephys_apfile)
    opened = sr.is_open
    if not opened:  # if not (opened := sr.is_open)  # py3.8
        sr.open()
    # if no output, need a temp folder to swap for big files
    if not output_path:
        output_path = raw_ephys_apfile.parent
    file_ftcp = Path(output_path).joinpath(
        f'fronts_times_channel_polarity{str(uuid.uuid4())}.bin')

    # loop over chunks of the raw ephys file
    wg = neurodsp.utils.WindowGenerator(sr.ns,
                                        int(SYNC_BATCH_SIZE_SECS * sr.fs),
                                        overlap=1)
    fid_ftcp = open(file_ftcp, 'wb')
    for sl in wg.slice:
        ss = sr.read_sync(sl)
        ind, fronts = neurodsp.utils.fronts(ss, axis=0)
        # a = sr.read_sync_analog(sl)
        sav = np.c_[(ind[0, :] + sl.start) / sr.fs, ind[1, :],
                    fronts.astype(np.double)]
        sav.tofile(fid_ftcp)
    # close temp file, read from it and delete
    fid_ftcp.close()
    tim_chan_pol = np.fromfile(str(file_ftcp))
    tim_chan_pol = tim_chan_pol.reshape((int(tim_chan_pol.size / 3), 3))
    file_ftcp.unlink()
    sync = {
        'times': tim_chan_pol[:, 0],
        'channels': tim_chan_pol[:, 1],
        'polarities': tim_chan_pol[:, 2]
    }
    # If opened Reader was passed into function, leave open
    if not opened:
        sr.close()
    if save:
        out_files = alfio.save_object_npy(output_path,
                                          sync,
                                          'sync',
                                          namespace='spikeglx',
                                          parts=parts)
        return Bunch(sync), out_files
    else:
        return Bunch(sync)
예제 #2
0
def detection(data,
              fs,
              h,
              detect_threshold=-4,
              time_tol=.002,
              distance_threshold_um=70):
    """
    Detects and de-duplicates negative voltage spikes based on voltage thresholding.
    The de-duplication step locks in maximum amplitude events. To account for collisions the amplitude
    is assumed to be decaying from the peak. If this is a multipeak event, each is labeled as a spike.

    :param data: 2D numpy array nsamples x nchannels
    :param fs: sampling frequency (Hz)
    :param h: dictionary with neuropixel geometry header: see. neuropixel.trace_header
    :param detect_threshold: negative value below which the voltage is considered to be a spike
    :param time_tol: time in seconds for which samples before and after are assumed to be part of the spike
    :param distance_threshold_um: distance for which exceeding threshold values are assumed to part of the same spike
    :return: spikes dictionary of vectors with keys "time", "trace", "amp" and "ispike"
    """
    multipeak = False
    time_bracket = np.array([-1, 1]) * time_tol
    inds, indtr = np.where(data < detect_threshold)
    picks = Bunch(time=inds / fs,
                  trace=indtr,
                  amp=data[inds, indtr],
                  ispike=np.zeros(inds.size))
    amp_order = np.argsort(picks.amp)

    hxy = h['x'] + 1j * h['y']

    spike_id = 1
    while np.any(picks.ispike == 0):
        # find the first unassigned spike with the highest amplitude
        iamp = np.where(picks.ispike[amp_order] == 0)[0][0]
        imax = amp_order[iamp]
        # look only within the time range
        itlims = np.searchsorted(picks.time, picks.time[imax] + time_bracket)
        itlims = np.arange(itlims[0], itlims[1])

        offset = np.abs(hxy[picks.trace[itlims]] - hxy[picks.trace[imax]])
        iit = np.where(offset < distance_threshold_um)[0]

        picks.ispike[itlims[iit]] = -1
        picks.ispike[imax] = spike_id
        # handles collision with a simple amplitude decay model: if amplitude doesn't decay
        # as a function of offset, then it's a collision and another spike is set
        if multipeak:  # noqa
            iii = np.lexsort((picks.amp[itlims[iit]], offset[iit]))
            sorted_amps_db = 20 * np.log10(np.abs(picks.amp[itlims[iit][iii]]))
            idetect = np.r_[0, np.where(np.diff(sorted_amps_db) > 12)[0] + 1]
            picks.ispike[itlims[iit[iii[idetect]]]] = np.arange(
                idetect.size) + spike_id
            spike_id += idetect.size
        else:
            spike_id += 1

    detects = Bunch({k: picks[k][picks.ispike > 0] for k in picks})
    return detects
예제 #3
0
    def __init__(self, x, y, z=None, c=None, cmap=None, plot_type='scatter'):
        """
        Class for organising data that will be used to create scatter plots. Can be 2D or 3D (if
        z given). Can also represent variable through color by specifying c

        :param x: x values for data
        :param y: y values for data
        :param z: z values for data
        :param c: values to use to represent color of scatter points
        :param cmap: name of colormap to use if c is given
        :param plot_type:
        """
        data = Bunch({'x': x, 'y': y, 'z': z, 'c': c})

        assert len(data['x']) == len(data['y']), 'dimensions must agree'
        if data['z'] is not None:
            assert len(data['z']) == len(data['x']), 'dimensions must agree'
        if data['c'] is not None:
            assert len(data['c']) == len(data['x']), 'dimensions must agree'

        super().__init__(plot_type, data)

        self._set_init_style()
        self.set_xlim()
        self.set_ylim()
        # If we have 3D data
        if data['z'] is not None:
            self.set_zlim()
        # If we want colorbar associated with scatter plot
        self.set_clim()
        self.cmap = self._set_default(cmap, 'viridis')
예제 #4
0
def get_video_meta(video_path, one=None):
    """
    Return a bunch of video information with the fields ('length', 'fps', 'width', 'height',
    'duration', 'size')
    :param video_path: A path to the video.  May be a file path or URL.
    :param one: An instance of ONE
    :return: A Bunch of video mata data
    """
    is_url = isinstance(video_path, str) and video_path.startswith('http')
    cap = VideoStreamer(video_path).cap if is_url else cv2.VideoCapture(str(video_path))
    assert cap.isOpened(), f'Failed to open video file {video_path}'

    # Get basic properties of video
    meta = Bunch()
    meta.length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    meta.fps = int(cap.get(cv2.CAP_PROP_FPS))
    meta.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    meta.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    meta.duration = timedelta(seconds=meta.length / meta.fps) if meta.fps > 0 else 0
    if is_url and one:
        eid = one.path2eid(video_path)
        datasets = one.list_datasets(eid, details=True)
        label = label_from_path(video_path)
        record = datasets[datasets['rel_path'].str.contains(f'_iblrig_{label}Camera.raw')]
        assert len(record) == 1
        meta.size = record['file_size'].iloc[0]
    elif is_url and not one:
        meta.size = None
    else:
        meta.size = Path(video_path).stat().st_size
    cap.release()
    return meta
예제 #5
0
    def __init__(self, img, x=None, y=None, cmap=None):
        """
        Class for organising data that will be used to create 2D image plots

        :param img: 2D image data
        :param x: x coordinate of each image voxel in x dimension
        :param y: y coordinate of each image voxel in y dimension
        :param cmap: name of colormap to use
        """

        data = Bunch({
            'x': self._set_default(x, np.arange(img.shape[0])),
            'y': self._set_default(y, np.arange(img.shape[1])),
            'c': img
        })

        # Make sure dimensions agree
        assert data['c'].shape[0] == data['x'].shape[
            0], 'dimensions must agree'
        assert data['c'].shape[1] == data['y'].shape[
            0], 'dimensions must agree'

        # Initialise default plot class with data
        super().__init__('image', data)
        self.scale = None
        self.offset = None
        self.cmap = self._set_default(cmap, 'viridis')

        self.set_xlim()
        self.set_ylim()
        self.set_clim()
예제 #6
0
    def combine_behaviour_data(self, wheel, dlc_left, dlc_right):
        behav = Bunch()
        behav_options = []
        aligned = [True]

        if all([val in wheel.keys() for val in ['timestamps', 'position']]):
            behav['wheel'] = wheel
            behav_options = behav_options + ['wheel']

        if all([val in dlc_left.keys() for val in ['times', 'dlc']]):
            dlc = get_dlc_everything(dlc_left, 'left')
            behav['leftCamera'] = dlc
            keys = [f'leftCamera_{key}' for key in dlc.dlc.keys() if 'speed' in key]
            behav_options = behav_options + keys
            keys = [f'leftCamera_{key}' for key in dlc.keys() if
                    any([key in name for name in ['licks', 'sniffs']])]
            behav_options = behav_options + keys
            aligned.append(dlc['aligned'])

        if all([val in dlc_right.keys() for val in ['times', 'dlc']]):
            dlc = get_dlc_everything(dlc_right, 'right')
            behav['rightCamera'] = dlc
            keys = [f'rightCamera_{key}' for key in dlc.dlc.keys() if 'speed' in key]
            behav_options = behav_options + keys
            keys = [f'rightCamera_{key}' for key in dlc.keys() if
                    any([key in name for name in ['licks', 'sniffs']])]
            behav_options = behav_options + keys
            aligned.append(dlc['aligned'])

        aligned = all(aligned)

        return behav, behav_options, aligned
예제 #7
0
    def __init__(self, img, x, y, cmap=None):
        """
        Class for organising data that will be used to create 2D probe plots. Use function
        plot_base.arrange_channels2bank to prepare data in correct format before using this class

        :param img: list of image data for each bank of probe
        :param x: list of x coordinate for each bank of probe
        :param y: list of y coordinate for each bank or probe
        :param cmap: name of cmap
        """

        # Make sure we have inputs as lists, can get input from arrange_channels2banks
        assert (type(img) == list)
        assert (type(x) == list)
        assert (type(y) == list)

        data = Bunch({'x': x, 'y': y, 'c': img})
        super().__init__('probe', data)
        self.cmap = self._set_default(cmap, 'viridis')

        self.set_xlim()
        self.set_ylim()
        self.set_clim()
        self.set_scale()
        self.set_offset()
예제 #8
0
    def get_qc_info(self, eid, one):
        data = Bunch()
        sess = one.alyx.rest('sessions', 'read', id=eid)['extended_qc']
        for qc in SESS_QC:
            data[qc] = sess.get(qc, '')

        return data
예제 #9
0
 def __init__(self,
              eid=None,
              one=None,
              log=logging.getLogger('ibllib'),
              **kwargs):
     self.one = one or ONE()
     self.eid = eid
     self.session_path = kwargs.pop('session_path',
                                    None) or self.one.eid2path(eid)
     self.ref = self.one.dict2ref(self.one.path2ref(self.session_path))
     self.log = log
     self.trials = self.wheel = self.camera_times = None
     raw_cam_path = self.session_path.joinpath('raw_video_data')
     camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*'))
     self.video_paths = {vidio.label_from_path(x): x for x in camera_path}
     self.data = Bunch()
     self.alignment = Bunch()
예제 #10
0
    def get_template_for_selection(self):
        data = Bunch()
        template = (self.clusters.waveforms[self.clust_ids[self.clust], :, 0]) * 1e6
        t_template = 1e3 * np.arange(template.shape[0]) / FS

        data['vals'] = template
        data['time'] = t_template

        return data
예제 #11
0
    def _get_raster(self, raster, trials_id):
        data = Bunch()
        data['raster'] = np.r_[raster.vals[trials_id],
                               np.full((self.n_trials - trials_id.shape[0],
                                        raster.vals.shape[1]), np.nan)]
        data['time'] = raster.time
        data['cmap'] = raster.cmap
        data['clevels'] = raster.clevels

        return data
예제 #12
0
    def sort_clusters(self, sort):
        clust_sort = Bunch()
        self.clust_ids = self.clusters.clust_ids[self.clusters[sort]]
        clust_sort['ids'] = self.clust_ids
        clust_sort['amps'] = self.clusters.amps[self.clusters[sort]] * 1e6
        clust_sort['depths'] = self.clusters.depths[self.clusters[sort]]
        clust_sort['colours_ks'] = self.clusters.colours_ks[self.clusters[sort]]
        clust_sort['colours_ibl'] = self.clusters.colours_ibl[self.clusters[sort]]

        return clust_sort
예제 #13
0
def get_sync_fronts(sync, channel_nb, tmin=None, tmax=None):
    selection = sync['channels'] == channel_nb
    selection = np.logical_and(selection,
                               sync['times'] <= tmax) if tmax else selection
    selection = np.logical_and(selection,
                               sync['times'] >= tmin) if tmin else selection
    return Bunch({
        'times': sync['times'][selection],
        'polarities': sync['polarities'][selection]
    })
예제 #14
0
    def get_cluster_metrics(self):
        data = Bunch()
        # Only if we have the metrics
        if any(self.clusters.get('metrics', [None])):
            for qc in CLUSTER_QC:
                data[qc] = round((self.clusters.metrics[qc][self.clust_ids[self.clust]]), 3)
        else:
            for qc in CLUSTER_QC:
                data[qc] = None

        return data
예제 #15
0
 def test_get_active_wheel_period(self):
     """Check that warning is raised, period is returned None, and QC is NOT_SET
      if there is active wheel period to be found"""
     wheel_keys = ('timestamps', 'position')
     wheel_data = (np.arange(1000), np.ones(1000))
     self.qc.data['wheel'] = Bunch(zip(wheel_keys, wheel_data))
     with self.assertLogs(logging.getLogger('ibllib'), logging.WARNING):
         period = self.qc.get_active_wheel_period(self.qc.data['wheel'])
     self.assertEqual(None, period)
     outcome = self.qc.check_wheel_alignment()
     self.assertEqual('NOT_SET', outcome)
예제 #16
0
    def get_autocorr_for_selection(self):
        data = Bunch()
        x_corr = xcorr(self.spikes.times[self.spikes.clusters == self.clust_ids[self.clust]],
                       self.spikes.clusters[self.spikes.clusters == self.clust_ids[self.clust]],
                       AUTOCORR_BIN, AUTOCORR_WINDOW)
        t_corr = np.arange(0, AUTOCORR_WINDOW + AUTOCORR_BIN, AUTOCORR_BIN) - AUTOCORR_WINDOW / 2

        data['vals'] = x_corr[0, 0, :]
        data['time'] = t_corr

        return data
예제 #17
0
 def get(self, ids) -> Bunch:
     """
     Get a bunch of the name/id
     """
     uid, uind = np.unique(ids, return_inverse=True)
     a, iself, _ = np.intersect1d(self.id,
                                  uid,
                                  assume_unique=False,
                                  return_indices=True)
     b = Bunch()
     for k in self.__dataclass_fields__.keys():
         b[k] = self.__getattribute__(k)[iself[uind]]
     return b
예제 #18
0
 def test_check_timestamps(self):
     FPS = 60.
     n = 1000
     self.qc.data['video'] = Bunch({'fps': FPS, 'length': n})
     self.qc.data['timestamps'] = np.array([round(1 / FPS, 4)] * n).cumsum()
     # Verify passes
     self.assertEqual('PASS', self.qc.check_timestamps())
     # Verify fails
     self.qc.data['timestamps'] = np.array([round(1 / 30, 4)] *
                                           100).cumsum()
     self.assertEqual('FAIL', self.qc.check_timestamps())
     # Verify not set
     self.qc.data['video'] = None
     self.assertEqual('NOT_SET', self.qc.check_timestamps())
예제 #19
0
 def add_lines(self,
               pos,
               orientation,
               lim=None,
               style='--',
               width=3,
               color='k'):
     """
     Method to specify position and style of horizontal or vertical reference lines
     :param pos: position of line
     :param orientation: either 'v' for vertical line or 'h' for horizontal line
     :param lim: extent of lines
     :param style: line style
     :param width: line width
     :param color: line colour
     :return:
     """
     if orientation == 'v':
         lim = self._set_default(lim, self.ylim)
         self.vlines.append(
             Bunch({
                 'pos': pos,
                 'lim': lim,
                 'style': style,
                 'width': width,
                 'color': color
             }))
     if orientation == 'h':
         lim = self._set_default(lim, self.xlim)
         self.hlines.append(
             Bunch({
                 'pos': pos,
                 'lim': lim,
                 'style': style,
                 'width': width,
                 'color': color
             }))
예제 #20
0
    def _get_psth(self, raster, trials_id, tbin=1):
        data = Bunch()
        if len(trials_id) == 0:
            data['psth_mean'] = np.zeros((raster.vals.shape[1]))
            data['psth_std'] = np.zeros((raster.vals.shape[1]))
            data['time'] = raster.time
            data['ylabel'] = raster.ylabel
        else:

            data['psth_mean'] = np.nanmean(raster.vals[trials_id], axis=0) / tbin
            data['psth_std'] = ((np.nanstd(raster.vals[trials_id], axis=0) / tbin) /
                                np.sqrt(trials_id.shape[0]))
            data['time'] = raster.time
            data['ylabel'] = raster.ylabel

        return data
예제 #21
0
def concatenate_trials(trials):
    """
    Concatenate trials from different training sessions

    :param trials: dict containing trials objects from three consecutive training sessions,
    keys are session dates
    :type trials: Bunch
    :return: trials object with data concatenated over three training sessions
    :rtype: dict
    """
    trials_all = Bunch()
    for k in TRIALS_KEYS:
        trials_all[k] = np.concatenate(
            list(trials[kk][k] for kk in trials.keys()))

    return trials_all
예제 #22
0
    def __init__(self, session_path_or_eid, side, **kwargs):
        """
        :param session_path_or_eid: A session eid or path
        :param side: The camera to run QC on
        :param log: A logging.Logger instance, if None the 'ibllib' logger is used
        :param one: An ONE instance for fetching and setting the QC on Alyx
        """
        # Make sure the type of camera is chosen
        self.side = side
        # When an eid is provided, we will download the required data by default (if necessary)
        download_data = not is_session_path(session_path_or_eid)
        self.download_data = kwargs.pop('download_data', download_data)
        super().__init__(session_path_or_eid, **kwargs)
        self.data = Bunch()

        # QC outcomes map
        self.metrics = None
예제 #23
0
 def get_sync_fronts(auxiliary_name):
     d = Bunch({'times': [], 'nsync': np.zeros(nprobes, )})
     # auxiliary_name: frame2ttl or right_camera
     for ind, ephys_file in enumerate(ephys_files):
         sync = alfio.load_object(ephys_file.ap.parent,
                                  'sync',
                                  namespace='spikeglx',
                                  short_keys=True)
         sync_map = get_ibl_sync_map(ephys_file, '3A')
         # exits if sync label not found for current probe
         if auxiliary_name not in sync_map:
             return
         isync = np.in1d(sync['channels'],
                         np.array([sync_map[auxiliary_name]]))
         # only returns syncs if we get fronts for all probes
         if np.all(~isync):
             return
         d.nsync[ind] = len(sync.channels)
         d['times'].append(sync['times'][isync])
     return d
예제 #24
0
    def test_check_camera_times(self):
        outcome = self.qc.check_camera_times()
        self.assertEqual('NOT_SET', outcome)

        # Verify passes
        self.qc.label = 'body'
        ts_path = Path(__file__).parents[1].joinpath('extractors', 'data',
                                                     'session_ephys')
        ssv_times = load_camera_ssv_times(ts_path, self.qc.label)
        self.qc.data.bonsai_times, self.qc.data.camera_times = ssv_times
        self.qc.data.video = Bunch({'length': self.qc.data.bonsai_times.size})

        outcome, _ = self.qc.check_camera_times()
        self.assertEqual('PASS', outcome)

        # Verify warning
        n_over = 14
        self.qc.data.video['length'] -= n_over
        outcome, actual = self.qc.check_camera_times()

        self.assertEqual('WARNING', outcome)
        self.assertEqual(n_over, actual)
예제 #25
0
    def set_labels(self,
                   title=None,
                   xlabel=None,
                   ylabel=None,
                   zlabel=None,
                   clabel=None):
        """
        Set labels for plot

        :param title: title
        :param xlabel: x axis label
        :param ylabel: y axis label
        :param zlabel: z axis label
        :param clabel: cbar label
        :return:
        """
        self.labels = Bunch({
            'title': title,
            'xlabel': xlabel,
            'ylabel': ylabel,
            'zlabel': zlabel,
            'clabel': clabel
        })
예제 #26
0
    def _get_spike_raster(self, trial_ids):
        # Ain't most efficient but it will do for now!

        data = Bunch()

        x = np.empty(0)
        y = np.empty(0)

        epoch = [0.4, 1]
        spk_times = self.spikes.times[self.spikes.clusters == self.clust_ids[self.clust]]

        for idx, val in enumerate(self.trials[self.trial_event][trial_ids]):
            spks_to_include = np.bitwise_and(spk_times >= val - epoch[0],
                                             spk_times <= val + epoch[1])
            trial_spk_times = spk_times[spks_to_include] - val
            x = np.append(x, trial_spk_times)
            y = np.append(y, np.ones(len(trial_spk_times)) * idx)

        data['raster'] = np.c_[x, y]
        data['time'] = EPOCH
        data['n_trials'] = self.n_trials

        return data
subjects = [
    'dop_24', 'dop_14', 'dop_13', 'dop_16', 'dop_21', 'dop_22', 'dop_36'
]

fig = rendering.figure()

for subject in subjects:
    ba = atlas.AllenAtlas(25)
    channels_rest = one.alyx.rest('channels', 'list', subject=subject)
    channels = Bunch({
        'atlas_id':
        np.array([ch['brain_region'] for ch in channels_rest]),
        'xyz':
        np.c_[np.array([ch['x'] for ch in channels_rest]),
              np.array([ch['y'] for ch in channels_rest]),
              np.array([ch['z'] for ch in channels_rest])] / 1e6,
        'axial_um':
        np.array([ch['axial'] for ch in channels_rest]),
        'lateral_um':
        np.array([ch['lateral'] for ch in channels_rest]),
        'trajectory_id':
        np.array([ch['trajectory_estimate'] for ch in channels_rest])
    })

    for m, probe_id in enumerate(np.unique(channels['trajectory_id'])):
        traj_dict = one.alyx.rest('trajectories', 'read', id=probe_id)
        ses = traj_dict['session']
        label = (f"{ses['subject']}/{ses['start_time'][:10]}/"
                 f"{str(ses['number']).zfill(3)}/{traj_dict['probe_name']}")
        print(label)

        color = ibllib.plots.color_cycle(m)
예제 #28
0
 def setUp(self):
     eid = '8dd0fcb0-1151-4c97-ae35-2e2421695ad7'
     one = ONE(**TEST_DB)
     self.qc = qcmetrics.HabituationQC(eid, one=one)
     self.qc.extractor = Bunch({'data': self.load_fake_bpod_data()
                                })  # Dummy extractor obj
]
probes = [
    'probe01', 'probe01', 'probe00', 'probe01', 'probe00', 'probe01', 'probe00'
]

# fetching data part
brain_atlas = atlas.AllenAtlas(25)
file_pickle = Path(cache_dir).joinpath('repeated_sites_channels.pkl')
if file_pickle.exists() or False:
    ins = pickle.load(open(file_pickle, 'rb'))
else:
    one = ONE()
    ins = Bunch({
        'eid': eids,
        'probe_label': probes,
        'insertion': [],
        'channels': [],
        'session': []
    })
    for eid, probe_label in zip(ins.eid, ins.probe_label):
        traj = one.alyx.rest('trajectories',
                             'list',
                             session=eid,
                             provenance='Histology track',
                             probe=probe_label)[0]
        ses = one.alyx.rest('sessions', 'read', id=eid)
        channels = bbone.load_channel_locations(eid=ses,
                                                one=one,
                                                probe=probe_label)[probe_label]
        insertion = atlas.Insertion.from_dict(traj)
        ins.insertion.append(insertion)
예제 #30
0
def quick_unit_metrics(spike_clusters,
                       spike_times,
                       spike_amps,
                       spike_depths,
                       params=METRICS_PARAMS,
                       cluster_ids=None,
                       tbounds=None):
    """
    Computes single unit metrics from only the spike times, amplitudes, and
    depths for a set of units.

    Metrics computed:
        'amp_max',
        'amp_min',
        'amp_median',
        'amp_std_dB',
        'contamination',
        'contamination_alt',
        'drift',
        'missed_spikes_est',
        'noise_cutoff',
        'presence_ratio',
        'presence_ratio_std',
        'slidingRP_viol',
        'spike_count'

    Parameters (see the METRICS_PARAMS constant)
    ----------
    spike_clusters : ndarray_like
        A vector of the unit ids for a set of spikes.
    spike_times : ndarray_like
        A vector of the timestamps for a set of spikes.
    spike_amps : ndarray_like
        A vector of the amplitudes for a set of spikes.
    spike_depths : ndarray_like
        A vector of the depths for a set of spikes.
    clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the
    spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent
    with the input arrays.
    tbounds: (optional) list or 2 elements array containing a time-selection to perform the
     metrics computation on.
    params : dict (optional)
        Parameters used for computing some of the metrics in the function:
            'presence_window': float
                The time window (in s) used to look for spikes when computing the presence ratio.
            'refractory_period': float
                The refractory period used when computing isi violations and the contamination
                estimate.
            'min_isi': float
                The minimum interspike-interval (in s) for counting duplicate spikes when computing
                the contamination estimate.
            'spks_per_bin_for_missed_spks_est': int
                The number of spikes per bin used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.
            'std_smoothing_kernel_for_missed_spks_est': float
                The standard deviation for the gaussian kernel used to compute the spike amplitude
                pdf for a unit, when computing the missed spikes estimate.
            'min_num_bins_for_missed_spks_est': int
                The minimum number of bins used to compute the spike amplitude pdf for a unit,
                when computing the missed spikes estimate.

    Returns
    -------
    r : bunch
        A bunch whose keys are the computed spike metrics.

    Notes
    -----
    This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf`
    during alf extraction of an ephys dataset in the ibl ephys extraction pipeline.

    Examples
    --------
    1) Compute quick metrics from a ks2 output directory:
        >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path
        >>> m = phy_model_from_ks2_path(path_to_ks2_out)
        >>> cluster_ids = m.spike_clusters
        >>> ts = m.spike_times
        >>> amps = m.amplitudes
        >>> depths = m.depths
        >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths)
    """
    metrics_list = [
        'cluster_id', 'amp_max', 'amp_min', 'amp_median', 'amp_std_dB',
        'contamination', 'contamination_alt', 'drift', 'missed_spikes_est',
        'noise_cutoff', 'presence_ratio', 'presence_ratio_std',
        'slidingRP_viol', 'spike_count'
    ]
    if tbounds:
        ispi = between_sorted(spike_times, tbounds)
        spike_times = spike_times[ispi]
        spike_clusters = spike_clusters[ispi]
        spike_amps = spike_amps[ispi]
        spike_depths = spike_depths[ispi]

    if cluster_ids is None:
        cluster_ids = np.unique(spike_clusters)
    nclust = cluster_ids.size

    r = Bunch({k: np.full((nclust, ), np.nan) for k in metrics_list})
    r['cluster_id'] = cluster_ids

    # vectorized computation of basic metrics such as presence ratio and firing rate
    tmin = spike_times[0]
    tmax = spike_times[-1]
    presence_ratio = bincount2D(spike_times,
                                spike_clusters,
                                xbin=params['presence_window'],
                                ybin=cluster_ids,
                                xlim=[tmin, tmax])[0]
    r.presence_ratio = np.sum(presence_ratio > 0,
                              axis=1) / presence_ratio.shape[1]
    r.presence_ratio_std = np.std(presence_ratio, axis=1)
    r.spike_count = np.sum(presence_ratio, axis=1)
    r.firing_rate = r.spike_count / (tmax - tmin)

    # computing amplitude statistical indicators by aggregating over cluster id
    camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps),
                              spike_clusters],
                        columns=['amps', 'log_amps', 'clusters'])
    camp = camp.groupby('clusters')
    ir, ib = ismember(r.cluster_id, camp.clusters.unique())
    r.amp_min[ir] = np.array(camp['amps'].min())
    r.amp_max[ir] = np.array(camp['amps'].max())
    # this is the geometric median
    r.amp_median[ir] = np.array(10**(camp['log_amps'].median() / 20))
    r.amp_std_dB[ir] = np.array(camp['log_amps'].std())

    # loop over each cluster to compute the rest of the metrics
    for ic in np.arange(nclust):
        # slice the spike_times array
        ispikes = spike_clusters == cluster_ids[ic]
        if np.all(~ispikes):  # if this cluster has no spikes, continue
            continue
        ts = spike_times[ispikes]
        amps = spike_amps[ispikes]
        depths = spike_depths[ispikes]

        # compute metrics
        r.contamination_alt[ic] = contamination_alt(
            ts, rp=params['refractory_period'])
        r.contamination[ic], _ = contamination(ts,
                                               tmin,
                                               tmax,
                                               rp=params['refractory_period'],
                                               min_isi=params['min_isi'])
        r.slidingRP_viol[ic] = slidingRP_viol(
            ts,
            bin_size=params['bin_size'],
            thresh=params['RPslide_thresh'],
            acceptThresh=params['acceptable_contamination'])
        r.noise_cutoff[ic] = noise_cutoff(
            amps,
            quartile_length=params['nc_quartile_length'],
            n_bins=params['nc_bins'],
            n_low_bins=params['nc_n_low_bins'])
        r.missed_spikes_est[ic], _, _ = missed_spikes_est(
            amps,
            spks_per_bin=params['spks_per_bin_for_missed_spks_est'],
            sigma=params['std_smoothing_kernel_for_missed_spks_est'],
            min_num_bins=params['min_num_bins_for_missed_spks_est'])

        # wonder if there is a need to low-cut this
        r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600

    r.label = compute_labels(r)
    return r