Example #1
0
 def create_probe_view(self, gui):
     """Add the view that shows the probe layout."""
     channel_positions = linear_positions(self.n_channels_dat)
     view = ProbeView(channel_positions)
     view.attach(gui)
     # TODO: update positions dynamically when the probe view changes
     return view
Example #2
0
File: gui.py Project: zsong30/phy
 def _create_model(self, **kwargs):
     kwik_path = kwargs.get('kwik_path')
     _backup(kwik_path)
     kwargs = {k: v for k, v in kwargs.items() if k in ('clustering', 'channel_group')}
     model = KwikModelGUI(str(kwik_path), **kwargs)
     # HACK: handle badly formed channel positions
     if model.channel_positions.ndim == 1:  # pragma: no cover
         logger.warning("Unable to read the channel positions, generating mock ones.")
         model.probe.positions = linear_positions(len(model.channel_positions))
     return model
Example #3
0
def test_get_boxes():
    positions = [[-1, 0], [1, 0]]
    boxes = _get_boxes(positions)
    ac(boxes, [[-1, -.25, 0, .25],
               [+0, -.25, 1, .25]], atol=1e-4)

    positions = [[-1, 0], [1, 0]]
    boxes = _get_boxes(positions, keep_aspect_ratio=False)
    ac(boxes, [[-1, -1, 0, 1],
               [0, -1, 1, 1]], atol=1e-4)

    positions = linear_positions(4)
    boxes = _get_boxes(positions)
    ac(boxes, [[-0.5, -1.0, +0.5, -0.5],
               [-0.5, -0.5, +0.5, +0.0],
               [-0.5, +0.0, +0.5, +0.5],
               [-0.5, +0.5, +0.5, +1.0],
               ], atol=1e-4)

    positions = staggered_positions(8)
    boxes = _get_boxes(positions)
    ac(boxes[:, 1], np.arange(.75, -1.1, -.25), atol=1e-6)
    ac(boxes[:, 3], np.arange(1, -.76, -.25), atol=1e-7)
Example #4
0
    def _load_data(self):
        """Load all data."""
        # Spikes
        self.spike_samples, self.spike_times = self._load_spike_samples()
        ns, = self.n_spikes, = self.spike_times.shape

        # Make sure the spike times are increasing.
        if not np.all(np.diff(self.spike_times) >= 0):
            raise ValueError("The spike times must be increasing.")

        # Spike amplitudes.
        self.amplitudes = self._load_amplitudes()
        if self.amplitudes is not None:
            assert self.amplitudes.shape == (ns,)

        # Spike templates.
        self.spike_templates = self._load_spike_templates()
        assert self.spike_templates.shape == (ns,)

        # Spike clusters.
        self.spike_clusters = self._load_spike_clusters()
        assert self.spike_clusters.shape == (ns,)

        # Spike reordering.
        self.spike_times_reordered = self._load_spike_reorder()
        if self.spike_times_reordered is not None:
            assert self.spike_times_reordered.shape == (ns,)

        # Channels.
        self.channel_mapping = self._load_channel_map()
        self.n_channels = nc = self.channel_mapping.shape[0]
        if self.n_channels_dat:
            assert np.all(self.channel_mapping <= self.n_channels_dat - 1)

        # Channel positions.
        self.channel_positions = self._load_channel_positions()
        assert self.channel_positions.shape == (nc, 2)
        if not _all_positions_distinct(self.channel_positions):  # pragma: no cover
            logger.error(
                "Some channels are on the same position, please check the channel positions file.")
            self.channel_positions = linear_positions(nc)

        # Channel shanks.
        self.channel_shanks = self._load_channel_shanks()
        assert self.channel_shanks.shape == (nc,)

        # Channel probes.
        self.channel_probes = self._load_channel_probes()
        assert self.channel_probes.shape == (nc,)
        self.probes = np.unique(self.channel_probes)
        self.n_probes = len(self.probes)

        # Templates.
        self.sparse_templates = self._load_templates()
        if self.sparse_templates is not None:
            self.n_templates, self.n_samples_waveforms, self.n_channels_loc = \
                self.sparse_templates.data.shape
            if self.sparse_templates.cols is not None:
                assert self.sparse_templates.cols.shape == (self.n_templates, self.n_channels_loc)
        else:  # pragma: no cover
            self.n_templates = self.spike_templates.max() + 1
            self.n_samples_waveforms = 0
            self.n_channels_loc = 0

        # Spike waveforms (optional, otherwise fetched from raw data as needed).
        self.spike_waveforms = self._load_spike_waveforms()

        # Whitening.
        try:
            self.wm = self._load_wm()
        except IOError:
            logger.debug("Whitening matrix file not found.")
            self.wm = np.eye(nc)
        assert self.wm.shape == (nc, nc)
        try:
            self.wmi = self._load_wmi()
        except IOError:
            logger.debug("Whitening matrix inverse file not found, computing it.")
            self.wmi = self._compute_wmi(self.wm)
        assert self.wmi.shape == (nc, nc)

        # Similar templates.
        self.similar_templates = self._load_similar_templates()
        assert self.similar_templates.shape == (self.n_templates, self.n_templates)

        # Traces and duration.
        self.traces = self._load_traces(self.channel_mapping)
        if self.traces is not None:
            self.duration = self.traces.shape[0] / float(self.sample_rate)
        else:
            self.duration = self.spike_times[-1]
        if self.spike_times[-1] > self.duration:  # pragma: no cover
            logger.warning(
                "There are %d/%d spikes after the end of the recording.",
                np.sum(self.spike_times > self.duration), self.n_spikes)

        # Features.
        self.sparse_features = self._load_features()
        self.features = self.sparse_features.data if self.sparse_features else None
        if self.sparse_features is not None:
            self.n_features_per_channel = self.sparse_features.data.shape[2]

        # Template features.
        self.sparse_template_features = self._load_template_features()
        self.template_features = (
            self.sparse_template_features.data if self.sparse_template_features else None)

        # Spike attributes.
        self.spike_attributes = self._load_spike_attributes()

        # Metadata.
        self.metadata = self._load_metadata()
