Ejemplo n.º 1
0
class MyModel(object):
    seed = np.random.seed(0)
    n_channels = 8
    n_spikes = 20000
    n_clusters = 32
    n_templates = n_clusters
    n_pcs = 5
    n_samples_waveforms = 100
    channel_positions = np.random.normal(size=(n_channels, 2))
    channel_mapping = np.arange(0, n_channels)
    channel_shanks = np.zeros(n_channels, dtype=np.int32)
    features = artificial_features(n_spikes, n_channels, n_pcs)
    metadata = {'group': {3: 'noise', 4: 'mua', 5: 'good'}}
    sample_rate = 10000
    spike_attributes = {}
    amplitudes = np.random.normal(size=n_spikes, loc=1, scale=.1)
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)
    spike_templates = spike_clusters
    spike_samples = artificial_spike_samples(n_spikes)
    spike_times = spike_samples / sample_rate
    spike_times_reordered = artificial_spike_samples(n_spikes) / sample_rate
    duration = spike_times[-1]
    spike_waveforms = None
    traces = artificial_traces(int(sample_rate * duration), n_channels)

    def _get_some_channels(self, offset, size):
        return list(
            islice(cycle(range(self.n_channels)), offset, offset + size))

    def get_features(self, spike_ids, channel_ids):
        return artificial_features(len(spike_ids), len(channel_ids),
                                   self.n_pcs)

    def get_waveforms(self, spike_ids, channel_ids):
        n_channels = len(channel_ids) if channel_ids else self.n_channels
        return artificial_waveforms(len(spike_ids), self.n_samples_waveforms,
                                    n_channels)

    def get_template(self, template_id):
        nc = self.n_channels // 2
        return Bunch(template=artificial_waveforms(1, self.n_samples_waveforms,
                                                   nc)[0, ...],
                     channel_ids=self._get_some_channels(template_id, nc))

    def save_spike_clusters(self, spike_clusters):
        pass

    def save_metadata(self, name, values):
        pass
Ejemplo n.º 2
0
def test_extend_spikes():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    spike_ids = np.unique(np.random.randint(size=5, low=0, high=n_spikes))

    # These spikes belong to the following clusters.
    clusters = np.unique(spike_clusters[spike_ids])

    # These are the spikes belonging to those clusters, but not in the
    # originally-specified spikes.
    extended = _extend_spikes(spike_ids, spike_clusters)
    assert np.all(np.in1d(spike_clusters[extended], clusters))

    # The function only returns spikes that weren't in the passed spikes.
    assert len(np.intersect1d(extended, spike_ids)) == 0

    # Check that all spikes from our clusters have been selected.
    rest = np.setdiff1d(np.arange(n_spikes), extended)
    rest = np.setdiff1d(rest, spike_ids)
    assert not np.any(np.in1d(spike_clusters[rest], clusters))
Ejemplo n.º 3
0
def test_iter_spike_waveforms():
    nc = 5
    ns = 20
    sr = 2000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    for w in _iter_spike_waveforms(
            interval=[0., 1.],
            traces_interval=traces,
            model=m,
            supervisor=s,
            n_samples_waveforms=ns,
            show_all_spikes=True,
            get_best_channels=lambda cluster_id: ch,
    ):
        assert w
Ejemplo n.º 4
0
def test_raster_1(qtbot, gui):
    ns = 10000
    nc = 100
    spike_times = artificial_spike_samples(ns) / 20000.
    spike_clusters = artificial_spike_clusters(ns, nc)
    cluster_ids = np.arange(4)

    v = RasterView(spike_times, spike_clusters)

    @v.add_color_scheme(name='group',
                        cluster_ids=cluster_ids,
                        colormap='cluster_group',
                        categorical=True)
    def cg(cluster_id):
        return cluster_id % 4

    v.add_color_scheme(lambda cid: cid,
                       name='random',
                       cluster_ids=cluster_ids,
                       colormap='categorical',
                       categorical=True)

    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.set_cluster_ids(cluster_ids)
    v.plot()
    v.on_select(cluster_ids=[0])

    v.update_cluster_sort(np.arange(nc))

    v.set_cluster_ids(np.arange(0, nc, 2))
    v.update_color()
    v.plot()

    _stop_and_close(qtbot, v)
