Example #1
0
def test_grid_closest_box():
    grid = Grid((3, 7))
    ac(grid.get_closest_box((0.0, 0.0)), (1, 3))
    ac(grid.get_closest_box((-1.0, +1.0)), (0, 0))
    ac(grid.get_closest_box((+1.0, -1.0)), (2, 6))
    ac(grid.get_closest_box((-1.0, -1.0)), (2, 0))
    ac(grid.get_closest_box((+1.0, +1.0)), (0, 6))
Example #2
0
def test_disk_store():

    dtype = np.float32
    sha = (2, 4)
    shb = (3, 5)
    a = np.random.rand(*sha).astype(dtype)
    b = np.random.rand(*shb).astype(dtype)

    def _assert_equal(d_0, d_1):
        """Test the equality of two dictionaries containing NumPy arrays."""
        assert sorted(d_0.keys()) == sorted(d_1.keys())
        for key in d_0.keys():
            ac(d_0[key], d_1[key])

    with TemporaryDirectory() as tempdir:
        ds = DiskStore(tempdir)

        ds.register_file_extensions(['key', 'key_bis'])
        assert ds.cluster_ids == []

        ds.store(3, key=a)
        _assert_equal(ds.load(3,
                              ['key'],
                              dtype=dtype,
                              shape=sha,
                              ),
                      {'key': a})
        loaded = ds.load(3, 'key', dtype=dtype, shape=sha)
        ac(loaded, a)

        # Loading a non-existing key returns None.
        assert ds.load(3, 'key_bis') is None
        assert ds.cluster_ids == [3]

        ds.store(3, key_bis=b)
        _assert_equal(ds.load(3, ['key'], dtype=dtype, shape=sha), {'key': a})
        _assert_equal(ds.load(3, ['key_bis'],
                              dtype=dtype,
                              shape=shb,
                              ),
                      {'key_bis': b})
        _assert_equal(ds.load(3,
                              ['key', 'key_bis'],
                              dtype=dtype,
                              ),
                      {'key': a.ravel(), 'key_bis': b.ravel()})
        ac(ds.load(3, 'key_bis', dtype=dtype, shape=shb), b)
        assert ds.cluster_ids == [3]

        ds.erase([2, 3])
        assert ds.load(3, ['key']) == {'key': None}
        assert ds.cluster_ids == []

        # Test load/save file.
        ds.save_file('test', {'a': a})
        ds = DiskStore(tempdir)
        data = ds.load_file('test')
        ae(data['a'], a)
        assert ds.load_file('test2') is None
Example #3
0
 def _check_arrays(cluster, clusters_for_sc=None, spikes=None):
     """Check the features and masks in the cluster store
     of a given custer."""
     if spikes is None:
         if clusters_for_sc is None:
             clusters_for_sc = [cluster]
         spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc)
     shape = (len(spikes),
              len(session.model.channel_order),
              session.model.n_features_per_channel)
     ac(cs.features(cluster), f[spikes, :].reshape(shape))
     ac(cs.masks(cluster), m[spikes])
