def test_correlogram_view(qtbot, tempdir):

    ns = 50

    def get_correlograms(cluster_ids, bin_size, window_size):
        return artificial_correlograms(len(cluster_ids), ns)

    v = CorrelogramView(correlograms=get_correlograms,
                        sample_rate=100.,
                        )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_normalization()

    v.set_bin(1)
    v.set_window(100)

    # qtbot.stop()
    gui.close()
def gui(tempdir, qtbot):
    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
Exemple #3
0
def gui(tempdir, qtbot):
    gui = GUI(position=(200, 200), size=(800, 600), config_dir=tempdir)
    gui.set_default_actions()
    gui.show()
    qtbot.wait(1)
    #qtbot.addWidget(gui)
    #qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(1)
    gui.close()
    qtbot.wait(1)
Exemple #4
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    gui_component._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
Exemple #5
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    _supervisor._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
Exemple #6
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    _supervisor._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.set_default_actions()
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
Exemple #7
0
def test_scatter_view(qtbot, tempdir):
    n = 1000
    v = ScatterView(coords=lambda c: Bunch(
        x=np.random.randn(n),
        y=np.random.randn(n),
        data_bounds=None,
    ))
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
Exemple #8
0
def test_scatter_view(qtbot, tempdir):
    n = 1000
    v = ScatterView(coords=lambda c: Bunch(x=np.random.randn(n),
                                           y=np.random.randn(n),
                                           data_bounds=None,
                                           )
                    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
Exemple #9
0
def test_probe_view(qtbot, tempdir):

    n = 50
    positions = staggered_positions(n)
    best_channels = lambda cluster_id: range(1, 9, 2)

    v = ProbeView(positions=positions,
                  best_channels=best_channels,
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
Exemple #10
0
def test_probe_view(qtbot, tempdir):

    n = 50
    positions = staggered_positions(n)
    best_channels = lambda cluster_id: range(1, 9, 2)

    v = ProbeView(positions=positions,
                  best_channels=best_channels,
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
Exemple #11
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(
        interval=[0., 1.],
        traces_interval=traces,
        model=m,
        supervisor=s,
        n_samples_waveforms=ns,
        get_best_channels=lambda cluster_id: ch,
        color_selector=cs,
    )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, ) * 4,
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=t - k / sr,
                color=cs.get(c),
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
            )
            out.waveforms.append(d)
        return out

    v = TraceView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_vertical_order=np.arange(nc)[::-1],
    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.),
                         button=1,
                         modifiers=(keys.CONTROL, ))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None,
                     channel_ids=None,
                     spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(
            data=features[spike_ids],
            spike_ids=spike_ids,
            masks=np.random.rand(ns, nc),
            channel_ids=(channel_ids
                         if channel_ids is not None else np.arange(nc)[::-1]),
        )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(
            data=spike_times[get_spike_ids(cluster_id)],
            lim=(0., 1.),
        )

    v = FeatureView(
        features=get_features,
        attributes={'time': get_time},
    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native,
                         Qt.LeftButton,
                         pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
Exemple #13
0
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None, channel_ids=None, spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(data=features[spike_ids],
                     spike_ids=spike_ids,
                     masks=np.random.rand(ns, nc),
                     channel_ids=(channel_ids
                                  if channel_ids is not None
                                  else np.arange(nc)[::-1]),
                     )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(data=spike_times[get_spike_ids(cluster_id)],
                     lim=(0., 1.),
                     )

    v = FeatureView(features=get_features,
                    attributes={'time': get_time},
                    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native, Qt.LeftButton, pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
Exemple #14
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(interval=[0., 1.],
                               traces_interval=traces,
                               model=m,
                               supervisor=s,
                               n_samples_waveforms=ns,
                               get_best_channels=lambda cluster_id: ch,
                               color_selector=cs,
                               )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                    color=(.75,) * 4,
                    )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(data=traces[s - k:s + k, :],
                      start_time=t - k / sr,
                      color=cs.get(c),
                      channel_ids=np.arange(5),
                      spike_id=i,
                      spike_cluster=c,
                      )
            out.waveforms.append(d)
        return out

    v = TraceView(traces=get_traces,
                  n_channels=nc,
                  sample_rate=sr,
                  duration=duration,
                  channel_vertical_order=np.arange(nc)[::-1],
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
Exemple #15
0
def test_waveform_view(qtbot, tempdir):
    nc = 5

    def get_waveforms(cluster_id):
        return Bunch(
            data=artificial_waveforms(10, 20, nc),
            channel_ids=np.arange(nc),
            channel_positions=staggered_positions(nc),
        )

    v = WaveformView(waveforms=get_waveforms, )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_show_labels()
    v.toggle_show_labels()

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    v.events.key_press(key=keys.Key('2'))
    v.events.mouse_press(pos=(0., 0.), button=1)
    v.events.key_release(key=keys.Key('2'))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()
Exemple #16
0
def test_waveform_view(qtbot, tempdir):
    nc = 5

    def get_waveforms(cluster_id):
        return Bunch(
            data=artificial_waveforms(10, 20, nc), channel_ids=np.arange(nc), channel_positions=staggered_positions(nc)
        )

    v = WaveformView(waveforms=get_waveforms)
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_show_labels()
    v.toggle_show_labels()

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    v.events.key_press(key=keys.Key("2"))
    v.events.mouse_press(pos=(0.0, 0.0), button=1)
    v.events.key_release(key=keys.Key("2"))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()