def estimate_drift(spike_times, spike_amps, spike_depths, display=False): """ Estimate drift for spike sorted data. :param spike_times: :param spike_amps: :param spike_depths: :param display: :return: """ # binning parameters DT_SECS = 1 # output sampling rate of the depth estimation (seconds) DEPTH_BIN_UM = 2 # binning parameter for depth AMP_RES_V = 100 * 1e-6 # binning parameter for amplitudes NXCORR = 50 # positive and negative lag in depth samples to look for depth NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate) # experimental: try the amp with a log scale na = int(np.ceil(np.nanmax(spike_amps) / AMP_RES_V)) nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM)) nt = int(np.ceil(np.max(spike_times) / DT_SECS)) # 3d histogram of spikes along amplitude, depths and time atd_hist = np.zeros((na, nt, nd)) abins = np.ceil(spike_amps / AMP_RES_V) for i, abin in enumerate(np.unique(abins)): inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0] a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM], [0, nt * DT_SECS]) atd_hist[i] = a[:-1, :-1] # compute the depth lag by xcorr # experimental: LP the fft for a better tracking ? atd_ = np.fft.fft(atd_hist, axis=-1) xcorr = np.real( np.fft.ifft(atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :])) xcorr = np.sum(xcorr, axis=0) xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]] # experimental: parabolic fit to get max values raw_drift = (np.argmax(xcorr, axis=-1) - NXCORR) * DEPTH_BIN_UM drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning') if display: import matplotlib.pyplot as plt from brainbox.plot import driftmap _, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True) axs[0].plot(DT_SECS * np.arange(drift.size), drift) driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1]) return drift
def plot_driftmap(self, spikes, clusters, channels, collection, ylim=(0, 3840)): fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0]) title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \ f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters" axs[0].set(ylim=ylim, title=title_str) run_label = str(Path(collection).relative_to(f'alf/{self.pname}')) run_label = "ks2matlab" if run_label == '.' else run_label outfile = self.output_directory.joinpath( f"spike_sorting_raster_{run_label}.png") set_axis_label_size(axs[0]) if self.histology_status: plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status) axs[1].set(ylim=ylim) set_axis_label_size(axs[1]) else: remove_axis_outline(axs[1]) fig.savefig(outfile) plt.close(fig) return outfile, fig, axs
one = ONE() # Find sessions dataset_types = ['spikes.times', 'spikes.amps', 'spikes.depths'] # eids = one.search(dataset_types=dataset_types, # project='ibl_neuropixel_brainwide_01', # task_protocol='_iblrig_tasks_ephysChoiceWorld') # # eid = eids[0] # Test with little drift: '7cdb71fb-928d-4eea-988f-0b655081f21c' eid = '89f0d6ff-69f4-45bc-b89e-72868abb042a' # Test with huge drift # Get dataset spike_times, spike_amps, spike_depths = \ one.load(eid, dataset_types=dataset_types) drift = estimate_drift(spike_times, spike_amps, spike_depths, display=False) # PLOT # Tight layout fig3 = plt.figure(constrained_layout=True) gs = fig3.add_gridspec(3, 3) f3_ax0 = fig3.add_subplot(gs[0, :]) f3_ax0.plot(drift) f3_ax1 = fig3.add_subplot(gs[1:, :]) bbplot.driftmap(spike_times, spike_depths, ax=f3_ax1, plot_style='bincount') f3_ax0.set_xlim(f3_ax1.get_xlim())
## Example 2: Plot Insertion for a given PID (todo: use Needles 2 for interactive) av = run_needles2.view(lazy=True) av.add_insertion_by_id(pid) ## Example 3: Show the PSD raw = sr[:, :-1].T fig, axes = plt.subplots(1, 2) fig.set_size_inches(18, 7) show_psd(raw, sr.fs, ax=axes[0]) ## Example 4: Display the raw / pre-proc and KS2 parts - h = neuropixel.trace_header() sos = scipy.signal.butter(3, 300 / sr.fs / 2, btype='highpass', output='sos') butt = scipy.signal.sosfiltfilt(sos, raw) fk_kwargs ={'dx': 1, 'vbounds': [0, 1e6], 'ntr_pad': 160, 'ntr_tap': 0, 'lagc': .01, 'btype': 'lowpass'} destripe = voltage.destripe(raw, fs=sr.fs, fk_kwargs=fk_kwargs, tr_sel=np.arange(raw.shape[0])) ks2 = get_ks2(raw, dsets, one) eqc_butt = viewseis(butt.T, si=1 / sr.fs, h=h, t0=t0, title='butt', taxis=0) eqc_dest = viewseis(destripe.T, si=1 / sr.fs, h=h, t0=t0, title='destr', taxis=0) eqc_ks2 = viewseis(ks2.T, si=1 / sr.fs, h=h, t0=t0, title='ks2', taxis=0) # Example 5: overlay the spikes on the existing easyqc instances spikes, clusters, channels = get_spikes(dsets, one) overlay_spikes(eqc_butt, spikes, clusters, channels) overlay_spikes(eqc_dest, spikes, clusters, channels) overlay_spikes(eqc_ks2, spikes, clusters, channels) # Do the driftmap driftmap(spikes['times'], spikes['depths'], t_bin=0.1, d_bin=5, ax=axes[1])
traj = one.alyx.rest('trajectories', 'list', session=eid, provenance='Histology track', probe=probe_label)[0] ins = atlas.Insertion.from_dict(traj) # Initialise fig subplots plt.figure(num=i_probe) fig, axs = plt.subplots(1, 3) fig.suptitle(f'Probe {probe_label}', fontsize=16) # Sagittal view sax = ba.plot_tilted_slice(ins.xyz, axis=0, ax=axs[0]) sax.plot(ins.xyz[:, 1] * 1e6, ins.xyz[:, 2] * 1e6) sax.plot(channels[probe_label].y * 1e6, channels[probe_label].z * 1e6, 'y.') # Coronal view cax = ba.plot_tilted_slice(ins.xyz, axis=1, ax=axs[1]) cax.plot(ins.xyz[:, 0] * 1e6, ins.xyz[:, 2] * 1e6) cax.plot(channels[probe_label].x * 1e6, channels[probe_label].z * 1e6, 'y.') # Raster plot -- Brainbox bbplot.driftmap(spikes[probe_label].times, spikes[probe_label].depths, ax=axs[2], plot_style='bincount')
def estimate_drift(spike_times, spike_amps, spike_depths, display=False): """ Electrode drift for spike sorted data. :param spike_times: :param spike_amps: :param spike_depths: :param display: :return: drift (ntimes vector) in input units (usually um) :return: ts (ntimes vector) time scale in seconds """ # binning parameters DT_SECS = 1 # output sampling rate of the depth estimation (seconds) DEPTH_BIN_UM = 2 # binning parameter for depth AMP_BIN_LOG10 = [1.25, 3.25] # binning parameter for amplitudes (log10 in uV) N_AMP = 1 # number of amplitude bins NXCORR = 50 # positive and negative lag in depth samples to look for depth NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate) # experimental: try the amp with a log scale nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM)) tmin, tmax = (np.min(spike_times), np.max(spike_times)) nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS) # 3d histogram of spikes along amplitude, depths and time atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single) abins = (np.log10(spike_amps * 1e6) - AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1) for i, abin in enumerate(np.unique(abins)): inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0] a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM], [np.floor(tmin), np.ceil(tmax)]) atd_hist[i] = a[:-1, :-1] fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM)) # k-filter along the depth direction lp = dsp.fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp') # compute the depth lag by xcorr # to experiment: LP the fft for a better tracking ? atd_ = np.fft.fft(atd_hist, axis=-1) # xcorrelation against reference xcorr = np.real( np.fft.ifft(lp * atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :])) xcorr = np.sum(xcorr, axis=0) xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]] xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis] # import easyqc # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor') # to experiment: parabolic fit to get max values raw_drift = (parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning') drift = drift - np.mean(drift) ts = DT_SECS * np.arange(drift.size) if display: import matplotlib.pyplot as plt from brainbox.plot import driftmap fig1, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True, figsize=(20, 10)) axs[0].plot(ts, drift) driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1]) axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2]) fig2, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True, figsize=(20, 10)) axs[0].plot(ts, drift) dd = np.interp(spike_times, ts, drift) driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1]) axs[1].set_ylim([-NXCORR * 2, 3840 + NXCORR * 2]) return drift, ts, [fig1, fig2] return drift, ts