def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS): """ Computes: - cell level metrics (cf quick_unit_metrics) - label the metrics according to quality thresholds - estimates drift as a function of time :param times: vector of spike times :param clusters: :param amplitudes: :param depths: :param cluster_ids (optional): set of clusters (if None the output datgrame will match the unique set of clusters represented in spike clusters) :param params: dict (optional) parameters for qc computation ( see constant at the top of the module for default values and keys) :return: data_frame of metrics (cluster records, columns are qc attributes)| :return: dictionary of recording qc (keys 'time_scale' and 'drift_um') """ # compute metrics and convert to `DataFrame` df_units = quick_unit_metrics(clusters, times, amps, depths, cluster_ids=cluster_ids, params=params) df_units = pd.DataFrame(df_units) # compute drift as a function of time and put in a dictionary drift, ts = electrode_drift.estimate_drift(times, amps, depths) rec_qc = {'time_scale': ts, 'drift_um': drift} return df_units, rec_qc
def test_drift_estimate(): """ From spike depths, xcorrelate drift maps to find a drift estimate """ np.random.seed(42) ncells = 200 cells_depth = np.random.random(ncells) * 3800 + 50 frs = np.random.randn(ncells) * 50 + 200 t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=200) # test negative times, no drift drift, ts = electrode_drift.estimate_drift(t - 2, a, cells_depth[c]) assert(np.all(np.abs(drift) < 0.01)) # test drift recovery - sinusoid 40 um peak amplitude dcor = np.sin(2 * np.pi * t / np.max(t) * 2) * 50 drift, ts = electrode_drift.estimate_drift(t, a, cells_depth[c] + dcor, display=False) drift_ = np.sin(2 * np.pi * ts / np.max(t) * 2) * 50 # import matplotlib.pyplot as plt # plt.plot(ts, drift_) # plt.plot(ts, drift) assert np.all(np.abs(drift - drift_)[2:] < 4)
one = ONE() # Find sessions dataset_types = ['spikes.times', 'spikes.amps', 'spikes.depths'] # eids = one.search(dataset_types=dataset_types, # project='ibl_neuropixel_brainwide_01', # task_protocol='_iblrig_tasks_ephysChoiceWorld') # # eid = eids[0] # Test with little drift: '7cdb71fb-928d-4eea-988f-0b655081f21c' eid = '89f0d6ff-69f4-45bc-b89e-72868abb042a' # Test with huge drift # Get dataset spike_times, spike_amps, spike_depths = \ one.load(eid, dataset_types=dataset_types) drift = estimate_drift(spike_times, spike_amps, spike_depths, display=False) # PLOT # Tight layout fig3 = plt.figure(constrained_layout=True) gs = fig3.add_gridspec(3, 3) f3_ax0 = fig3.add_subplot(gs[0, :]) f3_ax0.plot(drift) f3_ax1 = fig3.add_subplot(gs[1:, :]) bbplot.driftmap(spike_times, spike_depths, ax=f3_ax1, plot_style='bincount') f3_ax0.set_xlim(f3_ax1.get_xlim())
""" Download data and plot drift over the session ============================================== Downloads LFP power spectrum for a given session and probe and plots a heatmap of power spectrum on the channels along probe against frequency """ # import modules from one.api import ONE from brainbox.metrics import electrode_drift # instantiate one one = ONE(base_url='https://openalyx.internationalbrainlab.org', silent=True) # Specify subject, date and probe we are interested in subject = 'CSHL049' date = '2020-01-08' sess_no = 1 probe_label = 'probe00' eid = one.search(subject=subject, date=date, number=sess_no)[0] # Download and load the spikes data spikes = one.load_object(eid, 'spikes', collection=f'alf/{probe_label}') # Use brainbox function to compute drift over session drift = electrode_drift.estimate_drift(spikes['times'], spikes['amps'], spikes['depths'], display=True)
import matplotlib.pyplot as plt one = ONE() # Find sessions dsets = ['spikes.times.npy', 'spikes.amps.npy', 'spikes.depths.npy'] # eids = one.search(dataset=dsets, # project='ibl_neuropixel_brainwide_01', # task_protocol='_iblrig_tasks_ephysChoiceWorld') # # eid = eids[0] # Test with little drift: '7cdb71fb-928d-4eea-988f-0b655081f21c' eid = '89f0d6ff-69f4-45bc-b89e-72868abb042a' # Test with huge drift # Get dataset spikes = one.load_object(eid, 'spikes', collection='alf/probe00') drift = estimate_drift(spikes.times, spikes.amps, spikes.depths, display=False) # PLOT # Tight layout fig3 = plt.figure(constrained_layout=True) gs = fig3.add_gridspec(3, 3) f3_ax0 = fig3.add_subplot(gs[0, :]) f3_ax0.plot(drift) f3_ax1 = fig3.add_subplot(gs[1:, :]) bbplot.driftmap(spikes.times, spikes.depths, ax=f3_ax1, plot_style='bincount') f3_ax0.set_xlim(f3_ax1.get_xlim())