def extract_spikes(traces, interval, sample_rate=None, spike_times=None, spike_clusters=None, cluster_groups=None, all_masks=None, n_samples_waveforms=None): cluster_groups = cluster_groups or {} sr = sample_rate ns = n_samples_waveforms if not isinstance(ns, tuple): ns = (ns // 2, ns // 2) offset_samples = ns[0] wave_len = ns[0] + ns[1] # Find spikes. a, b = spike_times.searchsorted(interval) st = spike_times[a:b] sc = spike_clusters[a:b] m = all_masks[a:b] n = len(st) assert len(sc) == n assert m.shape[0] == n # Extract waveforms. spikes = [] for i in range(n): b = Bunch() # Find the start of the waveform in the extracted traces. sample_start = int(round((st[i] - interval[0]) * sr)) sample_start -= offset_samples o = _extract_wave(traces, sample_start, m[i], wave_len) if o is None: # pragma: no cover logger.debug("Unable to extract spike %d.", i) continue b.waveforms, b.channels = o # Masks on unmasked channels. b.masks = m[i, b.channels] b.spike_time = st[i] b.spike_cluster = sc[i] b.cluster_group = cluster_groups.get(b.spike_cluster, None) b.offset_samples = offset_samples spikes.append(b) return spikes