Example #4
0
def _gen_store_1(chunk_size):
    _reset_store()
    _free_cache()

    t0 = default_timer()
    # chunk_size = 10000
    # print("chunks")
    for i in range(n_spikes // chunk_size):
        # print(i, end='\r')
        a, b = i * chunk_size, (i + 1) * chunk_size

        # Load a chunk from HDF5.
        assert isinstance(arr, h5py.Dataset)
        sub_arr = arr[a:b]
        assert isinstance(sub_arr, np.ndarray)
        sub_sc = sc[a:b]
        sub_spikes = np.arange(a, b)

        # Split the spikes.
        sub_spc = _spikes_per_cluster(sub_spikes, sub_sc)

        # Go through the clusters.
        clusters = sorted(sub_spc.keys())
        for cluster in clusters:
            idx = _index_of(sub_spc[cluster], sub_spikes)

            # Save part of the array to a binary file.
            with open(_flat_file(cluster), 'ab') as f:
                sub_arr[idx].tofile(f)
    # print()

    # ds = DiskStore(_store_path)

    # # Next, put the flat binary files back to HDF5.
    # # print("flat to HDF5")
    # for cluster in range(n_clusters):
    #     # print(cluster, end='\r')
    #     data = np.fromfile(_flat_file(cluster),
    #                        dtype=np.float32).reshape((-1, n_channels))
    #     ds.store(cluster, data=data)


    print("time", default_timer() - t0)
    # print()

    # Test.
    cluster = 0
    # arr2 = ds.load(cluster, 'data')
    arr2 = np.fromfile(_flat_file(cluster),
                       dtype=np.float32).reshape((-1, n_channels))

    ac(arr[spc[cluster], :], arr2)
Example #5
0
def test_grid_interact():
    grid = Grid((4, 8))
    ac(grid.map([0.0, 0.0], (0, 0)), [[-0.875, 0.75]])
    ac(grid.map([0.0, 0.0], (1, 3)), [[-0.125, 0.25]])
    ac(grid.map([0.0, 0.0], (3, 7)), [[0.875, -0.75]])

    ac(grid.imap([[0.875, -0.75]], (3, 7)), [[0.0, 0.0]])
Example #6
0
def test_stacked_closest_box():
    stacked = Stacked(n_boxes=4, origin="upper")
    ac(stacked.get_closest_box((-0.5, 0.9)), 0)
    ac(stacked.get_closest_box((+0.5, -0.9)), 3)

    stacked = Stacked(n_boxes=4, origin="lower")
    ac(stacked.get_closest_box((-0.5, 0.9)), 3)
    ac(stacked.get_closest_box((+0.5, -0.9)), 0)
Example #7
0
def test_boxed_closest_box():
    b = np.array([[-0.5, -0.5, 0.0, 0.0], [0.0, 0.0, +0.5, +0.5]])
    boxed = Boxed(box_bounds=b)

    ac(boxed.get_closest_box((-1, -1)), 0)
    ac(boxed.get_closest_box((-0.001, 0)), 0)
    ac(boxed.get_closest_box((+0.001, 0)), 1)
    ac(boxed.get_closest_box((-1, +1)), 0)
Example #8
0
def _assert_equal(d_0, d_1):
    """Check that two objects are equal."""
    # Compare arrays.
    if _is_array_like(d_0):
        try:
            ae(d_0, d_1)
        except AssertionError:
            ac(d_0, d_1)
    # Compare dicts recursively.
    elif isinstance(d_0, dict):
        assert set(d_0) == set(d_1)
        for k_0 in d_0:
            _assert_equal(d_0[k_0], d_1[k_0])
    else:
        # General comparison.
        assert d_0 == d_1
Example #9
0
def _assert_equal(d_0, d_1):
    """Check that two objects are equal."""
    # Compare arrays.
    if _is_array_like(d_0):
        try:
            ae(d_0, d_1)
        except AssertionError:
            ac(d_0, d_1)
    # Compare dicts recursively.
    elif isinstance(d_0, dict):
        assert sorted(d_0) == sorted(d_1)
        for (k_0, k_1) in zip(sorted(d_0), sorted(d_1)):
            assert k_0 == k_1
            _assert_equal(d_0[k_0], d_1[k_1])
    else:
        # General comparison.
        assert d_0 == d_1
Example #10
0
def test_trace_view(qtbot, gui):
    v = gui.controller.add_trace_view(gui)

    _select_clusters(gui)

    ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, gui.controller.duration))
    v.widen()

    v.toggle_show_labels()
    assert not v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # qtbot.stop()
    gui.close()
Example #11
0
def test_write_by_chunk(tempdir):
    n = 5
    arrs = [np.random.rand(i + 1, 3).astype(np.float32) for i in range(n)]
    n_tot = sum(_.shape[0] for _ in arrs)

    path = op.join(tempdir, 'test.h5')
    with open_h5(path, 'w') as f:
        ds = f.write('/test', shape=(n_tot, 3), dtype=np.float32)
        _write_by_chunk(ds, arrs)
    with open_h5(path, 'r') as f:
        ds = f.read('/test')[...]
        offset = 0
        for i, arr in enumerate(arrs):
            size = arr.shape[0]
            assert size == (i + 1)
            ac(ds[offset:offset + size, ...], arr)
            offset += size
