class MyModel(object): seed = np.random.seed(0) n_channels = 8 n_spikes = 20000 n_clusters = 32 n_templates = n_clusters n_pcs = 5 n_samples_waveforms = 100 channel_positions = np.random.normal(size=(n_channels, 2)) channel_mapping = np.arange(0, n_channels) channel_shanks = np.zeros(n_channels, dtype=np.int32) features = artificial_features(n_spikes, n_channels, n_pcs) metadata = {'group': {3: 'noise', 4: 'mua', 5: 'good'}} sample_rate = 10000 spike_attributes = {} amplitudes = np.random.normal(size=n_spikes, loc=1, scale=.1) spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) spike_templates = spike_clusters spike_samples = artificial_spike_samples(n_spikes) spike_times = spike_samples / sample_rate spike_times_reordered = artificial_spike_samples(n_spikes) / sample_rate duration = spike_times[-1] spike_waveforms = None traces = artificial_traces(int(sample_rate * duration), n_channels) def _get_some_channels(self, offset, size): return list( islice(cycle(range(self.n_channels)), offset, offset + size)) def get_features(self, spike_ids, channel_ids): return artificial_features(len(spike_ids), len(channel_ids), self.n_pcs) def get_waveforms(self, spike_ids, channel_ids): n_channels = len(channel_ids) if channel_ids else self.n_channels return artificial_waveforms(len(spike_ids), self.n_samples_waveforms, n_channels) def get_template(self, template_id): nc = self.n_channels // 2 return Bunch(template=artificial_waveforms(1, self.n_samples_waveforms, nc)[0, ...], channel_ids=self._get_some_channels(template_id, nc)) def save_spike_clusters(self, spike_clusters): pass def save_metadata(self, name, values): pass
def waveform_loader(do_filter=False, mask_threshold=None): n_samples_trace, n_channels = 1000, 5 h = 10 n_samples_waveforms = 2 * h n_spikes = n_samples_trace // (2 * n_samples_waveforms) sample_rate = 2000. traces = artificial_traces(n_samples_trace, n_channels) spike_samples = artificial_spike_samples(n_spikes, max_isi=2 * n_samples_waveforms) loader = WaveformLoader( traces=traces, spike_samples=spike_samples, n_samples_waveforms=n_samples_waveforms, filter_order=3 if do_filter else None, sample_rate=sample_rate, ) return loader
def test_iter_spike_waveforms(): nc = 5 ns = 20 sr = 2000. ch = list(range(nc)) duration = 1. st = np.linspace(0.1, .9, ns) sc = artificial_spike_clusters(ns, nc) traces = 10 * artificial_traces(int(round(duration * sr)), nc) m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr) s = Bunch(cluster_meta={}, selected=[0]) for w in _iter_spike_waveforms( interval=[0., 1.], traces_interval=traces, model=m, supervisor=s, n_samples_waveforms=ns, show_all_spikes=True, get_best_channels=lambda cluster_id: ch, ): assert w
def test_trace_view_1(qtbot, tempdir, gui): nc = 5 ns = 20 sr = 2000. duration = 1. st = np.linspace(0.1, .9, ns) sc = artificial_spike_clusters(ns, nc) traces = 10 * artificial_traces(int(round(duration * sr)), nc) def get_traces(interval): out = Bunch( data=select_traces(traces, interval, sample_rate=sr), color=(.75, .75, .75, 1), ) a, b = st.searchsorted(interval) out.waveforms = [] k = 20 for i in range(a, b): t = st[i] c = sc[i] s = int(round(t * sr)) d = Bunch( data=traces[s - k:s + k, :], start_time=(s - k) / sr, channel_ids=np.arange(5), spike_id=i, spike_cluster=c, select_index=0, ) out.waveforms.append(d) return out def get_spike_times(): return st v = TraceView( traces=get_traces, spike_times=get_spike_times, n_channels=nc, sample_rate=sr, duration=duration, channel_positions=linear_positions(nc), ) v.show() qtbot.waitForWindowShown(v.canvas) v.attach(gui) v.on_select(cluster_ids=[]) v.on_select(cluster_ids=[0]) v.on_select(cluster_ids=[0, 2, 3]) v.on_select(cluster_ids=[0, 2]) v.stacked.add_boxes(v.canvas) ac(v.stacked.box_size, (.950, .165), atol=1e-3) v.set_interval((.375, .625)) assert v.time == .5 qtbot.wait(1) v.go_to(.25) assert v.time == .25 qtbot.wait(1) v.go_to(-.5) assert v.time == .125 qtbot.wait(1) v.go_left() assert v.time == .125 qtbot.wait(1) v.go_right() ac(v.time, .150) qtbot.wait(1) v.jump_left() qtbot.wait(1) v.jump_right() qtbot.wait(1) v.go_to_next_spike() qtbot.wait(1) v.go_to_previous_spike() qtbot.wait(1) # Change interval size. v.interval = (.25, .75) ac(v.interval, (.25, .75)) qtbot.wait(1) v.widen() ac(v.interval, (.1875, .8125)) qtbot.wait(1) v.narrow() ac(v.interval, (.25, .75)) qtbot.wait(1) v.go_to_start() qtbot.wait(1) assert v.interval[0] == 0 v.go_to_end() qtbot.wait(1) assert v.interval[1] == duration # Widen the max interval. v.set_interval((0, duration)) v.widen() qtbot.wait(1) v.toggle_show_labels(True) v.go_right() # Check auto scaling. db = v.data_bounds v.toggle_auto_scale(False) v.narrow() qtbot.wait(1) # Check that ymin and ymax have not changed. assert v.data_bounds[1] == db[1] assert v.data_bounds[3] == db[3] v.toggle_auto_update(True) assert v.do_show_labels qtbot.wait(1) v.toggle_highlighted_spikes(True) qtbot.wait(50) # Change channel scaling. bs = v.stacked.box_size v.decrease() qtbot.wait(1) v.increase() ac(v.stacked.box_size, bs, atol=.05) qtbot.wait(1) v.origin = 'bottom' v.switch_origin() assert v.origin == 'top' qtbot.wait(1) # Simulate spike selection. _clicked = [] @connect(sender=v) def on_select_spike(sender, channel_id=None, spike_id=None, cluster_id=None, key=None): _clicked.append((channel_id, spike_id, cluster_id)) mouse_click(qtbot, v.canvas, pos=(0., 0.), button='Left', modifiers=('Control', )) v.set_state(v.state) assert len(_clicked[0]) == 3 # Simulate channel selection. _clicked = [] @connect(sender=v) def on_select_channel(sender, channel_id=None, button=None): _clicked.append((channel_id, button)) mouse_click(qtbot, v.canvas, pos=(0., 0.), button='Left', modifiers=('Shift', )) mouse_click(qtbot, v.canvas, pos=(0., 0.), button='Right', modifiers=('Shift', )) assert _clicked == [(2, 'Left'), (2, 'Right')] _stop_and_close(qtbot, v)
def test_trace_image_view_1(qtbot, tempdir, gui): nc = 350 sr = 2000. duration = 1. traces = 10 * artificial_traces(int(round(duration * sr)), nc) def get_traces(interval): return Bunch( data=select_traces(traces, interval, sample_rate=sr), color=(.75, .75, .75, 1), ) v = TraceImageView( traces=get_traces, n_channels=nc, sample_rate=sr, duration=duration, channel_positions=linear_positions(nc), ) v.show() qtbot.waitForWindowShown(v.canvas) v.attach(gui) v.update_color() v.set_interval((.375, .625)) assert v.time == .5 qtbot.wait(1) v.go_to(.25) assert v.time == .25 qtbot.wait(1) v.go_to(-.5) assert v.time == .125 qtbot.wait(1) v.go_left() assert v.time == .125 qtbot.wait(1) v.go_right() ac(v.time, .150) qtbot.wait(1) v.jump_left() qtbot.wait(1) v.jump_right() qtbot.wait(1) # Change interval size. v.interval = (.25, .75) ac(v.interval, (.25, .75)) qtbot.wait(1) v.widen() ac(v.interval, (.1875, .8125)) qtbot.wait(1) v.narrow() ac(v.interval, (.25, .75)) qtbot.wait(1) v.go_to_start() qtbot.wait(1) assert v.interval[0] == 0 v.go_to_end() qtbot.wait(1) assert v.interval[1] == duration # Widen the max interval. v.set_interval((0, duration)) v.widen() qtbot.wait(1) v.toggle_auto_update(True) assert v.do_show_labels qtbot.wait(1) # Change channel scaling. v.decrease() qtbot.wait(1) v.increase() qtbot.wait(1) v.origin = 'bottom' v.switch_origin() # assert v.origin == 'top' qtbot.wait(1) _stop_and_close(qtbot, v)