コード例 #1
0
ファイル: conftest.py プロジェクト: pombredanne/phy
    def _init_data(self):
        self.cache_dir = self.config_dir
        self.n_samples_waveforms = 31
        self.n_samples_t = 20000
        self.n_channels = 11
        self.n_clusters = 4
        self.n_spikes_per_cluster = 50
        n_spikes_total = self.n_clusters * self.n_spikes_per_cluster
        n_features_per_channel = 4

        self.n_channels = self.n_channels
        self.n_spikes = n_spikes_total
        self.sample_rate = 20000.
        self.duration = self.n_samples_t / float(self.sample_rate)
        self.spike_times = np.arange(0, self.duration, 100. / self.sample_rate)
        self.spike_clusters = np.repeat(np.arange(self.n_clusters),
                                        self.n_spikes_per_cluster)
        assert len(self.spike_times) == len(self.spike_clusters)
        self.cluster_ids = np.unique(self.spike_clusters)
        self.channel_positions = staggered_positions(self.n_channels)

        sc = self.spike_clusters
        self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c])
        self.spike_count = lambda c: len(self.spikes_per_cluster(c))
        self.n_features_per_channel = n_features_per_channel
        self.cluster_groups = {c: None for c in range(self.n_clusters)}

        self.all_traces = artificial_traces(self.n_samples_t, self.n_channels)
        self.all_masks = artificial_masks(n_spikes_total, self.n_channels)
        self.all_waveforms = artificial_waveforms(n_spikes_total,
                                                  self.n_samples_waveforms,
                                                  self.n_channels)
        self.all_features = artificial_features(n_spikes_total,
                                                self.n_channels,
                                                self.n_features_per_channel)
コード例 #2
0
ファイル: test_waveform.py プロジェクト: stephenlenzi/phy
def test_edges():
    n_samples_trace, n_channels = 100, 10
    n_samples_waveforms = 20

    traces = artificial_traces(n_samples_trace, n_channels)

    # Filter.
    b_filter = bandpass_filter(rate=1000, low=50, high=200, order=3)
    filter_margin = 10

    # Create a loader.
    loader = WaveformLoader(traces,
                            n_samples_waveforms=n_samples_waveforms,
                            filter=lambda x: apply_filter(x, b_filter),
                            filter_margin=filter_margin)

    # Invalid time.
    with raises(ValueError):
        loader._load_at(200000)

    ns = n_samples_waveforms + filter_margin
    assert loader._load_at(0).shape == (ns, n_channels)
    assert loader._load_at(5).shape == (ns, n_channels)
    assert loader._load_at(n_samples_trace - 5).shape == (ns, n_channels)
    assert loader._load_at(n_samples_trace - 1).shape == (ns, n_channels)
コード例 #3
0
ファイル: test_waveform.py プロジェクト: stephenlenzi/phy
def waveform_loader(request):
    scale_factor, dc_offset = request.param

    n_samples_trace, n_channels = 1000, 5
    h = 10
    n_samples_waveforms = 2 * h
    n_spikes = n_samples_trace // (2 * n_samples_waveforms)

    traces = artificial_traces(n_samples_trace, n_channels)
    spike_samples = artificial_spike_samples(n_spikes,
                                             max_isi=2 * n_samples_waveforms)

    with raises(ValueError):
        WaveformLoader(traces)

    loader = WaveformLoader(
        traces=traces,
        n_samples_waveforms=n_samples_waveforms,
        scale_factor=scale_factor,
        dc_offset=dc_offset,
    )
    b = Bunch(
        loader=loader,
        n_samples_waveforms=n_samples_waveforms,
        n_spikes=n_spikes,
        spike_samples=spike_samples,
    )
    yield b