Example #12
0
def test_read_kwd(tempdir):
    n_samples = 100
    n_channels = 10

    arr = artificial_traces(n_samples, n_channels)

    path = op.join(tempdir, 'test')

    with open_h5(path, 'w') as f:
        f.write('/recordings/0/data',
                arr[:n_samples // 2, ...].astype(np.float32))
        f.write('/recordings/1/data',
                arr[n_samples // 2:, ...].astype(np.float32))

    with open_h5(path, 'r') as f:
        data = read_kwd(f)[:]

    ac(arr, data)
Example #13
0
def test_mean_masked_features_distance(features,
                                       n_channels,
                                       n_features_per_channel,
                                       ):

    # Shifted feature vectors.
    shift = 10.
    f0 = mean(features)
    f1 = mean(features) + shift

    # Only one channel is unmasked.
    m0 = m1 = np.zeros(n_channels)
    m0[n_channels // 2] = 1

    # Check the distance.
    d_expected = np.sqrt(n_features_per_channel) * shift
    d_computed = get_mean_masked_features_distance(f0, f1, m0, m1,
                                                   n_features_per_channel)
    ac(d_expected, d_computed)
Example #14
0
def test_boxed_1(qtbot, canvas):

    n = 6
    b = np.zeros((n, 4))

    b[:, 0] = b[:, 1] = np.linspace(-1.0, 1.0 - 1.0 / 3.0, n)
    b[:, 2] = b[:, 3] = np.linspace(-1.0 + 1.0 / 3.0, 1.0, n)

    n = 1000
    box_index = np.repeat(np.arange(6), n, axis=0)

    boxed = Boxed(box_bounds=b)
    _create_visual(qtbot, canvas, boxed, box_index)

    ae(boxed.box_bounds, b)
    boxed.box_bounds = b

    boxed.update_boxes(boxed.box_pos, boxed.box_size)
    ac(boxed.box_bounds, b)
Example #15
0
def test_panzoom_set_range():
    pz = PanZoom()

    def _test_range(*bounds):
        pz.set_range(bounds)
        ac(pz.get_range(), bounds)

    _test_range(-1, -1, 1, 1)
    ac(pz.zoom, (1, 1))

    _test_range(-.5, -.5, .5, .5)
    ac(pz.zoom, (2, 2))

    _test_range(0, 0, 1, 1)
    ac(pz.zoom, (2, 2))

    _test_range(-1, 0, 1, 1)
    ac(pz.zoom, (1, 2))

    pz.set_range((-1, 0, 1, 1), keep_aspect=True)
    ac(pz.zoom, (1, 1))
Example #16
0
def test_session_store_features(tempdir):
    """Check that the cluster store works for features and masks."""

    model = MockModel(n_spikes=50, n_clusters=3)
    s0 = np.nonzero(model.spike_clusters == 0)[0]
    s1 = np.nonzero(model.spike_clusters == 1)[0]

    session = _start_manual_clustering(model=model,
                                       tempdir=tempdir,
                                       chunk_size=4,
                                       )

    f = session.store.features(0)
    m = session.store.masks(1)
    w = session.store.waveforms(1)

    assert f.shape == (len(s0), 28, 2)
    assert m.shape == (len(s1), 28,)
    assert w.shape == (len(s1), model.n_samples_waveforms, 28,)

    ac(f, model.features[s0].reshape((f.shape[0], -1, 2)), 1e-3)
    ac(m, model.masks[s1], 1e-3)
Example #17
0
def test_tesselate_histogram():
    n = 7
    hist = np.arange(n)
    thist = _tesselate_histogram(hist)
    assert thist.shape == (6 * n, 2)
    ac(thist[0], [0, 0])
    ac(thist[-3], [n, n - 1])
    ac(thist[-1], [n, 0])
Example #18
0
def test_get_boxes():
    positions = [[-1, 0], [1, 0]]
    boxes = _get_boxes(positions)
    ac(boxes, [[-1, -.25, 0, .25],
               [+0, -.25, 1, .25]], atol=1e-4)

    positions = [[-1, 0], [1, 0]]
    boxes = _get_boxes(positions, keep_aspect_ratio=False)
    ac(boxes, [[-1, -1, 0, 1],
               [0, -1, 1, 1]], atol=1e-4)

    positions = linear_positions(4)
    boxes = _get_boxes(positions)
    ac(boxes, [[-0.5, -1.0, +0.5, -0.5],
               [-0.5, -0.5, +0.5, +0.0],
               [-0.5, +0.0, +0.5, +0.5],
               [-0.5, +0.5, +0.5, +1.0],
               ], atol=1e-4)

    positions = staggered_positions(8)
    boxes = _get_boxes(positions)
    ac(boxes[:, 1], np.arange(.75, -1.1, -.25), atol=1e-6)
    ac(boxes[:, 3], np.arange(1, -.76, -.25), atol=1e-7)
Example #19
0
def test_compute_threshold():
    n_samples, n_channels = 100, 10
    data = artificial_traces(n_samples, n_channels)

    # Single threshold.
    threshold = compute_threshold(data, std_factor=1.)
    assert threshold.shape == (2,)
    assert threshold[0] > 0
    assert threshold[0] == threshold[1]

    threshold = compute_threshold(data, std_factor=[1., 2.])
    assert threshold.shape == (2,)
    assert threshold[1] == 2 * threshold[0]

    # Multiple threshold.
    threshold = compute_threshold(data, single_threshold=False, std_factor=2.)
    assert threshold.shape == (2, n_channels)

    threshold = compute_threshold(data,
                                  single_threshold=False,
                                  std_factor=(1., 2.))
    assert threshold.shape == (2, n_channels)
    ac(threshold[1], 2 * threshold[0])
Example #20
0
def test_panzoom_map():
    pz = PanZoom()
    pz.pan = (1., -1.)
    ac(pz.map([0., 0.]), [[1., -1.]])

    pz.zoom = (2., .5)
    ac(pz.map([0., 0.]), [[2., -.5]])

    ac(pz.imap([2., -.5]), [[0., 0.]])
Example #21
0
def test_creator_simple(tempdir):
    basename = op.join(tempdir, 'my_file')

    creator = KwikCreator(basename)

    # Test create empty files.
    creator.create_empty()
    assert op.exists(basename + '.kwik')
    assert op.exists(basename + '.kwx')

    # Test metadata.
    creator.set_metadata('/application_data/spikedetekt',
                         a=1, b=2., c=[0, 1])

    with open_h5(creator.kwik_path, 'r') as f:
        assert f.read_attr('/application_data/spikedetekt', 'a') == 1
        assert f.read_attr('/application_data/spikedetekt', 'b') == 2.
        ae(f.read_attr('/application_data/spikedetekt', 'c'), [0, 1])

    # Test add spikes in one block.
    n_spikes = 100
    n_channels = 8
    n_features = 3

    spike_samples = artificial_spike_samples(n_spikes)
    features = artificial_features(n_spikes, n_channels, n_features)
    masks = artificial_masks(n_spikes, n_channels)

    creator.add_spikes(group=0,
                       spike_samples=spike_samples,
                       features=features.astype(np.float32),
                       masks=masks.astype(np.float32),
                       n_channels=n_channels,
                       n_features=n_features,
                       )

    # Test the spike samples.
    with open_h5(creator.kwik_path, 'r') as f:
        s = f.read('/channel_groups/0/spikes/time_samples')[...]
        assert s.dtype == np.uint64
        ac(s, spike_samples)

    # Test the features and masks.
    with open_h5(creator.kwx_path, 'r') as f:
        fm = f.read('/channel_groups/0/features_masks')[...]
        assert fm.dtype == np.float32
        ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features)))
        ac(fm[:, ::n_features, 1], masks)

    # Spikes can only been added once.
    with raises(RuntimeError):
        creator.add_spikes(group=0,
                           spike_samples=spike_samples,
                           n_channels=n_channels,
                           n_features=n_features)
Example #22
0
def test_boxed_interact():

    n = 8
    b = np.zeros((n, 4))
    b[:, 0] = b[:, 1] = np.linspace(-1.0, 1.0 - 1.0 / 4.0, n)
    b[:, 2] = b[:, 3] = np.linspace(-1.0 + 1.0 / 4.0, 1.0, n)

    boxed = Boxed(box_bounds=b)
    ac(boxed.map([0.0, 0.0], 0), [[-0.875, -0.875]])
    ac(boxed.map([0.0, 0.0], 7), [[0.875, 0.875]])
    ac(boxed.imap([[0.875, 0.875]], 7), [[0.0, 0.0]])
Example #23
0
def test_creator_chunks(tempdir):
    basename = op.join(tempdir, 'my_file')

    creator = KwikCreator(basename)
    creator.create_empty()

    # Test add spikes in one block.
    n_spikes = 100
    n_channels = 8
    n_features = 3

    spike_samples = artificial_spike_samples(n_spikes)
    features = artificial_features(n_spikes, n_channels,
                                   n_features).astype(np.float32)
    masks = artificial_masks(n_spikes, n_channels).astype(np.float32)

    def _split(arr):
        n = n_spikes // 10
        return [arr[k:k + n, ...] for k in range(0, n_spikes, n)]

    creator.add_spikes(group=0,
                       spike_samples=spike_samples,
                       features=_split(features),
                       masks=_split(masks),
                       n_channels=n_channels,
                       n_features=n_features,
                       )

    # Test the spike samples.
    with open_h5(creator.kwik_path, 'r') as f:
        s = f.read('/channel_groups/0/spikes/time_samples')[...]
        assert s.dtype == np.uint64
        ac(s, spike_samples)

    # Test the features and masks.
    with open_h5(creator.kwx_path, 'r') as f:
        fm = f.read('/channel_groups/0/features_masks')[...]
        assert fm.dtype == np.float32
        ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features)))
        ac(fm[:, ::n_features, 1], masks)