Example #5
0
def test_trace_view_1(qtbot, tempdir, gui):
    nc = 5
    ns = 20
    sr = 2000.
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=(s - k) / sr,
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
                select_index=0,
            )
            out.waveforms.append(d)
        return out

    def get_spike_times():
        return st

    v = TraceView(
        traces=get_traces,
        spike_times=get_spike_times,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.on_select(cluster_ids=[])
    v.on_select(cluster_ids=[0])
    v.on_select(cluster_ids=[0, 2, 3])
    v.on_select(cluster_ids=[0, 2])

    v.stacked.add_boxes(v.canvas)

    ac(v.stacked.box_size, (.950, .165), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    v.go_to_next_spike()
    qtbot.wait(1)

    v.go_to_previous_spike()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_show_labels(True)
    v.go_right()

    # Check auto scaling.
    db = v.data_bounds
    v.toggle_auto_scale(False)
    v.narrow()
    qtbot.wait(1)
    # Check that ymin and ymax have not changed.
    assert v.data_bounds[1] == db[1]
    assert v.data_bounds[3] == db[3]

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    v.toggle_highlighted_spikes(True)
    qtbot.wait(50)

    # Change channel scaling.
    bs = v.stacked.box_size
    v.decrease()
    qtbot.wait(1)

    v.increase()
    ac(v.stacked.box_size, bs, atol=.05)
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    assert v.origin == 'top'
    qtbot.wait(1)

    # Simulate spike selection.
    _clicked = []

    @connect(sender=v)
    def on_select_spike(sender,
                        channel_id=None,
                        spike_id=None,
                        cluster_id=None,
                        key=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Control', ))

    v.set_state(v.state)

    assert len(_clicked[0]) == 3

    # Simulate channel selection.
    _clicked = []

    @connect(sender=v)
    def on_select_channel(sender, channel_id=None, button=None):
        _clicked.append((channel_id, button))

    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Left',
                modifiers=('Shift', ))
    mouse_click(qtbot,
                v.canvas,
                pos=(0., 0.),
                button='Right',
                modifiers=('Shift', ))

    assert _clicked == [(2, 'Left'), (2, 'Right')]

    _stop_and_close(qtbot, v)
Example #6
0
def test_trace_image_view_1(qtbot, tempdir, gui):
    nc = 350
    sr = 2000.
    duration = 1.
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)

    def get_traces(interval):
        return Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, .75, .75, 1),
        )

    v = TraceImageView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_positions=linear_positions(nc),
    )
    v.show()
    qtbot.waitForWindowShown(v.canvas)
    v.attach(gui)

    v.update_color()

    v.set_interval((.375, .625))
    assert v.time == .5
    qtbot.wait(1)

    v.go_to(.25)
    assert v.time == .25
    qtbot.wait(1)

    v.go_to(-.5)
    assert v.time == .125
    qtbot.wait(1)

    v.go_left()
    assert v.time == .125
    qtbot.wait(1)

    v.go_right()
    ac(v.time, .150)
    qtbot.wait(1)

    v.jump_left()
    qtbot.wait(1)

    v.jump_right()
    qtbot.wait(1)

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.widen()
    ac(v.interval, (.1875, .8125))
    qtbot.wait(1)

    v.narrow()
    ac(v.interval, (.25, .75))
    qtbot.wait(1)

    v.go_to_start()
    qtbot.wait(1)
    assert v.interval[0] == 0

    v.go_to_end()
    qtbot.wait(1)
    assert v.interval[1] == duration

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()
    qtbot.wait(1)

    v.toggle_auto_update(True)
    assert v.do_show_labels
    qtbot.wait(1)

    # Change channel scaling.
    v.decrease()
    qtbot.wait(1)

    v.increase()
    qtbot.wait(1)

    v.origin = 'bottom'
    v.switch_origin()
    # assert v.origin == 'top'
    qtbot.wait(1)

    _stop_and_close(qtbot, v)