Ejemplo n.º 5
0
def test_raster_1(qtbot, gui):
    ns = 10000
    nc = 100
    spike_times = artificial_spike_samples(ns) / 20000.
    spike_clusters = artificial_spike_clusters(ns, nc)
    cluster_ids = np.arange(4)

    cluster_meta = Bunch(fields=('group', ), get=lambda f, cl: cl % 4)
    cluster_metrics = {'quality': lambda c: 100 - c}
    c = ClusterColorSelector(cluster_meta=cluster_meta,
                             cluster_metrics=cluster_metrics,
                             cluster_ids=cluster_ids)

    class Supervisor(object):
        pass

    s = Supervisor()

    v = RasterView(spike_times, spike_clusters, cluster_color_selector=c)
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.set_cluster_ids(cluster_ids)
    v.plot()
    v.on_select(cluster_ids=[0], sender=s)

    v.update_cluster_sort(np.arange(nc))

    c.set_color_mapping('group', 'cluster_group')
    v.update_color()

    v.set_cluster_ids(np.arange(0, nc, 2))
    v.plot()

    _stop_and_close(qtbot, v)
Ejemplo n.º 6
0
def test_clustering_long():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)
    spike_clusters_base = spike_clusters.copy()

    # Instantiate a Clustering instance.
    clustering = Clustering(spike_clusters)
    ae(clustering.spike_clusters, spike_clusters)

    # Test clustering.spikes_in_clusters() function.:
    assert np.all(spike_clusters[clustering.spikes_in_clusters([5])] == 5)

    # Test cluster ids.
    ae(clustering.cluster_ids, np.arange(n_clusters))

    assert clustering.new_cluster_id() == n_clusters
    assert clustering.n_clusters == n_clusters

    # Updating a cluster, method 1.
    spike_clusters_new = spike_clusters.copy()
    spike_clusters_new[:10] = 100
    clustering.spike_clusters[:] = spike_clusters_new[:]
    # Need to update explicitely.
    clustering._new_cluster_id = 101
    clustering._update_cluster_ids()
    ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100])

    # Updating a cluster, method 2.
    clustering.spike_clusters[:] = spike_clusters_base[:]
    clustering.spike_clusters[:10] = 100
    # HACK: need to update manually here.
    clustering._new_cluster_id = 101
    ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100])

    # Assign.
    new_cluster = 101
    clustering.assign(np.arange(0, 10), new_cluster)
    assert new_cluster in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[:10] == new_cluster)

    # Merge.
    my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [2, 3]))[0]
    info = clustering.merge([2, 3])
    my_spikes = info.spike_ids
    ae(my_spikes, my_spikes_0)
    assert (new_cluster + 1) in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[my_spikes] == (new_cluster + 1))

    # Merge to a given cluster.
    clustering.spike_clusters[:] = spike_clusters_base[:]
    clustering._new_cluster_id = 11

    my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [4, 6]))[0]
    info = clustering.merge([4, 6], 11)
    my_spikes = info.spike_ids
    ae(my_spikes, my_spikes_0)
    assert 11 in clustering.cluster_ids
    assert np.all(clustering.spike_clusters[my_spikes] == 11)

    # Split.
    my_spikes = [1, 3, 5]
    clustering.split(my_spikes)
    assert np.all(clustering.spike_clusters[my_spikes] == 12)

    # Assign.
    clusters = [0, 1, 2]
    clustering.assign(my_spikes, clusters)
    clu = clustering.spike_clusters[my_spikes]
    ae(clu - clu[0], clusters)
Ejemplo n.º 7
0
def test_clustering_assign():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    clustering = Clustering(spike_clusters)

    checkpoints = {}

    def _checkpoint(index=None):
        if index is None:
            index = len(checkpoints)
        checkpoints[index] = clustering.spike_clusters.copy()

    def _assert_is_checkpoint(index):
        ae(clustering.spike_clusters, checkpoints[index])

    @connect(sender=clustering)
    def on_request_undo_state(sender, up):
        return 'hello'

    # Checkpoint 0.
    _checkpoint()
    _assert_is_checkpoint(0)

    my_spikes_1 = np.unique(np.random.randint(low=0, high=n_spikes, size=5))
    my_spikes_2 = np.unique(np.random.randint(low=0, high=n_spikes, size=10))
    my_spikes_3 = np.unique(np.random.randint(low=0, high=n_spikes, size=1000))
    my_spikes_4 = np.arange(n_spikes - 5)

    # Edge cases.
    clustering.assign([])
    with raises(ValueError):
        clustering.merge([], 1)

    # Checkpoint 1.
    info = clustering.split(my_spikes_1)
    _checkpoint()
    assert info.description == 'assign'
    assert 10 in info.added
    assert info.history is None
    _assert_is_checkpoint(1)

    # Checkpoint 2.
    info = clustering.split(my_spikes_2)
    assert info.description == 'assign'
    assert info.history is None
    _checkpoint()
    _assert_is_checkpoint(2)

    # Checkpoint 3.
    info = clustering.assign(my_spikes_3)
    assert info.description == 'assign'
    assert info.history is None
    assert info.undo_state is None
    _checkpoint()
    _assert_is_checkpoint(3)

    # Undo checkpoint 3.
    info = clustering.undo()
    assert info.description == 'assign'
    assert info.history == 'undo'
    assert info.undo_state == ['hello']
    _checkpoint()
    _assert_is_checkpoint(2)

    # Checkpoint 4.
    info = clustering.assign(my_spikes_4)
    assert info.description == 'assign'
    assert info.history is None
    _checkpoint(4)
    assert len(info.deleted) >= 2
    _assert_is_checkpoint(4)