Example #24
0
def test_normalize():
    m, M = 0., 10.
    arr = np.linspace(0., 10., 10)
    ac(_normalize(arr, m, M), np.linspace(-1., 1., 10))
Example #25
0
def test_gui_clustering(qtbot):

    gui = _start_manual_clustering()
    gui.show()
    qtbot.addWidget(gui.main_window)

    cs = gui.store
    spike_clusters = gui.model.spike_clusters.copy()

    f = gui.model.features
    m = gui.model.masks

    def _check_arrays(cluster, clusters_for_sc=None, spikes=None):
        """Check the features and masks in the cluster store
        of a given custer."""
        if spikes is None:
            if clusters_for_sc is None:
                clusters_for_sc = [cluster]
            spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc)
        shape = (len(spikes),
                 len(gui.model.channel_order),
                 gui.model.n_features_per_channel)
        ac(cs.features(cluster), f[spikes, :].reshape(shape))
        ac(cs.masks(cluster), m[spikes])

    _check_arrays(0)
    _check_arrays(2)

    # Merge two clusters.
    clusters = [0, 2]
    up = gui.merge(clusters)
    new = up.added[0]
    _check_arrays(new, clusters)

    # Split some spikes.
    spikes = [2, 3, 5, 7, 11, 13]
    # clusters = np.unique(spike_clusters[spikes])
    up = gui.split(spikes)
    _check_arrays(new + 1, spikes=spikes)

    # Undo.
    gui.undo()
    _check_arrays(new, clusters)

    # Undo.
    gui.undo()
    _check_arrays(0)
    _check_arrays(2)

    # Redo.
    gui.redo()
    _check_arrays(new, clusters)

    # Split some spikes.
    spikes = [5, 7, 11, 13, 17, 19]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)
    _check_arrays(new + 1, spikes=spikes)

    # Test merge-undo-different-merge combo.
    spc = gui.clustering.spikes_per_cluster.copy()
    clusters = gui.cluster_ids[:3]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)
    # Undo.
    gui.undo()
    for cluster in clusters:
        _check_arrays(cluster, spikes=spc[cluster])
    # Another merge.
    clusters = gui.cluster_ids[1:5]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)

    # Move a cluster to a group.
    cluster = gui.cluster_ids[0]
    gui.move([cluster], 2)
    assert len(gui.store.mean_probe_position(cluster)) == 2

    spike_clusters_new = gui.model.spike_clusters.copy()
    # Check that the spike clusters have changed.
    assert not np.all(spike_clusters_new == spike_clusters)
    ac(gui.model.spike_clusters, gui.clustering.spike_clusters)

    gui.close()
