예제 #1
0
    def compute_waveform(self, clust):
        if len(self.ephys_file_path) != 0:
            spk_times = self.spikes.times[self.clus_idx][self.spk_intervals[
                self.n_waveform]:self.spk_intervals[self.n_waveform + 1]]
            max_ch = self.clusters['channels'][self.clust_ids[clust]]
            wf = extract_waveforms(self.ephys_file_path,
                                   spk_times,
                                   max_ch,
                                   t=self.waveform_window,
                                   car=self.CAR)
            wf_mean = np.mean(wf[:, :, 0], axis=0)
            wf_std = np.std(wf[:, :, 0], axis=0)

            return self.t_waveform, wf_mean, wf_std
예제 #2
0
def ptp_over_noise(ephys_file,
                   ts,
                   ch,
                   t=2.0,
                   sr=30000,
                   n_ch_probe=385,
                   car=True):
    """
    For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes /
    the MADs of the background noise).

    Parameters
    ----------
    ephys_file : string
        The file path to the binary ephys data.
    ts : ndarray_like
        The timestamps (in s) of the spikes.
    ch : ndarray_like
        The channels on which to extract the waveforms.
    t : numeric (optional)
        The time (in ms) of the waveforms to extract to compute the ptp.
    sr : int (optional)
        The sampling rate (in hz) that the ephys data was acquired at.
    n_ch_probe : int (optional)
        The number of channels of the recording.
    car: bool (optional)
        A flag to perform common-average-referencing before extracting waveforms.

    Returns
    -------
    ptp_sigma : ndarray
        An array containing the mean ptp_over_noise values for the specified `ts` and `ch`.

    Examples
    --------
    1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude
    for unit 1.
        >>> ts = units_b['times']['1']
        >>> max_ch = max_ch = clstrs_b['channels'][1]
        >>> if max_ch < 10:  # take only channels greater than `max_ch`.
        >>>     ch = np.arange(max_ch, max_ch + 20)
        >>> elif (max_ch + 10) > 385:  # take only channels less than `max_ch`.
        >>>     ch = np.arange(max_ch - 20, max_ch)
        >>> else:  # take `n_c_ch` around `max_ch`.
        >>>     ch = np.arange(max_ch - 10, max_ch + 10)
        >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch)
    """

    # Ensure `ch` is ndarray
    ch = np.asarray(ch)
    ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch

    # Get waveforms.
    wf = extract_waveforms(ephys_file,
                           ts,
                           ch,
                           t=t,
                           sr=sr,
                           n_ch_probe=n_ch_probe,
                           car=car)

    # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch.
    mean_ptp = np.zeros((ch.size, ))
    for cur_ch in range(ch.size, ):
        mean_ptp[cur_ch] = np.mean(
            np.max(wf[:, :, cur_ch], axis=1) -
            np.min(wf[:, :, cur_ch], axis=1))

    # Compute MAD for `ch` in chunks.
    with spikeglx.Reader(ephys_file) as s_reader:
        file_m = s_reader.data  # the memmapped array
        n_chunk_samples = 5e6  # number of samples per chunk
        n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int')
        # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
        # samples that make up the first chunk.
        chunk_sample = np.arange(0,
                                 file_m.shape[0],
                                 n_chunk_samples,
                                 dtype=int)
        chunk_sample = np.append(chunk_sample, file_m.shape[0])
        # Give time estimate for computing MAD.
        t0 = time.perf_counter()
        stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1],
                                               ch],
                                        axis=0)
        dt = time.perf_counter() - t0
        print('Performing MAD computation. Estimated time is {:.2f} mins.'
              ' ({})'.format(dt * n_chunks / 60, time.ctime()))
        # Compute MAD for each chunk, then take the median MAD of all chunks.
        mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16)
        for chunk in range(n_chunks):
            mad_chunks[chunk, :] = stats.median_absolute_deviation(
                file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch],
                axis=0,
                scale=1)
    print('Done. ({})'.format(time.ctime()))

    # Return `mean_ptp` over `mad`
    mad = np.median(mad_chunks, axis=0)
    ptp_sigma = mean_ptp / mad
    return ptp_sigma
