def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
    """
    Estimate drift for spike sorted data.
    :param spike_times:
    :param spike_amps:
    :param spike_depths:
    :param display:
    :return:
    """
    # binning parameters
    DT_SECS = 1  # output sampling rate of the depth estimation (seconds)
    DEPTH_BIN_UM = 2  # binning parameter for depth
    AMP_RES_V = 100 * 1e-6  # binning parameter for amplitudes
    NXCORR = 50  # positive and negative lag in depth samples to look for depth
    NT_SMOOTH = 9  # length of the Gaussian smoothing window in samples (DT_SECS rate)

    # experimental: try the amp with a log scale
    na = int(np.ceil(np.nanmax(spike_amps) / AMP_RES_V))
    nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
    nt = int(np.ceil(np.max(spike_times) / DT_SECS))

    # 3d histogram of spikes along amplitude, depths and time
    atd_hist = np.zeros((na, nt, nd))
    abins = np.ceil(spike_amps / AMP_RES_V)
    for i, abin in enumerate(np.unique(abins)):
        inds = np.where(np.logical_and(abins == abin,
                                       ~np.isnan(spike_depths)))[0]
        a, _, _ = bincount2D(spike_depths[inds], spike_times[inds],
                             DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM],
                             [0, nt * DT_SECS])
        atd_hist[i] = a[:-1, :-1]

    # compute the depth lag by xcorr
    # experimental: LP the fft for a better tracking ?
    atd_ = np.fft.fft(atd_hist, axis=-1)
    xcorr = np.real(
        np.fft.ifft(atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
    xcorr = np.sum(xcorr, axis=0)
    xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]

    # experimental: parabolic fit to get max values
    raw_drift = (np.argmax(xcorr, axis=-1) - NXCORR) * DEPTH_BIN_UM
    drift = smooth.rolling_window(raw_drift,
                                  window_len=NT_SMOOTH,
                                  window='hanning')

    if display:
        import matplotlib.pyplot as plt
        from brainbox.plot import driftmap
        _, axs = plt.subplots(2,
                              1,
                              gridspec_kw={'height_ratios': [.15, .85]},
                              sharex=True)
        axs[0].plot(DT_SECS * np.arange(drift.size), drift)
        driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])

    return drift
Example #2
0
        def plot_driftmap(self,
                          spikes,
                          clusters,
                          channels,
                          collection,
                          ylim=(0, 3840)):
            fig, axs = plt.subplots(1,
                                    2,
                                    gridspec_kw={'width_ratios': [.95, .05]},
                                    figsize=(16, 9))
            driftmap(spikes.times,
                     spikes.depths,
                     t_bin=0.007,
                     d_bin=10,
                     vmax=0.5,
                     ax=axs[0])
            title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \
                        f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
            axs[0].set(ylim=ylim, title=title_str)
            run_label = str(Path(collection).relative_to(f'alf/{self.pname}'))
            run_label = "ks2matlab" if run_label == '.' else run_label
            outfile = self.output_directory.joinpath(
                f"spike_sorting_raster_{run_label}.png")
            set_axis_label_size(axs[0])

            if self.histology_status:
                plot_brain_regions(channels['atlas_id'],
                                   channel_depths=channels['axial_um'],
                                   brain_regions=self.brain_regions,
                                   display=True,
                                   ax=axs[1],
                                   title=self.histology_status)
                axs[1].set(ylim=ylim)
                set_axis_label_size(axs[1])
            else:
                remove_axis_outline(axs[1])

            fig.savefig(outfile)
            plt.close(fig)

            return outfile, fig, axs
Example #3
0
one = ONE()

# Find sessions
dataset_types = ['spikes.times', 'spikes.amps', 'spikes.depths']

# eids = one.search(dataset_types=dataset_types,
#                   project='ibl_neuropixel_brainwide_01',
#                   task_protocol='_iblrig_tasks_ephysChoiceWorld')
#
# eid = eids[0]  # Test with little drift: '7cdb71fb-928d-4eea-988f-0b655081f21c'

eid = '89f0d6ff-69f4-45bc-b89e-72868abb042a'  # Test with huge drift

# Get dataset

spike_times, spike_amps, spike_depths = \
    one.load(eid, dataset_types=dataset_types)

drift = estimate_drift(spike_times, spike_amps, spike_depths, display=False)

# PLOT
# Tight layout
fig3 = plt.figure(constrained_layout=True)
gs = fig3.add_gridspec(3, 3)
f3_ax0 = fig3.add_subplot(gs[0, :])
f3_ax0.plot(drift)
f3_ax1 = fig3.add_subplot(gs[1:, :])
bbplot.driftmap(spike_times, spike_depths, ax=f3_ax1, plot_style='bincount')
f3_ax0.set_xlim(f3_ax1.get_xlim())
Example #4
0
## Example 2: Plot Insertion for a given PID (todo: use Needles 2 for interactive)
av = run_needles2.view(lazy=True)
av.add_insertion_by_id(pid)

## Example 3: Show the PSD
raw = sr[:, :-1].T
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(18, 7)
show_psd(raw, sr.fs, ax=axes[0])

