예제 #1
0
파일: test_base.py 프로젝트: mmyros/phy
class MyModel(object):
    seed = np.random.seed(0)
    n_channels = 8
    n_spikes = 20000
    n_clusters = 32
    n_templates = n_clusters
    n_pcs = 5
    n_samples_waveforms = 100
    channel_positions = np.random.normal(size=(n_channels, 2))
    channel_mapping = np.arange(0, n_channels)
    channel_shanks = np.zeros(n_channels, dtype=np.int32)
    features = artificial_features(n_spikes, n_channels, n_pcs)
    metadata = {'group': {3: 'noise', 4: 'mua', 5: 'good'}}
    sample_rate = 10000
    spike_attributes = {}
    amplitudes = np.random.normal(size=n_spikes, loc=1, scale=.1)
    spike_clusters = artificial_spike_clusters(n_spikes, n_clusters)
    spike_templates = spike_clusters
    spike_samples = artificial_spike_samples(n_spikes)
    spike_times = spike_samples / sample_rate
    spike_times_reordered = artificial_spike_samples(n_spikes) / sample_rate
    duration = spike_times[-1]
    spike_waveforms = None
    traces = artificial_traces(int(sample_rate * duration), n_channels)

    def _get_some_channels(self, offset, size):
        return list(
            islice(cycle(range(self.n_channels)), offset, offset + size))

    def get_features(self, spike_ids, channel_ids):
        return artificial_features(len(spike_ids), len(channel_ids),
                                   self.n_pcs)

    def get_waveforms(self, spike_ids, channel_ids):
        n_channels = len(channel_ids) if channel_ids else self.n_channels
        return artificial_waveforms(len(spike_ids), self.n_samples_waveforms,
                                    n_channels)

    def get_template(self, template_id):
        nc = self.n_channels // 2
        return Bunch(template=artificial_waveforms(1, self.n_samples_waveforms,
                                                   nc)[0, ...],
                     channel_ids=self._get_some_channels(template_id, nc))

    def save_spike_clusters(self, spike_clusters):
        pass

    def save_metadata(self, name, values):
        pass
예제 #2
0
파일: test_feature.py 프로젝트: GrohLab/phy
def test_feature_view(qtbot, gui, n_channels):
    nc = n_channels
    ns = 10000
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None,
                     channel_ids=None,
                     spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(
            data=features[spike_ids],
            spike_ids=spike_ids,
            masks=np.random.rand(ns, nc),
            channel_ids=(channel_ids
                         if channel_ids is not None else np.arange(nc)[::-1]),
        )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(data=spike_times[get_spike_ids(cluster_id)], lim=(0., 1.))

    v = FeatureView(features=get_features, attributes={'time': get_time})
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.set_grid_dim(_get_default_grid())

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    assert v.status

    v.increase()
    v.decrease()

    v.increase_marker_size()
    v.decrease_marker_size()

    v.on_select_channel(channel_id=3, button='Left', key=None)
    v.on_select_channel(channel_id=3, button='Right', key=None)
    v.on_select_channel(channel_id=3, button='Right', key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection(True)

    # Test feature selection with Alt+click.
    _l = []

    @connect(sender=v)
    def on_select_feature(sender, dim=None, channel_id=None, pc=None):
        _l.append((dim, channel_id, pc))

    for i, j, dim_x, dim_y in v._iter_subplots():
        for k, button in enumerate(('Left', 'Right')):
            # Click on the center of every subplot.
            w, h = v.canvas.get_size()
            w, h = w / 4, h / 4
            x, y = w / 2, h / 2
            mouse_click(qtbot,
                        v.canvas, (x + j * w, y + i * h),
                        button=button,
                        modifiers=('Alt', ))
            assert _l[-1][0] == v.grid_dim[i][j].split(',')[k]

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    a, b = 10, 100
    mouse_click(qtbot, v.canvas, (a, a), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (a, b), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (b, b), modifiers=('Control', ))
    mouse_click(qtbot, v.canvas, (b, a), modifiers=('Control', ))

    # Split lassoed points.
    spike_ids = v.on_request_split()

    # HACK: this seems to fail because qtbot.mouseClick is not working??
    # assert len(spike_ids) > 0

    v.set_state(v.state)

    _stop_and_close(qtbot, v)
예제 #3
0
 def get_features(self, spike_ids, channel_ids):
     return artificial_features(len(spike_ids), len(channel_ids),
                                self.n_pcs)
예제 #4
0
def features(n_spikes, n_channels, n_features_per_channel):
    yield artificial_features(n_spikes, n_channels, n_features_per_channel)