Ejemplo n.º 1
0
class MyModel(object):
    seed = np.random.seed(0)
    n_channels = 8
    n_spikes = 20000
    n_clusters = 32
    n_templates = n_clusters
    n_pcs = 5
    n_samples_waveforms = 100
    channel_positions = np.random.normal(size=(n_channels, 2))
    channel_mapping = np.arange(0, n_channels)
    channel_shanks = np.zeros(n_channels, dtype=np.int32)
    features = artificial_features(n_spikes, n_channels, n_pcs)
    metadata = {'group': {3: 'noise', 4: 'mua', 5: 'good'}}
    sample_rate = 10000
    spike_attributes = {}
    amplitudes = np.random.normal(size=n_spikes, loc=1, scale=.1)
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)
    spike_templates = spike_clusters
    spike_samples = artificial_spike_samples(n_spikes)
    spike_times = spike_samples / sample_rate
    spike_times_reordered = artificial_spike_samples(n_spikes) / sample_rate
    duration = spike_times[-1]
    spike_waveforms = None
    traces = artificial_traces(int(sample_rate * duration), n_channels)

    def _get_some_channels(self, offset, size):
        return list(
            islice(cycle(range(self.n_channels)), offset, offset + size))

    def get_features(self, spike_ids, channel_ids):
        return artificial_features(len(spike_ids), len(channel_ids),
                                   self.n_pcs)

    def get_waveforms(self, spike_ids, channel_ids):
        n_channels = len(channel_ids) if channel_ids else self.n_channels
        return artificial_waveforms(len(spike_ids), self.n_samples_waveforms,
                                    n_channels)

    def get_template(self, template_id):
        nc = self.n_channels // 2
        return Bunch(template=artificial_waveforms(1, self.n_samples_waveforms,
                                                   nc)[0, ...],
                     channel_ids=self._get_some_channels(template_id, nc))

    def save_spike_clusters(self, spike_clusters):
        pass

    def save_metadata(self, name, values):
        pass
Ejemplo n.º 2
0
def waveform_loader(do_filter=False, mask_threshold=None):
    n_samples_trace, n_channels = 1000, 5
    h = 10
    n_samples_waveforms = 2 * h
    n_spikes = n_samples_trace // (2 * n_samples_waveforms)
    sample_rate = 2000.

    traces = artificial_traces(n_samples_trace, n_channels)
    spike_samples = artificial_spike_samples(n_spikes,
                                             max_isi=2 * n_samples_waveforms)

    loader = WaveformLoader(
        traces=traces,
        spike_samples=spike_samples,
        n_samples_waveforms=n_samples_waveforms,
        filter_order=3 if do_filter else None,
        sample_rate=sample_rate,
    )
    return loader
Ejemplo n.º 3
0
def test_iter_spike_waveforms():
    nc = 5
    ns = 20
    sr = 2000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    for w in _iter_spike_waveforms(
            interval=[0., 1.],
            traces_interval=traces,
            model=m,
            supervisor=s,
            n_samples_waveforms=ns,
            show_all_spikes=True,
            get_best_channels=lambda cluster_id: ch,
    ):
        assert w
Ejemplo n.º 4
0
def test_trace_view_1(qtbot, tempdir, gui):
    nc = 5
    ns = 20
    sr = 2000.
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=(s - k) / sr,
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
                select_index=0,
            )
            out.waveforms.append(d)
        return out

    def get_spike_times():
        return st

    v = TraceView(
        traces=get_traces,
        spike_times=get_spike_times,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    v.stacked.add_boxes(v.canvas)

    ac(v.stacked.box_size, (.950, .165), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    v.go_to_next_spike()
    qtbot.wait(1)

    v.go_to_previous_spike()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_show_labels(True)
    v.go_right()

    # Check auto scaling.
    db = v.data_bounds
    v.toggle_auto_scale(False)
    v.narrow()
    qtbot.wait(1)
    # Check that ymin and ymax have not changed.
    assert v.data_bounds[1] == db[1]
    assert v.data_bounds[3] == db[3]

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    v.toggle_highlighted_spikes(True)
    qtbot.wait(50)

    # Change channel scaling.
    bs = v.stacked.box_size
    v.decrease()
    qtbot.wait(1)

    v.increase()
    ac(v.stacked.box_size, bs, atol=.05)
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    assert v.origin == 'top'
    qtbot.wait(1)

    # Simulate spike selection.
    _clicked = []

    @connect(sender=v)
    def on_select_spike(sender,
                        channel_id=None,
                        spike_id=None,
                        cluster_id=None,
                        key=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Control', ))

    v.set_state(v.state)

    assert len(_clicked[0]) == 3

    # Simulate channel selection.
    _clicked = []

    @connect(sender=v)
    def on_select_channel(sender, channel_id=None, button=None):
        _clicked.append((channel_id, button))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Shift', ))
    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Right',
                modifiers=('Shift', ))

    assert _clicked == [(2, 'Left'), (2, 'Right')]

    _stop_and_close(qtbot, v)
Ejemplo n.º 5
0
def test_trace_image_view_1(qtbot, tempdir, gui):
    nc = 350
    sr = 2000.
    duration = 1.
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        return Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )

    v = TraceImageView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.update_color()

    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    # Change channel scaling.
    v.decrease()
    qtbot.wait(1)

    v.increase()
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    # assert v.origin == 'top'
    qtbot.wait(1)

    _stop_and_close(qtbot, v)