## Example 4: Display the raw / pre-proc and KS2 parts -
h = neuropixel.trace_header()
sos = scipy.signal.butter(3, 300 / sr.fs / 2, btype='highpass', output='sos')
butt = scipy.signal.sosfiltfilt(sos, raw)
fk_kwargs ={'dx': 1, 'vbounds': [0, 1e6], 'ntr_pad': 160, 'ntr_tap': 0, 'lagc': .01, 'btype': 'lowpass'}
destripe = voltage.destripe(raw, fs=sr.fs, fk_kwargs=fk_kwargs, tr_sel=np.arange(raw.shape[0]))
ks2 = get_ks2(raw, dsets, one)
eqc_butt = viewseis(butt.T, si=1 / sr.fs, h=h, t0=t0, title='butt', taxis=0)
eqc_dest = viewseis(destripe.T, si=1 / sr.fs, h=h, t0=t0, title='destr', taxis=0)
eqc_ks2 = viewseis(ks2.T, si=1 / sr.fs, h=h, t0=t0, title='ks2', taxis=0)

# Example 5: overlay the spikes on the existing easyqc instances
spikes, clusters, channels = get_spikes(dsets, one)
overlay_spikes(eqc_butt, spikes, clusters, channels)
overlay_spikes(eqc_dest, spikes, clusters, channels)
overlay_spikes(eqc_ks2, spikes, clusters, channels)

# Do the driftmap
driftmap(spikes['times'], spikes['depths'], t_bin=0.1, d_bin=5, ax=axes[1])
Example #5
0
    traj = one.alyx.rest('trajectories',
                         'list',
                         session=eid,
                         provenance='Histology track',
                         probe=probe_label)[0]

    ins = atlas.Insertion.from_dict(traj)

    # Initialise fig subplots
    plt.figure(num=i_probe)
    fig, axs = plt.subplots(1, 3)
    fig.suptitle(f'Probe {probe_label}', fontsize=16)

    # Sagittal view
    sax = ba.plot_tilted_slice(ins.xyz, axis=0, ax=axs[0])
    sax.plot(ins.xyz[:, 1] * 1e6, ins.xyz[:, 2] * 1e6)
    sax.plot(channels[probe_label].y * 1e6, channels[probe_label].z * 1e6,
             'y.')

    # Coronal view
    cax = ba.plot_tilted_slice(ins.xyz, axis=1, ax=axs[1])
    cax.plot(ins.xyz[:, 0] * 1e6, ins.xyz[:, 2] * 1e6)
    cax.plot(channels[probe_label].x * 1e6, channels[probe_label].z * 1e6,
             'y.')

    # Raster plot -- Brainbox
    bbplot.driftmap(spikes[probe_label].times,
                    spikes[probe_label].depths,
                    ax=axs[2],
                    plot_style='bincount')
Example #6
0
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
    """
    Electrode drift for spike sorted data.
    :param spike_times:
    :param spike_amps:
    :param spike_depths:
    :param display:
    :return: drift (ntimes vector) in input units (usually um)
    :return: ts (ntimes vector) time scale in seconds

    """
    # binning parameters
    DT_SECS = 1  # output sampling rate of the depth estimation (seconds)
    DEPTH_BIN_UM = 2  # binning parameter for depth
    AMP_BIN_LOG10 = [1.25,
                     3.25]  # binning parameter for amplitudes (log10 in uV)
    N_AMP = 1  # number of amplitude bins

    NXCORR = 50  # positive and negative lag in depth samples to look for depth
    NT_SMOOTH = 9  # length of the Gaussian smoothing window in samples (DT_SECS rate)

    # experimental: try the amp with a log scale
    nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
    tmin, tmax = (np.min(spike_times), np.max(spike_times))
    nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS)

    # 3d histogram of spikes along amplitude, depths and time
    atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single)
    abins = (np.log10(spike_amps * 1e6) -
             AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP
    abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1)

    for i, abin in enumerate(np.unique(abins)):
        inds = np.where(np.logical_and(abins == abin,
                                       ~np.isnan(spike_depths)))[0]
        a, _, _ = bincount2D(spike_depths[inds], spike_times[inds],
                             DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM],
                             [np.floor(tmin), np.ceil(tmax)])
        atd_hist[i] = a[:-1, :-1]

    fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM))
    # k-filter along the depth direction
    lp = dsp.fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp')
    # compute the depth lag by xcorr
    # to experiment: LP the fft for a better tracking ?
    atd_ = np.fft.fft(atd_hist, axis=-1)
    # xcorrelation against reference
    xcorr = np.real(
        np.fft.ifft(lp * atd_ *
                    np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
    xcorr = np.sum(xcorr, axis=0)
    xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]
    xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis]
    # import easyqc
    # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor')

    # to experiment: parabolic fit to get max values
    raw_drift = (parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM
    drift = smooth.rolling_window(raw_drift,
                                  window_len=NT_SMOOTH,
                                  window='hanning')
    drift = drift - np.mean(drift)
    ts = DT_SECS * np.arange(drift.size)
    if display:
        import matplotlib.pyplot as plt
        from brainbox.plot import driftmap
        fig1, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        fig2, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        dd = np.interp(spike_times, ts, drift)
        driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        return drift, ts, [fig1, fig2]

    return drift, ts