Example #26
0
def test_normalize():
    m, M = 0., 10.
    arr = np.linspace(0., 10., 10)
    ac(_normalize(arr, m, M), np.linspace(-1., 1., 10))
    ac(_normalize(arr, m, m), arr)
Example #27
0
def test_waveform_view(qtbot, gui):
    v = gui.controller.add_waveform_view(gui)
    _select_clusters(gui)

    ac(v.boxed.box_size, (.1818, .0909), atol=1e-2)

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_zoom_on_channels()
    v.toggle_zoom_on_channels()

    v.toggle_show_labels()
    assert not v.do_show_labels

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    v.zoom_on_channels([0, 2, 4])

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_idx=None, button=None, key=None):
        _clicked.append((channel_idx, button, key))

    v.events.key_press(key=keys.Key('2'))
    v.events.mouse_press(pos=(0., 0.), button=1)
    v.events.key_release(key=keys.Key('2'))

    assert _clicked == [(0, 1, 2)]

    v.next_data()

    # qtbot.stop()
    gui.close()
Example #28
0
def test_session_gui_clustering(qtbot, session):

    cs = session.store
    spike_clusters = session.model.spike_clusters.copy()

    f = session.model.features
    m = session.model.masks

    def _check_arrays(cluster, clusters_for_sc=None, spikes=None):
        """Check the features and masks in the cluster store
        of a given custer."""
        if spikes is None:
            if clusters_for_sc is None:
                clusters_for_sc = [cluster]
            spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc)
        shape = (len(spikes),
                 len(session.model.channel_order),
                 session.model.n_features_per_channel)
        ac(cs.features(cluster), f[spikes, :].reshape(shape))
        ac(cs.masks(cluster), m[spikes])

    _check_arrays(0)
    _check_arrays(2)

    gui = session.show_gui()
    qtbot.addWidget(gui.main_window)

    # Merge two clusters.
    clusters = [0, 2]
    gui.merge(clusters)  # Create cluster 5.
    _check_arrays(5, clusters)

    # Split some spikes.
    spikes = [2, 3, 5, 7, 11, 13]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)  # Create cluster 6 and more.
    _check_arrays(6, spikes=spikes)

    # Undo.
    gui.undo()
    _check_arrays(5, clusters)

    # Undo.
    gui.undo()
    _check_arrays(0)
    _check_arrays(2)

    # Redo.
    gui.redo()
    _check_arrays(5, clusters)

    # Split some spikes.
    spikes = [5, 7, 11, 13, 17, 19]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)  # Create cluster 6 and more.
    _check_arrays(6, spikes=spikes)

    # Test merge-undo-different-merge combo.
    spc = gui.clustering.spikes_per_cluster.copy()
    clusters = gui.cluster_ids[:3]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)
    # Undo.
    gui.undo()
    for cluster in clusters:
        _check_arrays(cluster, spikes=spc[cluster])
    # Another merge.
    clusters = gui.cluster_ids[1:5]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)

    # Move a cluster to a group.
    cluster = gui.cluster_ids[0]
    gui.move([cluster], 2)
    assert len(gui.store.mean_probe_position(cluster)) == 2

    # Save.
    spike_clusters_new = gui.model.spike_clusters.copy()
    # Check that the spike clusters have changed.
    assert not np.all(spike_clusters_new == spike_clusters)
    ac(session.model.spike_clusters, gui.clustering.spike_clusters)
    session.save()

    # Re-open the file and check that the spike clusters and
    # cluster groups have correctly been saved.
    session = _start_manual_clustering(kwik_path=session.model.path,
                                       tempdir=session.tempdir)
    ac(session.model.spike_clusters, gui.clustering.spike_clusters)
    ac(session.model.spike_clusters, spike_clusters_new)
    #  Check the cluster groups.
    clusters = gui.clustering.cluster_ids
    groups = session.model.cluster_groups
    assert groups[cluster] == 2

    gui.close()
Example #29
0
def test_waveform_view(qtbot, tempdir):
    nc = 5

    def get_waveforms(cluster_id):
        return Bunch(
            data=artificial_waveforms(10, 20, nc),
            channel_ids=np.arange(nc),
            channel_positions=staggered_positions(nc),
        )

    v = WaveformView(waveforms=get_waveforms, )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_show_labels()
    v.toggle_show_labels()

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    v.events.key_press(key=keys.Key('2'))
    v.events.mouse_press(pos=(0., 0.), button=1)
    v.events.key_release(key=keys.Key('2'))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()
Example #30
0
 def _assert_equal(d_0, d_1):
     """Test the equality of two dictionaries containing NumPy arrays."""
     assert sorted(d_0.keys()) == sorted(d_1.keys())
     for key in d_0.keys():
         ac(d_0[key], d_1[key])
Example #31
0
def test_binary_search():
    def f(x):
        return x < .4
    ac(_binary_search(f, 0, 1), .4)
    ac(_binary_search(f, 0, .3), .3)
    ac(_binary_search(f, .5, 1), .5)