Ejemplo n.º 8
0
def test_clustering_merge():
    n_spikes = 1000
    n_clusters = 10
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)

    clustering = Clustering(spike_clusters)
    spk0 = clustering.spikes_per_cluster[0]
    spk1 = clustering.spikes_per_cluster[1]

    checkpoints = {}

    def _checkpoint():
        index = len(checkpoints)
        checkpoints[index] = clustering.spike_clusters.copy()

    def _assert_is_checkpoint(index):
        ae(clustering.spike_clusters, checkpoints[index])

    def _assert_spikes(clusters):
        ae(info.spike_ids, _spikes_in_clusters(spike_clusters, clusters))

    @connect(sender=clustering)
    def on_request_undo_state(sender, up):
        return 'hello'

    # Checkpoint 0.
    _checkpoint()
    _assert_is_checkpoint(0)

    # Checkpoint 1.
    info = clustering.merge([0, 1], 11)
    _checkpoint()
    _assert_spikes([11])
    ae(clustering.spikes_per_cluster[11], np.sort(np.r_[spk0, spk1]))
    assert 0 not in clustering.spikes_per_cluster
    assert info.added == [11]
    assert info.deleted == [0, 1]
    _assert_is_checkpoint(1)

    # Checkpoint 2.
    info = clustering.merge([2, 3], 12)
    _checkpoint()
    _assert_spikes([12])
    assert info.added == [12]
    assert info.deleted == [2, 3]
    assert info.history is None
    assert info.undo_state is None  # undo_state is only returned when undoing.
    _assert_is_checkpoint(2)

    # Undo once.
    info = clustering.undo()
    assert info.added == [2, 3]
    assert info.deleted == [12]
    assert info.history == 'undo'
    assert info.undo_state == ['hello']
    _assert_is_checkpoint(1)
    ae(clustering.spikes_per_cluster[11], np.sort(np.r_[spk0, spk1]))

    # Redo.
    info = clustering.redo()
    _assert_spikes([12])
    assert info.added == [12]
    assert info.deleted == [2, 3]
    assert info.history == 'redo'
    assert info.undo_state is None
    _assert_is_checkpoint(2)

    # No redo.
    info = clustering.redo()
    _assert_is_checkpoint(2)

    # Merge again.
    info = clustering.merge([4, 5, 6], 13)
    _checkpoint()
    _assert_spikes([13])
    assert info.added == [13]
    assert info.deleted == [4, 5, 6]
    assert info.history is None
    _assert_is_checkpoint(3)

    # One more merge.
    info = clustering.merge([8, 7])  # merged to 14
    _checkpoint()
    _assert_spikes([14])
    assert info.added == [14]
    assert info.deleted == [7, 8]
    assert info.history is None
    _assert_is_checkpoint(4)

    # Now we undo.
    info = clustering.undo()
    assert info.added == [7, 8]
    assert info.deleted == [14]
    assert info.history == 'undo'
    _assert_is_checkpoint(3)

    # We merge again.
    # NOTE: 14 has been wasted, move to 15: necessary to avoid explicit cache
    # invalidation when caching clusterid-based functions.
    assert clustering.new_cluster_id() == 15
    assert any(clustering.spike_clusters == 13)
    assert all(clustering.spike_clusters != 14)
    info = clustering.merge([8, 7], 15)
    _assert_spikes([15])
    assert info.added == [15]
    assert info.deleted == [7, 8]
    assert info.history is None
    # Same as checkpoint with 4, but replace 14 with 15.
    res = checkpoints[4]
    res[res == 14] = 15
    ae(clustering.spike_clusters, res)

    # Undo all.
    for i in range(3, -1, -1):
        info = clustering.undo()
        _assert_is_checkpoint(i)

    _assert_is_checkpoint(0)

    # Redo all.
    for i in range(5):
        _assert_is_checkpoint(i)
        info = clustering.redo()