예제 #3
0
def wf_comp(ephys_file,
            ts1,
            ts2,
            ch,
            sr=30000,
            n_ch_probe=385,
            dtype='int16',
            car=True,
            col=['b', 'r'],
            ax=None):
    '''
    Plots two different sets of waveforms across specified channels after (optionally)
    common-average-referencing. In this way, waveforms can be compared to see if there is,
    e.g. drift during the recording, or if two units should be merged, or one unit should be split.

    Parameters
    ----------
    ephys_file : string
        The file path to the binary ephys data.
    ts1 : array_like
        A set of timestamps for which to compare waveforms with `ts2`.
    ts2: array_like
        A set of timestamps for which to compare waveforms with `ts1`.
    ch : array-like
        The channels to use for extracting and plotting the waveforms.
    sr : int (optional)
        The sampling rate (in hz) that the ephys data was acquired at.
    n_ch_probe : int (optional)
        The number of channels of the recording.
    dtype: str (optional)
        The datatype represented by the bytes in `ephys_file`.
    car: bool (optional)
        A flag for whether or not to perform common-average-referencing before extracting waveforms
    col: list of strings or float arrays (optional)
        Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms
        will be plotted in, respectively.
    ax : axessubplot (optional)
        The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)

    Returns
    -------
    wf1 : ndarray
        The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels).
    wf2 : ndarray
        The waveforms for the spikes in `ts2`: an array of shape (#spikes, #samples, #channels).
    s : float
        The similarity score between the two sets of waveforms, calculated by
        `single_units.wf_similarity`

    See Also
    --------
    io.extract_waveforms
    single_units.wf_similarity

    Examples
    --------
    1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel
    of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth
    minutes for unit2, across 10 channels around the mean.
        # Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1:
        >>> ts1 = units_b['times']['1'][:100]
        >>> ts2 = units_b['times']['1'][-100:]
        >>> max_ch = clstrs_b['channels'][1]
        >>> if max_ch < n_c_ch:  # take only channels greater than `max_ch`.
        >>>     ch = np.arange(max_ch, max_ch + 20)
        >>> elif (max_ch + n_c_ch) > n_ch_probe:  # take only channels less than `max_ch`.
        >>>     ch = np.arange(max_ch - 20, max_ch)
        >>> else:  # take `n_c_ch` around `max_ch`.
        >>>     ch = np.arange(max_ch - 10, max_ch + 10)
        >>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch)
        # Plot waveforms for unit2 from the first and fourth minutes across 10 channels.
        >>> ts = units_b['times']['2']
        >>> ts1_2 = ts[np.where(ts<60)[0]]
        >>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]]
        >>> max_ch = clstrs_b['channels'][2]
        >>> if max_ch < n_c_ch:  # take only channels greater than `max_ch`.
        >>>     ch = np.arange(max_ch, max_ch + 10)
        >>> elif (max_ch + n_c_ch) > n_ch_probe:  # take only channels less than `max_ch`.
        >>>     ch = np.arange(max_ch - 10, max_ch)
        >>> else:  # take `n_c_ch` around `max_ch`.
        >>>     ch = np.arange(max_ch - 5, max_ch + 5)
        >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch)
    '''

    # Ensure `ch` is ndarray
    ch = np.asarray(ch)
    ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch

    # Extract the waveforms for these timestamps and compute similarity score.
    wf1 = extract_waveforms(ephys_file,
                            ts1,
                            ch,
                            sr=sr,
                            n_ch_probe=n_ch_probe,
                            dtype=dtype,
                            car=car)
    wf2 = extract_waveforms(ephys_file,
                            ts2,
                            ch,
                            sr=sr,
                            n_ch_probe=n_ch_probe,
                            dtype=dtype,
                            car=car)
    s = single_units.wf_similarity(wf1, wf2)

    # Plot these waveforms against each other.
    n_ch = ch.size
    if ax is None:
        fig, ax = plt.subplots(
            nrows=n_ch,
            ncols=2)  # left col is all waveforms, right col is mean
    for cur_ax, cur_ch in enumerate(ch):
        ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0])
        ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1])
        ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0])
        ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1])
        ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch))
    ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s))
    ax[0][1].set_title('Mean Waveforms')
    plt.legend(['1st spike set', '2nd spike set'])

    return wf1, wf2, s