Example #32
0
def test_gui_clustering(qtbot):

    gui = _start_manual_clustering()
    gui.show()
    qtbot.addWidget(gui.main_window)

    cs = gui.store
    spike_clusters = gui.model.spike_clusters.copy()

    f = gui.model.features
    m = gui.model.masks

    def _check_arrays(cluster, clusters_for_sc=None, spikes=None):
        """Check the features and masks in the cluster store
        of a given custer."""
        if spikes is None:
            if clusters_for_sc is None:
                clusters_for_sc = [cluster]
            spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc)
        shape = (len(spikes), len(gui.model.channel_order),
                 gui.model.n_features_per_channel)
        ac(cs.features(cluster), f[spikes, :].reshape(shape))
        ac(cs.masks(cluster), m[spikes])

    _check_arrays(0)
    _check_arrays(2)

    # Merge two clusters.
    clusters = [0, 2]
    up = gui.merge(clusters)
    new = up.added[0]
    _check_arrays(new, clusters)

    # Split some spikes.
    spikes = [2, 3, 5, 7, 11, 13]
    # clusters = np.unique(spike_clusters[spikes])
    up = gui.split(spikes)
    _check_arrays(new + 1, spikes=spikes)

    # Undo.
    gui.undo()
    _check_arrays(new, clusters)

    # Undo.
    gui.undo()
    _check_arrays(0)
    _check_arrays(2)

    # Redo.
    gui.redo()
    _check_arrays(new, clusters)

    # Split some spikes.
    spikes = [5, 7, 11, 13, 17, 19]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)
    _check_arrays(new + 1, spikes=spikes)

    # Test merge-undo-different-merge combo.
    spc = gui.clustering.spikes_per_cluster.copy()
    clusters = gui.cluster_ids[:3]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)
    # Undo.
    gui.undo()
    for cluster in clusters:
        _check_arrays(cluster, spikes=spc[cluster])
    # Another merge.
    clusters = gui.cluster_ids[1:5]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)

    # Move a cluster to a group.
    cluster = gui.cluster_ids[0]
    gui.move([cluster], 2)
    assert len(gui.store.mean_probe_position(cluster)) == 2

    spike_clusters_new = gui.model.spike_clusters.copy()
    # Check that the spike clusters have changed.
    assert not np.all(spike_clusters_new == spike_clusters)
    ac(gui.model.spike_clusters, gui.clustering.spike_clusters)

    gui.close()
Example #33
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(interval=[0., 1.],
                               traces_interval=traces,
                               model=m,
                               supervisor=s,
                               n_samples_waveforms=ns,
                               get_best_channels=lambda cluster_id: ch,
                               color_selector=cs,
                               )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                    color=(.75,) * 4,
                    )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(data=traces[s - k:s + k, :],
                      start_time=t - k / sr,
                      color=cs.get(c),
                      channel_ids=np.arange(5),
                      spike_id=i,
                      spike_cluster=c,
                      )
            out.waveforms.append(d)
        return out

    v = TraceView(traces=get_traces,
                  n_channels=nc,
                  sample_rate=sr,
                  duration=duration,
                  channel_vertical_order=np.arange(nc)[::-1],
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
Example #34
0
def test_trace_image_view_1(qtbot, tempdir, gui):
    nc = 350
    sr = 2000.
    duration = 1.
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        return Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )

    v = TraceImageView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.update_color()

    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    # Change channel scaling.
    v.decrease()
    qtbot.wait(1)

    v.increase()
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    # assert v.origin == 'top'
    qtbot.wait(1)

    _stop_and_close(qtbot, v)
Example #35
0
 def check(y):
     ac(cp.asnumpy(y)[npad:-npad, :], conv_cpu[npad:-npad, :], atol=1e-3)
Example #36
0
 def _a(f):
     ac(f(traces)[:], f(arr))
Example #37
0
def test_disk_store(tempdir):

    dtype = np.float32
    sha = (2, 4)
    shb = (3, 5)
    a = np.random.rand(*sha).astype(dtype)
    b = np.random.rand(*shb).astype(dtype)

    def _assert_equal(d_0, d_1):
        """Test the equality of two dictionaries containing NumPy arrays."""
        assert sorted(d_0.keys()) == sorted(d_1.keys())
        for key in d_0.keys():
            ac(d_0[key], d_1[key])

    ds = DiskStore(tempdir)

    ds.register_file_extensions(['key', 'key_bis'])
    assert ds.cluster_ids == []

    ds.store(3, key=a)
    _assert_equal(ds.load(
        3,
        ['key'],
        dtype=dtype,
        shape=sha,
    ), {'key': a})
    loaded = ds.load(3, 'key', dtype=dtype, shape=sha)
    ac(loaded, a)

    # Loading a non-existing key returns None.
    assert ds.load(3, 'key_bis') is None
    assert ds.cluster_ids == [3]

    ds.store(3, key_bis=b)
    _assert_equal(ds.load(3, ['key'], dtype=dtype, shape=sha), {'key': a})
    _assert_equal(ds.load(
        3,
        ['key_bis'],
        dtype=dtype,
        shape=shb,
    ), {'key_bis': b})
    _assert_equal(ds.load(
        3,
        ['key', 'key_bis'],
        dtype=dtype,
    ), {
        'key': a.ravel(),
        'key_bis': b.ravel()
    })
    ac(ds.load(3, 'key_bis', dtype=dtype, shape=shb), b)
    assert ds.cluster_ids == [3]

    ds.erase([2, 3])
    assert ds.load(3, ['key']) == {'key': None}
    assert ds.cluster_ids == []

    # Test load/save file.
    ds.save_file('test', {'a': a})
    ds = DiskStore(tempdir)
    data = ds.load_file('test')
    ae(data['a'], a)
    assert ds.load_file('test2') is None