コード例 #4
0
def raw_dataset(request):
    sample_rate = 20000
    params = _load_default_settings()['spikedetekt']
    data_type = request.param
    if data_type == 'real':
        path = download_test_data('test-32ch-10s.dat')
        traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32))
        traces = traces[:45000]
        n_samples, n_channels = traces.shape
        params['use_single_threshold'] = False
        probe = load_probe('1x32_buzsaki')
    else:
        probe = {
            'channel_groups': {
                0: {
                    'channels': [0, 1, 2],
                    'graph': [[0, 1], [0, 2], [1, 2]],
                },
                1: {
                    'channels': [3],
                    'graph': [],
                    'geometry': {
                        3: [0., 0.]
                    },
                }
            }
        }
        if data_type == 'null':
            n_samples, n_channels = 25000, 4
            traces = np.zeros((n_samples, n_channels))
        elif data_type == 'artificial':
            n_samples, n_channels = 25000, 4
            traces = artificial_traces(n_samples, n_channels)
            traces[5000:5010, 1] *= 5
            traces[15000:15010, 3] *= 5
    n_samples_w = params['extract_s_before'] + params['extract_s_after']
    yield Bunch(
        n_channels=n_channels,
        n_samples=n_samples,
        sample_rate=sample_rate,
        n_samples_waveforms=n_samples_w,
        traces=traces,
        params=params,
        probe=probe,
    )
コード例 #5
0
ファイル: test_waveform.py プロジェクト: kwikteam/phy
def waveform_loader(do_filter=False, mask_threshold=None):
    n_samples_trace, n_channels = 1000, 5
    h = 10
    n_samples_waveforms = 2 * h
    n_spikes = n_samples_trace // (2 * n_samples_waveforms)
    sample_rate = 2000.

    traces = artificial_traces(n_samples_trace, n_channels)
    spike_samples = artificial_spike_samples(n_spikes,
                                             max_isi=2 * n_samples_waveforms)

    loader = WaveformLoader(traces=traces,
                            spike_samples=spike_samples,
                            n_samples_waveforms=n_samples_waveforms,
                            filter_order=3 if do_filter else None,
                            sample_rate=sample_rate,
                            )
    return loader
コード例 #6
0
def waveform_loader(do_filter=False, mask_threshold=None):
    n_samples_trace, n_channels = 1000, 5
    h = 10
    n_samples_waveforms = 2 * h
    n_spikes = n_samples_trace // (2 * n_samples_waveforms)
    sample_rate = 2000.

    traces = artificial_traces(n_samples_trace, n_channels)
    spike_samples = artificial_spike_samples(n_spikes,
                                             max_isi=2 * n_samples_waveforms)

    loader = WaveformLoader(
        traces=traces,
        spike_samples=spike_samples,
        n_samples_waveforms=n_samples_waveforms,
        filter_order=3 if do_filter else None,
        sample_rate=sample_rate,
    )
    return loader
コード例 #7
0
ファイル: test_waveform.py プロジェクト: stephenlenzi/phy
def test_loader_channels():
    n_samples_trace, n_channels = 1000, 10
    n_samples_waveforms = 20

    traces = artificial_traces(n_samples_trace, n_channels)

    # Create a loader.
    loader = WaveformLoader(traces, n_samples_waveforms=n_samples_waveforms)
    loader.traces = traces
    channels = [2, 5, 7]
    loader.channels = channels
    assert loader.channels == channels
    assert loader[500].shape == (1, n_samples_waveforms, 3)
    assert loader[[500, 501, 600, 300]].shape == (4, n_samples_waveforms, 3)

    # Test edge effects.
    assert loader[3].shape == (1, n_samples_waveforms, 3)
    assert loader[995].shape == (1, n_samples_waveforms, 3)

    with raises(NotImplementedError):
        loader[500:510]