예제 #4
0
def depth(ephys_file,
          spks_b,
          clstrs_b,
          chnls_b,
          tmplts_b,
          unit,
          n_ch=12,
          n_ch_probe=385,
          sr=30000,
          dtype='int16',
          car=False):
    '''
    Gets `n_ch` channels around a unit's channel of max amplitude, extracts all unit spike
    waveforms from binary datafile for these channels, and for each spike, computes the dot
    products of waveform by unit template for those channels, and computes center-of-mass of these
    dot products to get spike depth estimates.

    Parameters
    ----------
    ephys_file : string
        The file path to the binary ephys data.
    spks_b : bunch
        A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features,
        etc.) for all spikes.
    clstrs_b : bunch
        A clusters bunch containing fields with cluster information (e.g. amp, ch of max amp, depth
        of ch of max amp, etc.) for all clusters.
    chnls_b : bunch
        A channels bunch containing fields with channel information (e.g. coordinates, indices,
        etc.) for all probe channels.
    tmplts_b : bunch
        A unit templates bunch containing fields with unit template information (e.g. template
        waveforms, etc.) for all unit templates.
    unit : numeric
        The unit for which to return the spikes depths.
    n_ch : int (optional)
        The number of channels to sample around the channel of max amplitude to compute the depths.
    sr : int (optional)
        The sampling rate (hz) that the ephys data was acquired at.
    n_ch_probe : int (optional)
        The number of channels of the recording.
    dtype: str (optional)
        The datatype represented by the bytes in `ephys_file`.
    car: bool (optional)
        A flag to perform common-average-referencing before extracting waveforms.

    Returns
    -------
    d : ndarray
        The estimated spike depths for all spikes in `unit`.

    See Also
    --------
    io.extract_waveforms

    Examples
    --------
    1) Get the spike depths for unit 1.
        >>> import numpy as np
        >>> import brainbox as bb
        >>> import alf.io as aio
        >>> import ibllib.ephys.spikes as e_spks
        (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
        >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
        # Get the necessary alf objects from an alf directory.
        >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
        >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
        >>> chnls_b = aio.load_object(path_to_alf_out, 'channels')
        >>> tmplts_b = aio.load_object(path_to_alf_out, 'templates')
        # Compute spike depths.
        >>> unit1_depths = bb.spike_features.depth(path_to_ephys_file, spks_b, clstrs_b, chnls_b,
                                                   tmplts_b, unit=1)
    '''

    # Set constants: #
    n_c_ch = n_ch // 2  # number of close channels to take on either side of max channel

    # Get unit waveforms: #
    # Get unit timestamps.
    unit_spk_indxs = np.where(spks_b['clusters'] == unit)[0]
    ts = spks_b['times'][unit_spk_indxs]
    # Get `n_close_ch` channels around channel of max amplitude.
    max_ch = clstrs_b['channels'][unit]
    if max_ch < n_c_ch:  # take only channels greater than `max_ch`.
        ch = np.arange(max_ch, max_ch + n_ch)
    elif (max_ch +
          n_c_ch) > n_ch_probe:  # take only channels less than `max_ch`.
        ch = np.arange(max_ch - n_ch, max_ch)
    else:  # take `n_c_ch` around `max_ch`.
        ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch)
    # Get unit template across `ch` and extract waveforms from `ephys_file`.
    tmplt_wfs = tmplts_b['waveforms']
    unit_tmplt = tmplt_wfs[unit, :, ch].T
    wf_t = tmplt_wfs.shape[1] / (sr / 1000)  # duration (ms) of each waveform
    wf = extract_waveforms(ephys_file=ephys_file,
                           ts=ts,
                           ch=ch,
                           t=wf_t,
                           sr=sr,
                           n_ch_probe=n_ch_probe,
                           dtype='int16',
                           car=car)

    # Compute center-of-mass: #
    ch_depths = chnls_b['localCoordinates'][[ch], [1]]
    d = np.zeros_like(ts)  # depths array
    # Compute normalized dot product of (waveforms,unit_template) across `ch`,
    # and get center-of-mass, `c_o_m`, of these dot products (one dot for each ch)
    for spk in range(len(ts)):
        dot_wf_template = np.sum(wf[spk, :, :] * unit_tmplt, axis=0)
        dot_wf_template += np.abs(np.min(dot_wf_template))
        dot_wf_template /= np.max(dot_wf_template)
        c_o_m = (1 / np.sum(dot_wf_template)) * np.sum(
            dot_wf_template * ch_depths)
        d[spk] = c_o_m
    return d