def test_extend_spikes():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    spike_ids = np.unique(np.random.randint(size=5, low=0, high=n_spikes))

    # These spikes belong to the following clusters.
    clusters = np.unique(spike_clusters[spike_ids])

    # These are the spikes belonging to those clusters, but not in the
    # originally-specified spikes.
    extended = _extend_spikes(spike_ids, spike_clusters)
    assert np.all(np.in1d(spike_clusters[extended], clusters))

    # The function only returns spikes that weren't in the passed spikes.
    assert len(np.intersect1d(extended, spike_ids)) == 0

    # Check that all spikes from our clusters have been selected.
    rest = np.setdiff1d(np.arange(n_spikes), extended)
    rest = np.setdiff1d(rest, spike_ids)
    assert not np.any(np.in1d(spike_clusters[rest], clusters))
def test_clustering_long():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)
    spike_clusters_base = spike_clusters.copy()

    # Instantiate a Clustering instance.
    clustering = Clustering(spike_clusters)
    ae(clustering.spike_clusters, spike_clusters)

    # Test clustering.spikes_in_clusters() function.:
    assert np.all(spike_clusters[clustering.spikes_in_clusters([5])] == 5)

    # Test cluster ids.
    ae(clustering.cluster_ids, np.arange(n_clusters))

    assert clustering.new_cluster_id() == n_clusters
    assert clustering.n_clusters == n_clusters

    # Updating a cluster, method 1.
    spike_clusters_new = spike_clusters.copy()
    spike_clusters_new[:10] = 100
    clustering.spike_clusters[:] = spike_clusters_new[:]
    # Need to update explicitely.
    clustering._new_cluster_id = 101
    clustering._update_cluster_ids()
    ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100])

    # Updating a cluster, method 2.
    clustering.spike_clusters[:] = spike_clusters_base[:]
    clustering.spike_clusters[:10] = 100
    # HACK: need to update manually here.
    clustering._new_cluster_id = 101
    ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100])

    # Assign.
    new_cluster = 101
    clustering.assign(np.arange(0, 10), new_cluster)
    assert new_cluster in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[:10] == new_cluster)

    # Merge.
    my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [2, 3]))[0]
    info = clustering.merge([2, 3])
    my_spikes = info.spike_ids
    ae(my_spikes, my_spikes_0)
    assert (new_cluster + 1) in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[my_spikes] == (new_cluster + 1))

    # Merge to a given cluster.
    clustering.spike_clusters[:] = spike_clusters_base[:]
    clustering._new_cluster_id = 11

    my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [4, 6]))[0]
    info = clustering.merge([4, 6], 11)
    my_spikes = info.spike_ids
    ae(my_spikes, my_spikes_0)
    assert 11 in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[my_spikes] == 11)

    # Split.
    my_spikes = [1, 3, 5]
    clustering.split(my_spikes)
    assert np.all(clustering.spike_clusters[my_spikes] == 12)

    # Assign.
    clusters = [0, 1, 2]
    clustering.assign(my_spikes, clusters)
    clu = clustering.spike_clusters[my_spikes]
    ae(clu - clu[0], clusters)
def test_clustering_assign():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    clustering = Clustering(spike_clusters)

    checkpoints = {}

    def _checkpoint(index=None):
        if index is None:
            index = len(checkpoints)
        checkpoints[index] = clustering.spike_clusters.copy()

    def _assert_is_checkpoint(index):
        ae(clustering.spike_clusters, checkpoints[index])

    @clustering.connect
    def on_request_undo_state(up):
        return 'hello'

    # Checkpoint 0.
    _checkpoint()
    _assert_is_checkpoint(0)

    my_spikes_1 = np.unique(np.random.randint(low=0, high=n_spikes, size=5))
    my_spikes_2 = np.unique(np.random.randint(low=0, high=n_spikes, size=10))
    my_spikes_3 = np.unique(np.random.randint(low=0, high=n_spikes, size=1000))
    my_spikes_4 = np.arange(n_spikes - 5)

    # Edge cases.
    clustering.assign([])
    with raises(ValueError):
        clustering.merge([], 1)

    # Checkpoint 1.
    info = clustering.split(my_spikes_1)
    _checkpoint()
    assert info.description == 'assign'
    assert 10 in info.added
    assert info.history is None
    _assert_is_checkpoint(1)

    # Checkpoint 2.
    info = clustering.split(my_spikes_2)
    assert info.description == 'assign'
    assert info.history is None
    _checkpoint()
    _assert_is_checkpoint(2)

    # Checkpoint 3.
    info = clustering.assign(my_spikes_3)
    assert info.description == 'assign'
    assert info.history is None
    assert info.undo_state is None
    _checkpoint()
    _assert_is_checkpoint(3)

    # Undo checkpoint 3.
    info = clustering.undo()
    assert info.description == 'assign'
    assert info.history == 'undo'
    assert info.undo_state == ['hello']
    _checkpoint()
    _assert_is_checkpoint(2)

    # Checkpoint 4.
    info = clustering.assign(my_spikes_4)
    assert info.description == 'assign'
    assert info.history is None
    _checkpoint(4)
    assert len(info.deleted) >= 2
    _assert_is_checkpoint(4)
