Exemple #1
0
 def test_2d(self):
     ipeak_, maxi_ = parabolic_max(self.x)
     self.assertTrue(np.all(np.isclose(self.maxi, maxi_, equal_nan=True)))
     self.assertTrue(np.all(np.isclose(self.ipeak, ipeak_, equal_nan=True)))
Exemple #2
0
 def test_1d(self):
     # look over the 2D array as 1D chunks
     for i, x in enumerate(self.x):
         ipeak_, maxi_ = parabolic_max(x)
         self.assertTrue(np.all(np.isclose(self.ipeak[i], ipeak_, equal_nan=True)))
         self.assertTrue(np.all(np.isclose(self.maxi[i], maxi_, equal_nan=True)))
Exemple #3
0
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
    """
    Electrode drift for spike sorted data.
    :param spike_times:
    :param spike_amps:
    :param spike_depths:
    :param display:
    :return: drift (ntimes vector) in input units (usually um)
    :return: ts (ntimes vector) time scale in seconds

    """
    # binning parameters
    DT_SECS = 1  # output sampling rate of the depth estimation (seconds)
    DEPTH_BIN_UM = 2  # binning parameter for depth
    AMP_BIN_LOG10 = [1.25,
                     3.25]  # binning parameter for amplitudes (log10 in uV)
    N_AMP = 1  # number of amplitude bins

    NXCORR = 50  # positive and negative lag in depth samples to look for depth
    NT_SMOOTH = 9  # length of the Gaussian smoothing window in samples (DT_SECS rate)

    # experimental: try the amp with a log scale
    nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
    tmin, tmax = (np.min(spike_times), np.max(spike_times))
    nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS)

    # 3d histogram of spikes along amplitude, depths and time
    atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single)
    abins = (np.log10(spike_amps * 1e6) -
             AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP
    abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1)

    for i, abin in enumerate(np.unique(abins)):
        inds = np.where(np.logical_and(abins == abin,
                                       ~np.isnan(spike_depths)))[0]
        a, _, _ = bincount2D(spike_depths[inds], spike_times[inds],
                             DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM],
                             [np.floor(tmin), np.ceil(tmax)])
        atd_hist[i] = a[:-1, :-1]

    fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM))
    # k-filter along the depth direction
    lp = dsp.fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp')
    # compute the depth lag by xcorr
    # to experiment: LP the fft for a better tracking ?
    atd_ = np.fft.fft(atd_hist, axis=-1)
    # xcorrelation against reference
    xcorr = np.real(
        np.fft.ifft(lp * atd_ *
                    np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
    xcorr = np.sum(xcorr, axis=0)
    xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]
    xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis]
    # import easyqc
    # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor')

    # to experiment: parabolic fit to get max values
    raw_drift = (parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM
    drift = smooth.rolling_window(raw_drift,
                                  window_len=NT_SMOOTH,
                                  window='hanning')
    drift = drift - np.mean(drift)
    ts = DT_SECS * np.arange(drift.size)
    if display:
        import matplotlib.pyplot as plt
        from brainbox.plot import driftmap
        fig1, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        fig2, axs = plt.subplots(2,
                                 1,
                                 gridspec_kw={'height_ratios': [.15, .85]},
                                 sharex=True,
                                 figsize=(20, 10))
        axs[0].plot(ts, drift)
        dd = np.interp(spike_times, ts, drift)
        driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1])
        axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2])
        return drift, ts, [fig1, fig2]

    return drift, ts