Example #38
0
def test_extract_simple():
    weak = 1.
    strong = 2.
    nc = 4
    ns = 20
    channels = list(range(nc))
    cpg = {0: channels}
    # graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []}

    data = np.random.uniform(size=(ns, nc), low=0., high=1.)

    data[10, 0] = 0.5
    data[11, 0] = 1.5
    data[12, 0] = 1.0

    data[10, 1] = 1.5
    data[11, 1] = 2.5
    data[12, 1] = 2.0

    component = np.array([
        [10, 0],
        [10, 1],
        [11, 0],
        [11, 1],
        [12, 0],
        [12, 1],
    ])

    we = WaveformExtractor(
        extract_before=3,
        extract_after=5,
        thresholds={
            'weak': weak,
            'strong': strong
        },
        channels_per_group=cpg,
    )

    # _component()
    comp = we._component(component, n_samples=ns)
    ae(comp.comp_s, [10, 10, 11, 11, 12, 12])
    ae(comp.comp_ch, [0, 1, 0, 1, 0, 1])
    assert (comp.s_min, comp.s_max) == (10 - 3, 12 + 4)
    ae(comp.channels, range(nc))

    # _normalize()
    assert we._normalize(weak) == 0
    assert we._normalize(strong) == 1
    ae(we._normalize([(weak + strong) / 2.]), [.5])

    # _comp_wave()
    wave = we._comp_wave(data, comp)
    assert wave.shape == (3 + 5 + 1, nc)
    ae(wave[3:6, :],
       [[0.5, 1.5, 0., 0.], [1.5, 2.5, 0., 0.], [1.0, 2.0, 0., 0.]])

    # masks()
    masks = we.masks(data, wave, comp)
    ae(masks, [.5, 1., 0, 0])

    # spike_sample_aligned()
    s = we.spike_sample_aligned(wave, comp)
    assert 11 <= s < 12

    # extract()
    wave_e = we.extract(data, s, channels=channels)
    assert wave_e.shape[1] == wave.shape[1]
    ae(wave[3:6, :2], wave_e[3:6, :2])

    # align()
    wave_a = we.align(wave_e, s)
    assert wave_a.shape == (3 + 5, nc)

    # Test final call.
    groups, s_f, wave_f, masks_f = we(component, data=data, data_t=data)
    assert s_f == s
    assert np.all(groups == 0)
    ae(masks_f, masks)
    ac(wave_f, wave_a)

    # Tests with a different order.
    we = WaveformExtractor(
        extract_before=3,
        extract_after=5,
        thresholds={
            'weak': weak,
            'strong': strong
        },
        channels_per_group={0: [1, 0, 3]},
    )
    groups, s_f_o, wave_f_o, masks_f_o = we(component, data=data, data_t=data)
    assert np.all(groups == 0)
    assert s_f == s_f_o
    assert np.allclose(wave_f[:, [1, 0, 3]], wave_f_o)
    ac(masks_f_o, [1., 0.5, 0.])
Example #39
0
def test_waveform_view(qtbot, gui):
    v = gui.controller.add_waveform_view(gui)
    _select_clusters(gui)

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_zoom_on_channels()
    v.toggle_zoom_on_channels()

    v.toggle_show_labels()
    assert v.do_show_labels

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    v.zoom_on_channels([0, 2, 4])

    v.filter_by_tag('test')

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_idx=None, button=None, key=None):
        _clicked.append((channel_idx, button, key))

    v.events.key_press(key=keys.Key('2'))
    v.events.mouse_press(pos=(0., 0.), button=1)
    v.events.key_release(key=keys.Key('2'))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()