def test_clustering_merge():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    clustering = Clustering(spike_clusters)
    spk0 = clustering.spikes_per_cluster[0]
    spk1 = clustering.spikes_per_cluster[1]

    checkpoints = {}

    def _checkpoint():
        index = len(checkpoints)
        checkpoints[index] = clustering.spike_clusters.copy()

    def _assert_is_checkpoint(index):
        ae(clustering.spike_clusters, checkpoints[index])

    def _assert_spikes(clusters):
        ae(info.spike_ids, _spikes_in_clusters(spike_clusters, clusters))

    @clustering.connect
    def on_request_undo_state(up):
        return 'hello'

    # Checkpoint 0.
    _checkpoint()
    _assert_is_checkpoint(0)

    # Checkpoint 1.
    info = clustering.merge([0, 1], 11)
    _checkpoint()
    _assert_spikes([11])
    ae(clustering.spikes_per_cluster[11], np.sort(np.r_[spk0, spk1]))
    assert 0 not in clustering.spikes_per_cluster
    assert info.added == [11]
    assert info.deleted == [0, 1]
    _assert_is_checkpoint(1)

    # Checkpoint 2.
    info = clustering.merge([2, 3], 12)
    _checkpoint()
    _assert_spikes([12])
    assert info.added == [12]
    assert info.deleted == [2, 3]
    assert info.history is None
    assert info.undo_state is None  # undo_state is only returned when undoing.
    _assert_is_checkpoint(2)

    # Undo once.
    info = clustering.undo()
    assert info.added == [2, 3]
    assert info.deleted == [12]
    assert info.history == 'undo'
    assert info.undo_state == ['hello']
    _assert_is_checkpoint(1)
    ae(clustering.spikes_per_cluster[11], np.sort(np.r_[spk0, spk1]))

    # Redo.
    info = clustering.redo()
    _assert_spikes([12])
    assert info.added == [12]
    assert info.deleted == [2, 3]
    assert info.history == 'redo'
    assert info.undo_state is None
    _assert_is_checkpoint(2)

    # No redo.
    info = clustering.redo()
    _assert_is_checkpoint(2)

    # Merge again.
    info = clustering.merge([4, 5, 6], 13)
    _checkpoint()
    _assert_spikes([13])
    assert info.added == [13]
    assert info.deleted == [4, 5, 6]
    assert info.history is None
    _assert_is_checkpoint(3)

    # One more merge.
    info = clustering.merge([8, 7])  # merged to 14
    _checkpoint()
    _assert_spikes([14])
    assert info.added == [14]
    assert info.deleted == [7, 8]
    assert info.history is None
    _assert_is_checkpoint(4)

    # Now we undo.
    info = clustering.undo()
    assert info.added == [7, 8]
    assert info.deleted == [14]
    assert info.history == 'undo'
    _assert_is_checkpoint(3)

    # We merge again.
    # NOTE: 14 has been wasted, move to 15: necessary to avoid explicit cache
    # invalidation when caching clusterid-based functions.
    assert clustering.new_cluster_id() == 15
    assert any(clustering.spike_clusters == 13)
    assert all(clustering.spike_clusters != 14)
    info = clustering.merge([8, 7], 15)
    _assert_spikes([15])
    assert info.added == [15]
    assert info.deleted == [7, 8]
    assert info.history is None
    # Same as checkpoint with 4, but replace 14 with 15.
    res = checkpoints[4]
    res[res == 14] = 15
    ae(clustering.spike_clusters, res)

    # Undo all.
    for i in range(3, -1, -1):
        info = clustering.undo()
        _assert_is_checkpoint(i)

    _assert_is_checkpoint(0)

    # Redo all.
    for i in range(5):
        _assert_is_checkpoint(i)
        info = clustering.redo()
Beispiel #5
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(
        interval=[0., 1.],
        traces_interval=traces,
        model=m,
        supervisor=s,
        n_samples_waveforms=ns,
        get_best_channels=lambda cluster_id: ch,
        color_selector=cs,
    )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, ) * 4,
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=t - k / sr,
                color=cs.get(c),
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
            )
            out.waveforms.append(d)
        return out

    v = TraceView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_vertical_order=np.arange(nc)[::-1],
    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.),
                         button=1,
                         modifiers=(keys.CONTROL, ))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None,
                     channel_ids=None,
                     spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(
            data=features[spike_ids],
            spike_ids=spike_ids,
            masks=np.random.rand(ns, nc),
            channel_ids=(channel_ids
                         if channel_ids is not None else np.arange(nc)[::-1]),
        )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(
            data=spike_times[get_spike_ids(cluster_id)],
            lim=(0., 1.),
        )

    v = FeatureView(
        features=get_features,
        attributes={'time': get_time},
    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native,
                         Qt.LeftButton,
                         pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
Beispiel #7
0
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None, channel_ids=None, spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(data=features[spike_ids],
                     spike_ids=spike_ids,
                     masks=np.random.rand(ns, nc),
                     channel_ids=(channel_ids
                                  if channel_ids is not None
                                  else np.arange(nc)[::-1]),
                     )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(data=spike_times[get_spike_ids(cluster_id)],
                     lim=(0., 1.),
                     )

    v = FeatureView(features=get_features,
                    attributes={'time': get_time},
                    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native, Qt.LeftButton, pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
Beispiel #8
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(interval=[0., 1.],
                               traces_interval=traces,
                               model=m,
                               supervisor=s,
                               n_samples_waveforms=ns,
                               get_best_channels=lambda cluster_id: ch,
                               color_selector=cs,
                               )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                    color=(.75,) * 4,
                    )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(data=traces[s - k:s + k, :],
                      start_time=t - k / sr,
                      color=cs.get(c),
                      channel_ids=np.arange(5),
                      spike_id=i,
                      spike_cluster=c,
                      )
            out.waveforms.append(d)
        return out

    v = TraceView(traces=get_traces,
                  n_channels=nc,
                  sample_rate=sr,
                  duration=duration,
                  channel_vertical_order=np.arange(nc)[::-1],
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()