def get_traces(self, interval): """Load traces and spikes in an interval.""" tr = select_traces( self.all_traces, interval, sample_rate=self.sample_rate, ) tr = tr - np.mean(tr, axis=0) a, b = self.spike_times.searchsorted(interval) sc = self.spike_templates[a:b] # Remove templates. tr_sub = subtract_templates( tr, start=interval[0], spike_times=self.spike_times[a:b], spike_clusters=sc, amplitudes=self.all_amplitudes[a:b], spike_templates=self.templates_unw[sc], sample_rate=self.sample_rate, ) return [ Bunch(traces=tr), Bunch(traces=tr_sub, color=(.5, .5, .5, .75)), ]
def test_gui_state_view(tempdir): view = Bunch(name='MyView0') state = GUIState(config_dir=tempdir) state.update_view_state(view, dict(hello='world')) assert not state.get_view_state(Bunch(name='MyView')) assert not state.get_view_state(Bunch(name='MyView1')) assert state.get_view_state(view) == Bunch(hello='world')
def get_features(self, cluster_id, load_all=False): # Overriden to take into account the sparse structure. # Only keep spikes belonging to the features spike ids. if self.features_spike_ids is not None: # All spikes spike_ids = self._select_spikes(cluster_id) spike_ids = np.intersect1d(spike_ids, self.features_spike_ids) # Relative indices of the spikes in the self.features_spike_ids # array, necessary to load features from all_features which only # contains the subset of the spikes. spike_ids_rel = _index_of(spike_ids, self.features_spike_ids) else: spike_ids = self._select_spikes(cluster_id, self.n_spikes_features if not load_all else None) spike_ids_rel = spike_ids st = self.spike_templates[spike_ids] nc = self.n_channels nfpc = self.n_features_per_channel ns = len(spike_ids) f = _densify(spike_ids_rel, self.all_features, self.features_ind[st, :], self.n_channels) f = np.transpose(f, (0, 2, 1)) assert f.shape == (ns, nc, nfpc) b = Bunch() # Normalize features. m = self.get_feature_lim() f = _normalize(f, -m, m) b.data = f b.spike_ids = spike_ids b.spike_clusters = self.spike_clusters[spike_ids] b.masks = self.all_masks[spike_ids] return b
def _get_traces(self, interval): """Get traces and spike waveforms.""" k = self.model.n_samples_templates m = self.model traces_interval = select_traces(m.traces, interval, sample_rate=m.sample_rate) # Reorder vertically. out = Bunch(data=traces_interval) out.waveforms = [] def gbc(cluster_id): return self.get_best_channels(cluster_id) for b in _iter_spike_waveforms(interval=interval, traces_interval=traces_interval, model=self.model, supervisor=self.supervisor, color_selector=self.color_selector, n_samples_waveforms=k, get_best_channels=gbc, show_all_spikes=self._show_all_spikes, ): i = b.spike_id # Compute the residual: waveform - amplitude * template. residual = b.copy() template_id = m.spike_templates[i] template = m.get_template(template_id).template amplitude = m.amplitudes[i] residual.data = residual.data - amplitude * template out.waveforms.extend([b, residual]) return out
def _get_traces(self, interval): """Get traces and spike waveforms.""" ns = self.model.n_samples_waveforms m = self.model c = self.channel_vertical_order traces_interval = select_traces(m.traces, interval, sample_rate=m.sample_rate) # Reorder vertically. traces_interval = traces_interval[:, c] def gbc(cluster_id): ch = self.get_best_channels(cluster_id) return ch out = Bunch(data=traces_interval) out.waveforms = [] for b in _iter_spike_waveforms( interval=interval, traces_interval=traces_interval, model=self.model, supervisor=self.supervisor, color_selector=self.color_selector, n_samples_waveforms=ns, get_best_channels=gbc, show_all_spikes=self._show_all_spikes, ): b.channel_labels = m.channel_order[b.channel_ids] out.waveforms.append(b) return out
def state(tempdir): # Save a test GUI state JSON file in the tempdir. state = Bunch() state.WaveformView0 = Bunch(overlap=False) state.TraceView0 = Bunch(scaling=1.) state.FeatureView0 = Bunch(feature_scaling=.5) state.CorrelogramView0 = Bunch(uniform_normalization=True) return state
def _select_data(self, cluster_id, arr, n_max=None, batch_size=None): spike_ids = self._select_spikes(cluster_id, n_max, batch_size=batch_size) b = Bunch() b.data = arr[spike_ids] b.spike_ids = spike_ids b.masks = self.all_masks[spike_ids] return b
def _select_data(self, cluster_id, arr, n_max=None): spike_ids = self._select_spikes(cluster_id, n_max) b = Bunch() b.data = arr[spike_ids] b.spike_ids = spike_ids b.spike_clusters = self.spike_clusters[spike_ids] b.masks = self.all_masks[spike_ids] return b
def get_amplitudes(self, cluster_id): spike_ids = self._select_spikes(cluster_id, self.n_spikes_features) d = Bunch() d.spike_ids = spike_ids d.spike_clusters = cluster_id * np.ones(len(spike_ids), dtype=np.int32) d.x = self.spike_times[spike_ids] d.y = self.all_amplitudes[spike_ids] return d
def get_background_features(self): k = max(1, int(self.n_spikes // self.n_spikes_background_features)) spike_ids = slice(None, None, k) b = Bunch() b.data = self.all_features[spike_ids] b.spike_ids = spike_ids b.spike_clusters = self.spike_clusters[spike_ids] b.masks = self.all_masks[spike_ids] return b
def get_amplitudes(self, cluster_id): spike_ids = self._select_spikes(cluster_id, self.n_spikes_features) d = Bunch() d.spike_ids = spike_ids d.x = self.spike_times[spike_ids] d.y = self.all_amplitudes[spike_ids] M = d.y.max() d.data_bounds = [0, 0, self.duration, M] return d
def wrapped(cluster_ids, **kwargs): # Single cluster. if not hasattr(cluster_ids, '__len__'): return f(cluster_ids, **kwargs) # Concatenate the result of multiple clusters. l = [f(c, **kwargs) for c in cluster_ids] # Handle the case where every function returns a list of Bunch. if l and isinstance(l[0], list): # We assume that all items have the same length. n = len(l[0]) return [ Bunch(_accumulate([item[i] for item in l])) for i in range(n) ] else: return Bunch(_accumulate(l))
def validate(pos=None, text=None, anchor=None, data_bounds=None, ): if text is None: text = [] if isinstance(text, string_types): text = [text] if pos is None: pos = np.zeros((len(text), 2)) assert pos is not None pos = np.atleast_2d(pos) assert pos.ndim == 2 assert pos.shape[1] == 2 n_text = pos.shape[0] assert len(text) == n_text anchor = anchor if anchor is not None else (0., 0.) anchor = np.atleast_2d(anchor) if anchor.shape[0] == 1: anchor = np.repeat(anchor, n_text, axis=0) assert anchor.ndim == 2 assert anchor.shape == (n_text, 2) if data_bounds is not None: data_bounds = _get_data_bounds(data_bounds, pos) assert data_bounds.shape[0] == n_text data_bounds = data_bounds.astype(np.float64) assert data_bounds.shape == (n_text, 4) return Bunch(pos=pos, text=text, anchor=anchor, data_bounds=data_bounds)
def _get_amplitudes(self, cluster_id): n = self.n_spikes_amplitudes m = self.model spike_ids = self.selector.select_spikes([cluster_id], n) x = m.spike_times[spike_ids] y = m.amplitudes[spike_ids] return Bunch(x=x, y=y, data_bounds=(0., 0., m.duration, y.max()))
def state(self): return Bunch(box_scaling=tuple(self.box_scaling), probe_scaling=tuple(self.probe_scaling), overlap=self.overlap, do_zoom_on_channels=self.do_zoom_on_channels, do_show_labels=self.do_show_labels, )
def get_traces(self, interval): tr = select_traces( self.all_traces, interval, sample_rate=self.sample_rate, ) return [Bunch(traces=tr)]
def validate( x=None, y=None, pos=None, color=None, size=None, depth=None, data_bounds='auto', ): if pos is None: x, y = _get_pos(x, y) pos = np.c_[x, y] pos = np.asarray(pos) assert pos.ndim == 2 assert pos.shape[1] == 2 n = pos.shape[0] # Validate the data. color = _get_array(color, (n, 4), ScatterVisual._default_color, dtype=np.float32) size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) depth = _get_array(depth, (n, 1), 0) if data_bounds is not None: data_bounds = _get_data_bounds(data_bounds, pos) assert data_bounds.shape[0] == n return Bunch(pos=pos, color=color, size=size, depth=depth, data_bounds=data_bounds)
def state(self): return Bunch(speed_threshold=self.speed_threshold, speed_threshold_mode=self.speed_threshold_mode, time_range=self.time_ranges, n_rate_map_contours=self.n_rate_map_contours, rate_map_contour_mode=self.rate_map_contour_mode, spike_pos_shift=self.spike_pos_shift)
def validate( x=None, y=None, pos=None, masks=None, data_bounds='auto', ): if pos is None: x, y = _get_pos(x, y) pos = np.c_[x, y] pos = np.asarray(pos) assert pos.ndim == 2 assert pos.shape[1] == 2 n = pos.shape[0] masks = _get_array(masks, (n, 1), 1., np.float32) assert masks.shape == (n, 1) # Validate the data. if data_bounds is not None: data_bounds = _get_data_bounds(data_bounds, pos) assert data_bounds.shape[0] == n return Bunch( pos=pos, masks=masks, data_bounds=data_bounds, )
def validate(hist=None, color=None, ylim=None): assert hist is not None hist = np.asarray(hist, np.float64) if hist.ndim == 1: hist = hist[None, :] assert hist.ndim == 2 n_hists, n_bins = hist.shape # Validate the data. color = _get_array( color, (n_hists, 4), HistogramVisual._default_color, dtype=np.float32, ) # Validate ylim. if ylim is None: ylim = hist.max() if hist.size > 0 else 1. ylim = np.atleast_1d(ylim) if len(ylim) == 1: ylim = np.tile(ylim, n_hists) if ylim.ndim == 1: ylim = ylim[:, np.newaxis] assert ylim.shape == (n_hists, 1) return Bunch( hist=hist, ylim=ylim, color=color, )
def _get_template_waveforms(self, cluster_id): """Return the waveforms of the templates corresponding to a cluster.""" pos = self.model.channel_positions count = self.get_template_counts(cluster_id) template_ids = np.nonzero(count)[0] count = count[template_ids] # Get local channels. channel_ids = self.get_best_channels(cluster_id) # Get masks. masks = count / float(count.max()) masks = np.tile(masks.reshape((-1, 1)), (1, len(channel_ids))) # Get the mean amplitude for the cluster. mean_amp = self._get_amplitudes(cluster_id).y.mean() # Get all templates from which this cluster stems from. templates = [self.model.get_template(template_id) for template_id in template_ids] data = np.stack([b.template * mean_amp for b in templates], axis=0) cols = np.stack([b.channel_ids for b in templates], axis=0) # NOTE: transposition because the channels should be in the second # dimension for from_sparse. data = data.transpose((0, 2, 1)) assert data.ndim == 3 assert data.shape[1] == cols.shape[1] waveforms = from_sparse(data, cols, channel_ids) # Transpose back. waveforms = waveforms.transpose((0, 2, 1)) return Bunch(data=waveforms, channel_ids=channel_ids, channel_positions=pos[channel_ids], masks=masks, alpha=1., )
def _get_template_features(self, cluster_ids): assert len(cluster_ids) == 2 clu0, clu1 = cluster_ids s0 = self._get_spike_ids(clu0) s1 = self._get_spike_ids(clu1) n0 = self.get_template_counts(clu0) n1 = self.get_template_counts(clu1) t0 = self.model.get_template_features(s0) t1 = self.model.get_template_features(s1) x0 = np.average(t0, weights=n0, axis=1) y0 = np.average(t0, weights=n1, axis=1) x1 = np.average(t1, weights=n0, axis=1) y1 = np.average(t1, weights=n1, axis=1) return Bunch(x0=x0, y0=y0, x1=x1, y1=y1, data_bounds=(min(x0.min(), x1.min()), min(y0.min(), y1.min()), max(y0.max(), y1.max()), max(y0.max(), y1.max()), ), )
def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None): """Extract the points from the data on a given dimension. bunch is returned by the features() function. dim is the string specifying the dimensions to extract for the data. """ if dim in self.attributes: return self.attributes[dim](cluster_id, load_all=load_all) masks = bunch.get('masks', None) assert dim not in self.attributes # This is called only on PC data. s = 'ABCDEFGHIJ' # Channel relative index. c_rel = int(dim[:-1]) # Get the channel_id from the currently-selected channels. channel_id = self.channel_ids[c_rel % len(self.channel_ids)] # Skup the plot if the channel id is not displayed. if channel_id not in bunch.channel_ids: # pragma: no cover return None # Get the column index of the current channel in data. c = list(bunch.channel_ids).index(channel_id) # Principal component: A=0, B=1, etc. d = s.index(dim[-1]) if masks is not None: masks = masks[:, c] return Bunch( data=bunch.data[:, c, d], masks=masks, )
def state(self): return Bunch( scaling=self.scaling, origin=self.origin, interval=self._interval, do_show_labels=self.do_show_labels, )
def waveform_loader(request): scale_factor, dc_offset = request.param n_samples_trace, n_channels = 1000, 5 h = 10 n_samples_waveforms = 2 * h n_spikes = n_samples_trace // (2 * n_samples_waveforms) traces = artificial_traces(n_samples_trace, n_channels) spike_samples = artificial_spike_samples(n_spikes, max_isi=2 * n_samples_waveforms) with raises(ValueError): WaveformLoader(traces) loader = WaveformLoader( traces=traces, n_samples_waveforms=n_samples_waveforms, scale_factor=scale_factor, dc_offset=dc_offset, ) b = Bunch( loader=loader, n_samples_waveforms=n_samples_waveforms, n_spikes=n_spikes, spike_samples=spike_samples, ) yield b
def state(self): return Bunch( bin_size=self.bin_size, window_size=self.window_size, excerpt_size=self.excerpt_size, n_excerpts=self.n_excerpts, uniform_normalization=self.uniform_normalization, )
def _get_amplitudes(self, cluster_id): n = self.n_spikes_amplitudes m = self.model spike_ids = self.selector.select_spikes([cluster_id], n) channel_id = self.get_best_channel(cluster_id) x = m.spike_times[spike_ids] y = m.amplitudes[spike_ids, channel_id] return Bunch(x=x, y=y, data_bounds=(0., y.min(), m.duration, y.max()))
def state(self): """View state. This Bunch will be automatically persisted in the GUI state when the GUI is closed. To be overriden. """ return Bunch()
def get_traces(interval): out = Bunch(data=select_traces(traces, interval, sample_rate=sr), color=(.75,) * 4, ) a, b = st.searchsorted(interval) out.waveforms = [] k = 20 for i in range(a, b): t = st[i] c = sc[i] s = int(round(t * sr)) d = Bunch(data=traces[s - k:s + k, :], start_time=t - k / sr, color=cs.get(c), channel_ids=np.arange(5), cluster_id=c, ) out.waveforms.append(d) return out
def get_cluster_pair_features(self, ci, cj): si = self._select_spikes(ci, self.n_spikes_features) sj = self._select_spikes(cj, self.n_spikes_features) ni = self.get_cluster_templates(ci) nj = self.get_cluster_templates(cj) ti = self._get_template_features(si) x0 = np.average(ti, weights=ni, axis=1) y0 = np.average(ti, weights=nj, axis=1) tj = self._get_template_features(sj) x1 = np.average(tj, weights=ni, axis=1) y1 = np.average(tj, weights=nj, axis=1) return [ Bunch(x=x0, y=y0, spike_ids=si), Bunch(x=x1, y=y1, spike_ids=sj) ]
def get_traces(interval): out = Bunch(data=select_traces(traces, interval, sample_rate=sr), color=(.75,) * 4, ) a, b = st.searchsorted(interval) out.waveforms = [] k = 20 for i in range(a, b): t = st[i] c = sc[i] s = int(round(t * sr)) d = Bunch(data=traces[s - k:s + k, :], start_time=t - k / sr, color=cs.get(c), channel_ids=np.arange(5), spike_id=i, spike_cluster=c, ) out.waveforms.append(d) return out
def get_waveforms(self, cluster_id): m, M = self.get_waveform_lims() if self.all_waveforms is not None: # Waveforms. waveforms_b = self._select_data( cluster_id, self.all_waveforms, self.n_spikes_waveforms, ) w = waveforms_b.data # Sparsify. channels = np.nonzero(w.mean(axis=1).mean(axis=0))[0] w = w[:, :, channels] waveforms_b.channels = channels # Normalize. mean = w.mean(axis=1).mean(axis=1) w = w.astype(np.float64) w -= mean[:, np.newaxis, np.newaxis] w = _normalize(w, m, M) waveforms_b.data = w waveforms_b.cluster_id = cluster_id waveforms_b.tag = 'waveforms' else: waveforms_b = None # Find the templates corresponding to the cluster. template_ids = np.nonzero(self.get_cluster_templates(cluster_id))[0] # Templates. templates = self.templates_unw[template_ids] assert templates.ndim == 3 # Masks. masks = self.template_masks[template_ids] assert masks.ndim == 2 assert templates.shape[0] == masks.shape[0] # Find mean amplitude. spike_ids = self._select_spikes(cluster_id, self.n_spikes_waveforms_lim) mean_amp = self.all_amplitudes[spike_ids].mean() # Normalize. # mean = templates.mean(axis=1).mean(axis=1) templates = templates.astype(np.float64).copy() # templates -= mean[:, np.newaxis, np.newaxis] templates *= mean_amp templates *= 2. / (M - m) template_b = Bunch( data=templates, masks=masks, alpha=1., cluster_id=cluster_id, tag='templates', ) if waveforms_b is not None: return [waveforms_b, template_b] else: return [template_b]
def get_cluster_pair_features(self, ci, cj): si = self._select_spikes(ci, self.n_spikes_features) sj = self._select_spikes(cj, self.n_spikes_features) ni = self.get_cluster_templates(ci) nj = self.get_cluster_templates(cj) ti = self._get_template_features(si) x0 = np.sum(ti * ni[np.newaxis, :], axis=1) / ni.sum() y0 = np.sum(ti * nj[np.newaxis, :], axis=1) / nj.sum() tj = self._get_template_features(sj) x1 = np.sum(tj * ni[np.newaxis, :], axis=1) / ni.sum() y1 = np.sum(tj * nj[np.newaxis, :], axis=1) / nj.sum() d = Bunch() d.x = np.hstack((x0, x1)) d.y = np.hstack((y0, y1)) d.spike_ids = np.hstack((si, sj)) d.spike_clusters = self.spike_clusters[d.spike_ids] return d
def get_features(self, cluster_id, load_all=False): # Overriden to take into account the sparse structure. # Only keep spikes belonging to the features spike ids. if self.features_spike_ids is not None: # All spikes spike_ids = self._select_spikes(cluster_id) spike_ids = np.intersect1d(spike_ids, self.features_spike_ids) # Relative indices of the spikes in the self.features_spike_ids # array, necessary to load features from all_features which only # contains the subset of the spikes. spike_ids_rel = _index_of(spike_ids, self.features_spike_ids) else: spike_ids = self._select_spikes( cluster_id, self.n_spikes_features if not load_all else None) spike_ids_rel = spike_ids st = self.spike_templates[spike_ids] nc = self.n_channels nfpc = self.n_features_per_channel ns = len(spike_ids) f = _densify(spike_ids_rel, self.all_features, self.features_ind[st, :], self.n_channels) f = np.transpose(f, (0, 2, 1)) assert f.shape == (ns, nc, nfpc) b = Bunch() # Normalize features. m = self.get_feature_lim() f = _normalize(f, -m, m) b.data = f b.spike_ids = spike_ids b.spike_clusters = self.spike_clusters[spike_ids] b.masks = self.all_masks[spike_ids] return b
def get_background_features(self): k = max(1, int(self.n_spikes // self.n_spikes_background_features)) spike_ids = slice(None, None, k) b = Bunch() b.data = self.all_features[spike_ids] m = self.get_feature_lim() b.data = _normalize(b.data.copy(), -m, +m) b.spike_ids = spike_ids b.spike_clusters = self.spike_clusters[spike_ids] b.masks = self.all_masks[spike_ids] return b
def _get_waveforms(self, cluster_id): """Return a selection of waveforms for a cluster.""" pos = self.model.channel_positions spike_ids = self.selector.select_spikes([cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) channel_ids = self.get_best_channels(cluster_id) data = self.model.get_waveforms(spike_ids, channel_ids) data = data - data.mean() return Bunch(data=data, channel_ids=channel_ids, channel_positions=pos[channel_ids], )
def extract_spikes(traces, interval, sample_rate=None, spike_times=None, spike_clusters=None, all_masks=None, n_samples_waveforms=None): sr = sample_rate ns = n_samples_waveforms if not isinstance(ns, tuple): ns = (ns // 2, ns // 2) offset_samples = ns[0] wave_len = ns[0] + ns[1] # Find spikes. a, b = spike_times.searchsorted(interval) st = spike_times[a:b] sc = spike_clusters[a:b] m = all_masks[a:b] n = len(st) assert len(sc) == n assert m.shape[0] == n # Extract waveforms. spikes = [] for i in range(n): b = Bunch() # Find the start of the waveform in the extracted traces. sample_start = int(round((st[i] - interval[0]) * sr)) sample_start -= offset_samples o = _extract_wave(traces, sample_start, m[i], wave_len) if o is None: # pragma: no cover logger.debug("Unable to extract spike %d.", i) continue b.waveforms, b.channels = o # Masks on unmasked channels. b.masks = m[i, b.channels] b.spike_time = st[i] b.spike_cluster = sc[i] b.offset_samples = offset_samples spikes.append(b) return spikes