Example #40
0
def test_trace_view_1(qtbot, tempdir, gui):
    nc = 5
    ns = 20
    sr = 2000.
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ClusterColorSelector(cluster_ids=list(range(nc)))

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=(s - k) / sr,
                color=cs.get(c, alpha=.5),
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
            )
            out.waveforms.append(d)
        return out

    def get_spike_times():
        return st

    v = TraceView(
        traces=get_traces,
        spike_times=get_spike_times,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_vertical_order=np.arange(nc)[::-1],
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    ac(v.stacked.box_size, (1., .19), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.go_to_next_spike()
    qtbot.wait(1)

    v.go_to_previous_spike()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_show_labels(True)
    v.go_right()

    # Check auto scaling.
    db = v.data_bounds
    v.toggle_auto_scale(False)
    v.narrow()
    qtbot.wait(1)
    # Check that ymin and ymax have not changed.
    assert v.data_bounds[1] == db[1]
    assert v.data_bounds[3] == db[3]

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    v.toggle_highlighted_spikes(True)
    qtbot.wait(50)

    # Change channel scaling.
    bs = v.stacked.box_size
    v.decrease()
    qtbot.wait(1)

    v.increase()
    ac(v.stacked.box_size, bs, atol=.05)
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    assert v.origin == 'top'
    qtbot.wait(1)

    # Simulate spike selection.
    _clicked = []

    @connect(sender=v)
    def on_spike_click(sender,
                       channel_id=None,
                       spike_id=None,
                       cluster_id=None,
                       key=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Control', ))

    v.set_state(v.state)

    assert len(_clicked[0]) == 3

    _stop_and_close(qtbot, v)
Example #41
0
 def _test_range(*bounds):
     pz.set_range(bounds)
     ac(pz.get_range(), bounds)
Example #42
0
def test_session_clustering(session):

    cs = session.store
    spike_clusters = session.model.spike_clusters.copy()

    f = session.model.features
    m = session.model.masks

    def _check_arrays(cluster, clusters_for_sc=None, spikes=None):
        """Check the features and masks in the cluster store
        of a given custer."""
        if spikes is None:
            if clusters_for_sc is None:
                clusters_for_sc = [cluster]
            spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc)
        shape = (len(spikes),
                 len(session.model.channel_order),
                 session.model.n_features_per_channel)
        ac(cs.features(cluster), f[spikes, :].reshape(shape))
        ac(cs.masks(cluster), m[spikes])

    _check_arrays(0)
    _check_arrays(2)

    gui = session.show_gui()
    yield

    # Merge two clusters.
    clusters = [0, 2]
    gui.merge(clusters)  # Create cluster 5.
    _check_arrays(5, clusters)
    yield

    # Split some spikes.
    spikes = [2, 3, 5, 7, 11, 13]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)  # Create cluster 6 and more.
    _check_arrays(6, spikes=spikes)
    yield

    # Undo.
    gui.undo()
    _check_arrays(5, clusters)
    yield

    # Undo.
    gui.undo()
    _check_arrays(0)
    _check_arrays(2)
    yield

    # Redo.
    gui.redo()
    _check_arrays(5, clusters)
    yield

    # Split some spikes.
    spikes = [5, 7, 11, 13, 17, 19]
    # clusters = np.unique(spike_clusters[spikes])
    gui.split(spikes)  # Create cluster 6 and more.
    _check_arrays(6, spikes=spikes)
    yield

    # Test merge-undo-different-merge combo.
    spc = gui.clustering.spikes_per_cluster.copy()
    clusters = gui.cluster_ids[:3]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)
    # Undo.
    gui.undo()
    for cluster in clusters:
        _check_arrays(cluster, spikes=spc[cluster])
    # Another merge.
    clusters = gui.cluster_ids[1:5]
    up = gui.merge(clusters)
    _check_arrays(up.added[0], spikes=up.spike_ids)
    yield

    # Move a cluster to a group.
    cluster = gui.cluster_ids[0]
    gui.move([cluster], 2)
    assert len(gui.store.mean_probe_position(cluster)) == 2
    yield

    # Save.
    spike_clusters_new = gui.model.spike_clusters.copy()
    # Check that the spike clusters have changed.
    assert not np.all(spike_clusters_new == spike_clusters)
    ac(session.model.spike_clusters, gui.clustering.spike_clusters)
    session.save()
    yield

    # Re-open the file and check that the spike clusters and
    # cluster groups have correctly been saved.
    session = _start_manual_clustering(kwik_path=session.model.path,
                                       tempdir=session.tempdir)
    ac(session.model.spike_clusters, gui.clustering.spike_clusters)
    ac(session.model.spike_clusters, spike_clusters_new)
    #  Check the cluster groups.
    clusters = gui.clustering.cluster_ids
    groups = session.model.cluster_groups
    assert groups[cluster] == 2
    yield

    gui.close()
Example #43
0
def test_panzoom_mouse_pos():
    pz = PanZoom()
    pz.zoom_delta((10, 10), (.5, .25))
    pos = pz.get_mouse_pos((.01, -.01))
    ac(pos, (.5, .25), atol=1e-3)
Example #44
0
 def _assert_equal(d_0, d_1):
     """Test the equality of two dictionaries containing NumPy arrays."""
     assert sorted(d_0.keys()) == sorted(d_1.keys())
     for key in d_0.keys():
         ac(d_0[key], d_1[key])
Example #45
0
def test_waveform_view(qtbot, tempdir, gui):
    nc = 5

    w = 10 + 100 * artificial_waveforms(10, 20, nc)

    def get_waveforms(cluster_id):
        return Bunch(data=w,
                     channel_ids=np.arange(nc),
                     channel_labels=['%d' % (ch * 10) for ch in range(nc)],
                     waveform_duration=1000,
                     channel_positions=staggered_positions(nc))

    v = WaveformView(waveforms={
        'waveforms': get_waveforms,
        'mean_waveforms': get_waveforms
    })
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    v.toggle_waveform_overlap(True)
    v.toggle_waveform_overlap(False)

    v.toggle_show_labels(False)
    v.toggle_show_labels(True)

    v.next_waveforms_type()
    v.toggle_mean_waveforms(True)
    v.toggle_mean_waveforms(False)

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    v.reset_scaling()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @connect(sender=v)
    def on_channel_click(sender, channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    key_press(qtbot, v.canvas, '2')
    mouse_click(qtbot, v.canvas, pos=(0., 0.), button='Left')
    key_release(qtbot, v.canvas, '2')

    assert _clicked == [(2, 'Left', 2)]

    v.set_state(v.state)

    _stop_and_close(qtbot, v)