Example #7
0
    def _load_data(self):
        """Load all data."""
        sr = self.sample_rate

        # Spikes.
        self.spike_samples = self._load_spike_samples()
        self.spike_times = self.spike_samples / sr
        ns, = self.n_spikes, = self.spike_times.shape

        # Spike amplitudes.
        self.amplitudes = self._load_amplitudes()
        if self.amplitudes is not None:
            assert self.amplitudes.shape == (ns, )

        # Spike templates.
        self.spike_templates = self._load_spike_templates()
        assert self.spike_templates.shape == (ns, )

        # Spike clusters.
        self.spike_clusters = self._load_spike_clusters()
        assert self.spike_clusters.shape == (ns, )

        # Channels.
        self.channel_mapping = self._load_channel_map()
        self.n_channels = nc = self.channel_mapping.shape[0]
        assert np.all(self.channel_mapping <= self.n_channels_dat - 1)

        # Channel positions.
        self.channel_positions = self._load_channel_positions()
        assert self.channel_positions.shape == (nc, 2)
        if not _all_positions_distinct(
                self.channel_positions):  # pragma: no cover
            logger.error(
                "Some channels are on the same position, please check the channel positions file."
            )
            self.channel_positions = linear_positions(nc)

        # Channel shanks.
        self.channel_shanks = self._load_channel_shanks()
        if self.channel_shanks is not None:
            assert self.channel_shanks.shape == (nc, )

        # Ordering of the channels in the trace view.
        self.channel_vertical_order = np.argsort(self.channel_positions[:, 1],
                                                 kind='mergesort')

        # Templates.
        self.sparse_templates = self._load_templates()
        self.n_templates, self.n_samples_waveforms, self.n_channels_loc = \
            self.sparse_templates.data.shape
        if self.sparse_templates.cols is not None:
            assert self.sparse_templates.cols.shape == (self.n_templates,
                                                        self.n_channels_loc)

        # Whitening.
        try:
            self.wm = self._load_wm()
        except IOError:
            logger.warning("Whitening matrix is not available.")
            self.wm = np.eye(nc)
        assert self.wm.shape == (nc, nc)
        try:
            self.wmi = self._load_wmi()
        except IOError:
            self.wmi = self._compute_wmi(self.wm)
        assert self.wmi.shape == (nc, nc)

        # Similar templates.
        self.similar_templates = self._load_similar_templates()
        assert self.similar_templates.shape == (self.n_templates,
                                                self.n_templates)

        # Traces and duration.
        self.traces = self._load_traces(self.channel_mapping)
        if self.traces is not None:
            self.duration = self.traces.shape[0] / float(self.sample_rate)
        else:
            self.duration = self.spike_times[-1]
        if self.spike_times[-1] > self.duration:  # pragma: no cover
            logger.debug(
                "There are %d/%d spikes after the end of the recording.",
                np.sum(self.spike_times > self.duration), self.n_spikes)

        # Features.
        self.sparse_features = self._load_features()
        self.features = self.sparse_features.data if self.sparse_features else None
        if self.sparse_features is not None:
            self.n_features_per_channel = self.sparse_features.data.shape[2]

        # Template features.
        self.sparse_template_features = self._load_template_features()
        self.template_features = (self.sparse_template_features.data
                                  if self.sparse_template_features else None)

        # Spike attributes.
        self.spike_attributes = self._load_spike_attributes()

        # Metadata.
        self.metadata = self._load_metadata()