Ejemplo n.º 9
0
def test_feature_view(qtbot, gui, n_channels):
    nc = n_channels
    ns = 10000
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None,
                     channel_ids=None,
                     spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(
            data=features[spike_ids],
            spike_ids=spike_ids,
            masks=np.random.rand(ns, nc),
            channel_ids=(channel_ids
                         if channel_ids is not None else np.arange(nc)[::-1]),
        )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(data=spike_times[get_spike_ids(cluster_id)], lim=(0., 1.))

    v = FeatureView(features=get_features, attributes={'time': get_time})
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.set_grid_dim(_get_default_grid())

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    assert v.status

    v.increase()
    v.decrease()

    v.increase_marker_size()
    v.decrease_marker_size()

    v.on_select_channel(channel_id=3, button='Left', key=None)
    v.on_select_channel(channel_id=3, button='Right', key=None)
    v.on_select_channel(channel_id=3, button='Right', key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection(True)

    # Test feature selection with Alt+click.
    _l = []

    @connect(sender=v)
    def on_select_feature(sender, dim=None, channel_id=None, pc=None):
        _l.append((dim, channel_id, pc))

    for i, j, dim_x, dim_y in v._iter_subplots():
        for k, button in enumerate(('Left', 'Right')):
            # Click on the center of every subplot.
            w, h = v.canvas.get_size()
            w, h = w / 4, h / 4
            x, y = w / 2, h / 2
            mouse_click(qtbot,
                        v.canvas, (x + j * w, y + i * h),
                        button=button,
                        modifiers=('Alt', ))
            assert _l[-1][0] == v.grid_dim[i][j].split(',')[k]

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    a, b = 10, 100
    mouse_click(qtbot, v.canvas, (a, a), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (a, b), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (b, b), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (b, a), modifiers=('Control', ))

    # Split lassoed points.
    spike_ids = v.on_request_split()

    # HACK: this seems to fail because qtbot.mouseClick is not working??
    # assert len(spike_ids) > 0

    v.set_state(v.state)

    _stop_and_close(qtbot, v)
Ejemplo n.º 10
0
def test_trace_view_1(qtbot, tempdir, gui):
    nc = 5
    ns = 20
    sr = 2000.
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=(s - k) / sr,
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
                select_index=0,
            )
            out.waveforms.append(d)
        return out

    def get_spike_times():
        return st

    v = TraceView(
        traces=get_traces,
        spike_times=get_spike_times,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    v.stacked.add_boxes(v.canvas)

    ac(v.stacked.box_size, (.950, .165), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    v.go_to_next_spike()
    qtbot.wait(1)

    v.go_to_previous_spike()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_show_labels(True)
    v.go_right()

    # Check auto scaling.
    db = v.data_bounds
    v.toggle_auto_scale(False)
    v.narrow()
    qtbot.wait(1)
    # Check that ymin and ymax have not changed.
    assert v.data_bounds[1] == db[1]
    assert v.data_bounds[3] == db[3]

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    v.toggle_highlighted_spikes(True)
    qtbot.wait(50)

    # Change channel scaling.
    bs = v.stacked.box_size
    v.decrease()
    qtbot.wait(1)

    v.increase()
    ac(v.stacked.box_size, bs, atol=.05)
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    assert v.origin == 'top'
    qtbot.wait(1)

    # Simulate spike selection.
    _clicked = []

    @connect(sender=v)
    def on_select_spike(sender,
                        channel_id=None,
                        spike_id=None,
                        cluster_id=None,
                        key=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Control', ))

    v.set_state(v.state)

    assert len(_clicked[0]) == 3

    # Simulate channel selection.
    _clicked = []

    @connect(sender=v)
    def on_select_channel(sender, channel_id=None, button=None):
        _clicked.append((channel_id, button))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Shift', ))
    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Right',
                modifiers=('Shift', ))

    assert _clicked == [(2, 'Left'), (2, 'Right')]

    _stop_and_close(qtbot, v)