コード例 #8
0
ファイル: conftest.py プロジェクト: mspacek/phy
    def _init_data(self):
        self.cache_dir = self.config_dir
        self.n_samples_waveforms = 31
        self.n_samples_t = 20000
        self.n_channels = 11
        self.n_clusters = 4
        self.n_spikes_per_cluster = 200
        n_spikes_total = self.n_clusters * self.n_spikes_per_cluster
        n_features_per_channel = 4

        self.n_channels = self.n_channels
        self.n_spikes = n_spikes_total
        self.sample_rate = 20000.
        self.duration = self.n_samples_t / float(self.sample_rate)
        self.spike_times = np.arange(
            0, self.duration,
            5000. / (self.sample_rate * self.n_spikes_per_cluster))
        self.spike_clusters = np.repeat(np.arange(self.n_clusters),
                                        self.n_spikes_per_cluster)
        assert len(self.spike_times) == len(self.spike_clusters)
        self.cluster_ids = np.unique(self.spike_clusters)
        self.channel_positions = staggered_positions(self.n_channels)
        self.channel_order = np.arange(self.n_channels)

        sc = self.spike_clusters
        self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c])
        self.spike_count = lambda c: len(self.spikes_per_cluster(c))
        self.n_features_per_channel = n_features_per_channel
        self.cluster_groups = {c: None for c in range(self.n_clusters)}

        self.all_traces = artificial_traces(self.n_samples_t, self.n_channels)
        self.all_masks = artificial_masks(n_spikes_total, self.n_channels)
        self.all_waveforms = artificial_waveforms(n_spikes_total,
                                                  self.n_samples_waveforms,
                                                  self.n_channels)
        self.all_features = artificial_features(n_spikes_total,
                                                self.n_channels,
                                                self.n_features_per_channel)
コード例 #9
0
ファイル: conftest.py プロジェクト: Peichao/phy
def raw_dataset(request):
    sample_rate = 20000
    params = _load_default_settings()['spikedetekt']
    data_type = request.param
    if data_type == 'real':
        path = download_test_data('test-32ch-10s.dat')
        traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32))
        traces = traces[:45000]
        n_samples, n_channels = traces.shape
        params['use_single_threshold'] = False
        probe = load_probe('1x32_buzsaki')
    else:
        probe = {'channel_groups': {
            0: {'channels': [0, 1, 2],
                'graph': [[0, 1], [0, 2], [1, 2]],
                },
            1: {'channels': [3],
                'graph': [],
                'geometry': {3: [0., 0.]},
                }
        }}
        if data_type == 'null':
            n_samples, n_channels = 25000, 4
            traces = np.zeros((n_samples, n_channels))
        elif data_type == 'artificial':
            n_samples, n_channels = 25000, 4
            traces = artificial_traces(n_samples, n_channels)
            traces[5000:5010, 1] *= 5
            traces[15000:15010, 3] *= 5
    n_samples_w = params['extract_s_before'] + params['extract_s_after']
    yield Bunch(n_channels=n_channels,
                n_samples=n_samples,
                sample_rate=sample_rate,
                n_samples_waveforms=n_samples_w,
                traces=traces,
                params=params,
                probe=probe,
                )
コード例 #10
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(
        interval=[0., 1.],
        traces_interval=traces,
        model=m,
        supervisor=s,
        n_samples_waveforms=ns,
        get_best_channels=lambda cluster_id: ch,
        color_selector=cs,
    )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, ) * 4,
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=t - k / sr,
                color=cs.get(c),
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
            )
            out.waveforms.append(d)
        return out

    v = TraceView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_vertical_order=np.arange(nc)[::-1],
    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.),
                         button=1,
                         modifiers=(keys.CONTROL, ))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
コード例 #11
0
ファイル: test_trace.py プロジェクト: kwikteam/phy
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(interval=[0., 1.],
                               traces_interval=traces,
                               model=m,
                               supervisor=s,
                               n_samples_waveforms=ns,
                               get_best_channels=lambda cluster_id: ch,
                               color_selector=cs,
                               )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                    color=(.75,) * 4,
                    )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(data=traces[s - k:s + k, :],
                      start_time=t - k / sr,
                      color=cs.get(c),
                      channel_ids=np.arange(5),
                      spike_id=i,
                      spike_cluster=c,
                      )
            out.waveforms.append(d)
        return out

    v = TraceView(traces=get_traces,
                  n_channels=nc,
                  sample_rate=sr,
                  duration=duration,
                  channel_vertical_order=np.arange(nc)[::-1],
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()