예제 #1
0
    def create_gui(self, name=None,
                   subtitle=None,
                   config_dir=None,
                   add_default_views=True,
                   **kwargs):
        """Create a manual clustering GUI."""
        config_dir = config_dir or self.config_dir
        gui = GUI(name=name or self.gui_name,
                  subtitle=subtitle,
                  config_dir=config_dir, **kwargs)
        gui.controller = self

        # Attach the ManualClustering component to the GUI.
        self.manual_clustering.attach(gui)

        # Add views.
        if add_default_views:
            self.add_correlogram_view(gui)
            if self.all_features is not None:
                self.add_feature_view(gui)
            if self.all_waveforms is not None:
                self.add_waveform_view(gui)
            if self.all_traces is not None:
                self.add_trace_view(gui)

        self.emit('create_gui', gui)

        return gui
예제 #2
0
    def create_gui(self,
                   name=None,
                   subtitle=None,
                   config_dir=None,
                   add_default_views=True,
                   **kwargs):
        """Create a manual clustering GUI."""
        config_dir = config_dir or self.config_dir
        gui = GUI(name=name or self.gui_name,
                  subtitle=subtitle,
                  config_dir=config_dir,
                  **kwargs)
        gui.controller = self

        # Attach the ManualClustering component to the GUI.
        self.manual_clustering.attach(gui)

        # Add views.
        if add_default_views:
            self.add_correlogram_view(gui)
            if self.all_features is not None:
                self.add_feature_view(gui)
            if self.all_waveforms is not None:
                self.add_waveform_view(gui)
            if self.all_traces is not None:
                self.add_trace_view(gui)

        self.emit('create_gui', gui)

        return gui
예제 #3
0
def gui(tempdir, qtbot):
    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
예제 #4
0
def create_trace_gui(obj, **kwargs):
    """Create the Trace GUI.

    Parameters
    ----------

    obj : str or Path
        Path to the raw data file.
    sample_rate : float
        The data sampling rate, in Hz.
    n_channels_dat : int
        The number of columns in the raw data file.
    dtype : str
        The NumPy data type of the raw binary file.
    offset : int
        The header offset in bytes.

    """

    gui_name = 'TraceGUI'

    # Support passing a params.py file.
    if str(obj).endswith('.py'):
        params = get_template_params(str(obj))
        return create_trace_gui(next(iter(params.pop('dat_path'))), **params)

    kwargs = {
        k: v
        for k, v in kwargs.items()
        if k in ('sample_rate', 'n_channels_dat', 'dtype', 'offset')
    }
    traces = get_ephys_reader(obj, **kwargs)

    create_app()
    gui = GUI(name=gui_name, subtitle=obj.resolve(), enable_threading=False)
    gui.set_default_actions()

    def _get_traces(interval):
        return Bunch(data=select_traces(
            traces, interval, sample_rate=traces.sample_rate))

    # TODO: load channel information

    view = TraceView(
        traces=_get_traces,
        n_channels=traces.n_channels,
        sample_rate=traces.sample_rate,
        duration=traces.duration,
        enable_threading=False,
    )
    view.attach(gui)

    return gui
예제 #5
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    _supervisor._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
예제 #6
0
    def create_gui(self):
        """Create the spike sorting GUI."""

        gui = GUI(name=self.gui_name,
                  subtitle=self.dat_path.resolve(),
                  enable_threading=False)
        gui.has_save_action = False
        gui.set_default_actions()
        self.create_actions(gui)
        self.create_params_widget(gui)
        self.create_ipython_view(gui)
        self.create_trace_view(gui)
        self.create_probe_view(gui)

        return gui
예제 #7
0
    def create_gui(self, **kwargs):
        gui = GUI(name=self.gui_name,
                  subtitle=self.model.dat_path,
                  config_dir=self.config_dir,
                  **kwargs)

        self.supervisor.attach(gui)

        self.add_waveform_view(gui)
        if self.model.traces is not None:
            self.add_trace_view(gui)
        if self.model.features is not None:
            self.add_feature_view(gui)
        if self.model.template_features is not None:
            self.add_template_feature_view(gui)
        self.add_correlogram_view(gui)
        if self.model.amplitudes is not None:
            self.add_amplitude_view(gui)
        self.add_probe_view(gui)

        # Save the memcache when closing the GUI.
        @gui.connect_
        def on_close():
            self.context.save_memcache()

        self.emit('gui_ready', gui)

        return gui
예제 #8
0
파일: conftest.py 프로젝트: GrohLab/phy
def gui(tempdir, qtbot):
    gui = GUI(position=(200, 200), size=(800, 600), config_dir=tempdir)
    gui.set_default_actions()
    gui.show()
    qtbot.wait(1)
    #qtbot.addWidget(gui)
    #qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(1)
    gui.close()
    qtbot.wait(1)
예제 #9
0
def test_scatter_view(qtbot, tempdir):
    n = 1000
    v = ScatterView(coords=lambda c: Bunch(x=np.random.randn(n),
                                           y=np.random.randn(n),
                                           data_bounds=None,
                                           )
                    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
예제 #10
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    _supervisor._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.set_default_actions()
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
예제 #11
0
파일: test_probe.py 프로젝트: kwikteam/phy
def test_probe_view(qtbot, tempdir):

    n = 50
    positions = staggered_positions(n)
    best_channels = lambda cluster_id: range(1, 9, 2)

    v = ProbeView(positions=positions,
                  best_channels=best_channels,
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
예제 #12
0
def test_plot_mpl_1(qtbot):
    gui = GUI()
    c = PlotCanvasMpl()

    c.clear()
    c.attach(gui)

    c.show()
    qtbot.waitForWindowShown(c.canvas)
    if os.environ.get('PHY_TEST_STOP', None):
        qtbot.stop()
    c.close()
예제 #13
0
def test_correlogram_view(qtbot, tempdir):

    ns = 50

    def get_correlograms(cluster_ids, bin_size, window_size):
        return artificial_correlograms(len(cluster_ids), ns)

    v = CorrelogramView(correlograms=get_correlograms,
                        sample_rate=100.,
                        )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_normalization()

    v.set_bin(1)
    v.set_window(100)

    # qtbot.stop()
    gui.close()
예제 #14
0
def gui(tempdir, qtbot):
    # NOTE: mock patch show box exec_
    gui_component._show_box = lambda _: _

    gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir)
    gui.show()
    qtbot.waitForWindowShown(gui)
    yield gui
    qtbot.wait(5)
    gui.close()
    del gui
    qtbot.wait(5)
예제 #15
0
def test_scatter_view(qtbot, tempdir):
    n = 1000
    v = ScatterView(coords=lambda c: Bunch(
        x=np.random.randn(n),
        y=np.random.randn(n),
        data_bounds=None,
    ))
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
예제 #16
0
def test_probe_view(qtbot, tempdir):

    n = 50
    positions = staggered_positions(n)
    best_channels = lambda cluster_id: range(1, 9, 2)

    v = ProbeView(positions=positions,
                  best_channels=best_channels,
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # qtbot.stop()
    gui.close()
예제 #17
0
def test_waveform_view(qtbot, tempdir):
    nc = 5

    def get_waveforms(cluster_id):
        return Bunch(
            data=artificial_waveforms(10, 20, nc),
            channel_ids=np.arange(nc),
            channel_positions=staggered_positions(nc),
        )

    v = WaveformView(waveforms=get_waveforms, )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_show_labels()
    v.toggle_show_labels()

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    v.events.key_press(key=keys.Key('2'))
    v.events.mouse_press(pos=(0., 0.), button=1)
    v.events.key_release(key=keys.Key('2'))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()
예제 #18
0
파일: test_trace.py 프로젝트: kwikteam/phy
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(interval=[0., 1.],
                               traces_interval=traces,
                               model=m,
                               supervisor=s,
                               n_samples_waveforms=ns,
                               get_best_channels=lambda cluster_id: ch,
                               color_selector=cs,
                               )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                    color=(.75,) * 4,
                    )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(data=traces[s - k:s + k, :],
                      start_time=t - k / sr,
                      color=cs.get(c),
                      channel_ids=np.arange(5),
                      spike_id=i,
                      spike_cluster=c,
                      )
            out.waveforms.append(d)
        return out

    v = TraceView(traces=get_traces,
                  n_channels=nc,
                  sample_rate=sr,
                  duration=duration,
                  channel_vertical_order=np.arange(nc)[::-1],
                  )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.), button=1, modifiers=(keys.CONTROL,))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
예제 #19
0
def test_trace_view(tempdir, qtbot):
    nc = 5
    ns = 9
    sr = 1000.
    ch = list(range(nc))
    duration = 1.
    st = np.linspace(0.1, .9, ns)
    sc = artificial_spike_clusters(ns, nc)
    traces = 10 * artificial_traces(int(round(duration * sr)), nc)
    cs = ColorSelector()

    m = Bunch(spike_times=st, spike_clusters=sc, sample_rate=sr)
    s = Bunch(cluster_meta={}, selected=[0])

    sw = _iter_spike_waveforms(
        interval=[0., 1.],
        traces_interval=traces,
        model=m,
        supervisor=s,
        n_samples_waveforms=ns,
        get_best_channels=lambda cluster_id: ch,
        color_selector=cs,
    )
    assert len(list(sw))

    def get_traces(interval):
        out = Bunch(
            data=select_traces(traces, interval, sample_rate=sr),
            color=(.75, ) * 4,
        )
        a, b = st.searchsorted(interval)
        out.waveforms = []
        k = 20
        for i in range(a, b):
            t = st[i]
            c = sc[i]
            s = int(round(t * sr))
            d = Bunch(
                data=traces[s - k:s + k, :],
                start_time=t - k / sr,
                color=cs.get(c),
                channel_ids=np.arange(5),
                spike_id=i,
                spike_cluster=c,
            )
            out.waveforms.append(d)
        return out

    v = TraceView(
        traces=get_traces,
        n_channels=nc,
        sample_rate=sr,
        duration=duration,
        channel_vertical_order=np.arange(nc)[::-1],
    )
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    # qtbot.waitForWindowShown(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    # ac(v.stacked.box_size, (1., .08181), atol=1e-3)
    v.set_interval((.375, .625))
    assert v.time == .5

    v.go_to(.25)
    assert v.time == .25

    v.go_to(-.5)
    assert v.time == .125

    v.go_left()
    assert v.time == .125

    v.go_right()
    assert v.time == .175

    # Change interval size.
    v.interval = (.25, .75)
    ac(v.interval, (.25, .75))
    v.widen()
    ac(v.interval, (.125, .875))
    v.narrow()
    ac(v.interval, (.25, .75))

    # Widen the max interval.
    v.set_interval((0, duration))
    v.widen()

    v.toggle_show_labels()
    # v.toggle_show_labels()
    v.go_right()
    assert v.do_show_labels

    # Change channel scaling.
    bs = v.stacked.box_size
    v.increase()
    v.decrease()
    ac(v.stacked.box_size, bs, atol=1e-3)

    v.origin = 'upper'
    assert v.origin == 'upper'

    # Simulate spike selection.
    _clicked = []

    @v.gui.connect_
    def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
        _clicked.append((channel_id, spike_id, cluster_id))

    v.events.key_press(key=keys.Key('Control'))
    v.events.mouse_press(pos=(400., 200.),
                         button=1,
                         modifiers=(keys.CONTROL, ))
    v.events.key_release(key=keys.Key('Control'))

    assert _clicked == [(1, 4, 1)]

    # qtbot.stop()
    gui.close()
예제 #20
0
파일: gui.py 프로젝트: ycanerol/phy
def create_trace_gui(dat_path, **kwargs):
    """Create the Trace GUI.

    Parameters
    ----------

    dat_path : str or Path
        Path to the raw data file.
    sample_rate : float
        The data sampling rate, in Hz.
    n_channels_dat : int
        The number of columns in the raw data file.
    dtype : str
        The NumPy data type of the raw binary file.

    """

    gui_name = 'TraceGUI'

    dat_path = Path(dat_path)

    # Support passing a params.py file.
    if dat_path.suffix == '.py':
        params = get_template_params(str(dat_path))
        return create_trace_gui(next(iter(params.pop('dat_path'))), **params)

    if dat_path.suffix == '.cbin':  # pragma: no cover
        data = load_raw_data(path=dat_path)
        sample_rate = data.sample_rate
        n_channels_dat = data.shape[1]
    else:
        sample_rate = float(kwargs['sample_rate'])
        assert sample_rate > 0.

        n_channels_dat = int(kwargs['n_channels_dat'])

        dtype = np.dtype(kwargs['dtype'])
        offset = int(kwargs['offset'] or 0)
        order = kwargs.get('order', None)

        # Memmap the raw data file.
        data = load_raw_data(
            path=dat_path,
            n_channels_dat=n_channels_dat,
            dtype=dtype,
            offset=offset,
            order=order,
        )

    duration = data.shape[0] / sample_rate

    create_app()
    gui = GUI(name=gui_name,
              subtitle=dat_path.resolve(),
              enable_threading=False)

    gui.set_default_actions()

    def _get_traces(interval):
        return Bunch(
            data=select_traces(data, interval, sample_rate=sample_rate))

    # TODO: load channel information

    view = TraceView(
        traces=_get_traces,
        n_channels=n_channels_dat,
        sample_rate=sample_rate,
        duration=duration,
        enable_threading=False,
    )
    view.attach(gui)

    return gui
예제 #21
0
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None,
                     channel_ids=None,
                     spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(
            data=features[spike_ids],
            spike_ids=spike_ids,
            masks=np.random.rand(ns, nc),
            channel_ids=(channel_ids
                         if channel_ids is not None else np.arange(nc)[::-1]),
        )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(
            data=spike_times[get_spike_ids(cluster_id)],
            lim=(0., 1.),
        )

    v = FeatureView(
        features=get_features,
        attributes={'time': get_time},
    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native,
                         Qt.LeftButton,
                         pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
예제 #22
0
def test_feature_view(qtbot, tempdir, n_channels):
    nc = n_channels
    ns = 500
    features = artificial_features(ns, nc, 4)
    spike_clusters = artificial_spike_clusters(ns, 4)
    spike_times = np.linspace(0., 1., ns)
    spc = _spikes_per_cluster(spike_clusters)

    def get_spike_ids(cluster_id):
        return (spc[cluster_id] if cluster_id is not None else np.arange(ns))

    def get_features(cluster_id=None, channel_ids=None, spike_ids=None,
                     load_all=None):
        if load_all:
            spike_ids = spc[cluster_id]
        else:
            spike_ids = get_spike_ids(cluster_id)
        return Bunch(data=features[spike_ids],
                     spike_ids=spike_ids,
                     masks=np.random.rand(ns, nc),
                     channel_ids=(channel_ids
                                  if channel_ids is not None
                                  else np.arange(nc)[::-1]),
                     )

    def get_time(cluster_id=None, load_all=None):
        return Bunch(data=spike_times[get_spike_ids(cluster_id)],
                     lim=(0., 1.),
                     )

    v = FeatureView(features=get_features,
                    attributes={'time': get_time},
                    )

    v.set_state(GUIState(scaling=None))

    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    gui.emit('select', [0, 2])
    qtbot.wait(10)

    v.increase()
    v.decrease()

    v.on_channel_click(channel_id=3, button=1, key=2)
    v.clear_channels()
    v.toggle_automatic_channel_selection()

    # Split without selection.
    spike_ids = v.on_request_split()
    assert len(spike_ids) == 0

    # Draw a lasso.
    def _click(x, y):
        qtbot.mouseClick(v.native, Qt.LeftButton, pos=QPoint(x, y),
                         modifier=Qt.ControlModifier)

    _click(10, 10)
    _click(10, 100)
    _click(100, 100)
    _click(100, 10)

    # Split lassoed points.
    spike_ids = v.on_request_split()
    assert len(spike_ids) > 0

    # qtbot.stop()
    gui.close()
예제 #23
0
파일: Pysorter.py 프로젝트: chongxi/mua
@author: chongxi lai
"""

#%%
from MUA import *
from Vis import ObjectWidget, view_scatter_3d
from phy.plot import View
from phy.plot.interact import Grid
from hdbscan import HDBSCAN
import phy
from phy.gui import GUI
import seaborn as sns

mua = MUA(filename='S:/pcie.bin', nCh=32, fs=25000, numbytes=4)
spk = mua.tospk()
fet = spk.tofet('pca')



#%%
gui = GUI(position=(0, 0), size=(600, 400), name='GUI')
props = ObjectWidget()
gui.add_view(props,position='left', name='params')
scatter_view = view_scatter_3d()
scatter_view.unfreeze()
gui.add_view(scatter_view)
spk_view = View('grid')
gui.add_view(spk_view)

예제 #24
0
            def MergeRuns(controller=controller, plugin=plugin):
                if True:
                    path2 = QtGui.QFileDialog.getExistingDirectory(
                        None,
                        "Select the results folder for the sort to be merged",
                        op.dirname(
                            op.dirname(controller.model.dir_path)
                        ),  #two folders up from the current phy's path
                        QtGui.QFileDialog.ShowDirsOnly)
                else:
                    path2 = '/home/luke/KiloSort_tmp/BOL005c_9_96clusts/results'
                params_path = op.join(path2, 'params.py')
                params = _read_python(params_path)
                params['dtype'] = np.dtype(params['dtype'])
                params['path'] = path2
                if op.realpath(params['dat_path']) != params['dat_path']:
                    params['dat_path'] = op.join(path2, params['dat_path'])
                print('Loading {}'.format(path2))
                controller2 = TemplateController(**params)
                #controller2.gui_name = 'TemplateGUI2'
                gui2 = controller2.create_gui()
                gui2.show()

                #                @gui2.connect_
                #                def on_select(clusters,controller2=controller2):
                #                    controller.supervisor.select(clusters)

                #create mean_waveforms for each controller (run)
                print('computing mean waveforms for master run...')
                controller.mean_waveforms = create_mean_waveforms(
                    controller, max_waveforms_per_cluster=100)
                print('computing mean waveforms for slave run...')
                controller2.mean_waveforms = create_mean_waveforms(
                    controller2, max_waveforms_per_cluster=100)

                groups = {
                    c: controller.supervisor.cluster_meta.get('group', c)
                    or 'unsorted'
                    for c in controller.supervisor.clustering.cluster_ids
                }
                groups2 = {
                    c: controller2.supervisor.cluster_meta.get('group', c)
                    or 'unsorted'
                    for c in controller2.supervisor.clustering.cluster_ids
                }
                su_inds = np.nonzero([
                    controller.supervisor.cluster_meta.get('group',
                                                           c) == 'good'
                    for c in controller.supervisor.clustering.cluster_ids
                ])[0]
                mu_inds = np.nonzero([
                    controller.supervisor.cluster_meta.get('group', c) == 'mua'
                    for c in controller.supervisor.clustering.cluster_ids
                ])[0]
                su_best_channels = np.array([
                    controller.get_best_channel(c) for c in
                    controller.supervisor.clustering.cluster_ids[su_inds]
                ])
                mu_best_channels = np.array([
                    controller.get_best_channel(c) for c in
                    controller.supervisor.clustering.cluster_ids[mu_inds]
                ])
                su_order = np.argsort(su_best_channels, kind='mergesort')
                mu_order = np.argsort(mu_best_channels, kind='mergesort')
                m_inds = np.concatenate((su_inds[su_order], mu_inds[mu_order]))

                filename = op.join(controller.model.dir_path,
                                   'cluster_names.ts')
                if not op.exists(filename):
                    best_channels = np.concatenate(
                        (su_best_channels[su_order],
                         mu_best_channels[mu_order]))
                    unit_type = np.concatenate(
                        (np.ones(len(su_order)), 2 * np.ones(len(mu_order))))
                    unit_number = np.zeros(len(best_channels))
                    for chan in np.unique(best_channels):
                        matched_clusts = best_channels == chan
                        unit_number[matched_clusts] = np.arange(
                            sum(matched_clusts)) + 1
                else:
                    print('{} exists, loading'.format(filename))
                    unit_types, channels, unit_numbers = load_metadata(
                        filename, controller.supervisor.clustering.cluster_ids)
                    best_channels = channels[m_inds]
                    unit_number = unit_numbers[m_inds]
                    unit_type = unit_types[m_inds]
                    unit_type_current = np.concatenate(
                        (np.ones(len(su_order)), 2 * np.ones(len(mu_order))))
                    if ~np.all(unit_type == unit_type_current):
                        raise RuntimeError(
                            'For the master phy, the unit types saved in "cluster_names.ts"'
                            'do not match those save in "cluster_groups.tsv" This likely means work was done on '
                            'this phy after merging with a previous master. Not sure how to deal with this!'
                        )
                    #re-sort to make unit numbers in order
                    # assuming unit_type is already sorted (which it should be...)
                    nsu = np.sum(unit_type == 1)
                    su_order = np.lexsort(
                        (unit_number[:nsu], best_channels[:nsu]))
                    mu_order = np.lexsort(
                        (unit_number[nsu:], best_channels[nsu:]))
                    m_inds[:nsu] = m_inds[su_order]
                    m_inds[nsu:] = m_inds[mu_order + nsu]
                    best_channels = channels[m_inds]
                    unit_number = unit_numbers[m_inds]
                    unit_type = unit_types[m_inds]

                dists = calc_dists(controller, controller2, m_inds)

                so = np.argsort(dists, 0, kind='mergesort')
                matchi = so[0, :]  #best match index to master for each slave
                sortrows = np.argsort(
                    matchi, kind='mergesort')  #sort index for best match

                def handle_item_clicked(item,
                                        controller=controller,
                                        controller2=controller2,
                                        plugin=plugin):
                    row = np.array(
                        [cell.row() for cell in table.selectedIndexes()])
                    column = np.array(
                        [cell.column() for cell in table.selectedIndexes()])
                    print("Row {} and Column {} was clicked".format(
                        row, column))
                    print("M {} S {} ".format(
                        controller.supervisor.clustering.cluster_ids[
                            plugin.m_inds[column]],
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]]))
                    column = column[~np.in1d(plugin.m_inds[column], (-1, -2))]
                    if len(column) == 0:
                        pass
                        #controller.supervisor.select(None)
                        # make a deselect function and call it here if feeling fancy
                    else:
                        controller.supervisor.select(
                            controller.supervisor.clustering.cluster_ids[
                                plugin.m_inds[column]].tolist())
                    controller2.supervisor.select(
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]].tolist())
                    #print("Row %d and Column %d was clicked" % (row, column))

                def create_table(controller, controller2, plugin):
                    plugin.table.setRowCount(len(plugin.matchi))
                    plugin.table.setColumnCount(len(plugin.m_inds))

                    # set data
                    dists_txt = np.round(plugin.dists / plugin.dists.max() *
                                         100)
                    normal = plt.Normalize(
                        plugin.dists[plugin.dists != -1].min() - 1,
                        plugin.dists.max() + 1)
                    colors = plt.cm.viridis_r(normal(plugin.dists)) * 255
                    for col in range(len(plugin.m_inds)):
                        for row in range(len(plugin.matchi)):
                            if plugin.dists[col, plugin.sortrows[row]] < 0:
                                item = QtGui.QTableWidgetItem('N/A')
                                item.setBackground(QtGui.QColor(127, 127, 127))
                            else:
                                item = QtGui.QTableWidgetItem('{:.0f}'.format(
                                    dists_txt[col, plugin.sortrows[row]]))
                                item.setBackground(
                                    QtGui.QColor(
                                        colors[col, plugin.sortrows[row], 0],
                                        colors[col, plugin.sortrows[row], 1],
                                        colors[col, plugin.sortrows[row], 2]))
                            if plugin.matchi[plugin.sortrows[row]] == col:
                                item.setForeground(QtGui.QColor(255, 0, 0))
                            #item.setFlags(Qt.ItemIsEditable)
                            plugin.table.setItem(row, col, item)
                            #plugin.table.item(row,col).setForeground(QtGui.QColor(0,255,0))
                    for col in range(plugin.dists.shape[0]):
                        if plugin.m_inds[col] == -1:
                            cluster_num = 'None'
                        elif plugin.m_inds[col] == -2:
                            cluster_num = 'Noise'
                        else:
                            cluster_num = controller.supervisor.clustering.cluster_ids[
                                plugin.m_inds[col]]
                        plugin.table.setHorizontalHeaderItem(
                            col,
                            QtGui.QTableWidgetItem('{}\n{:.0f}-{:.0f}'.format(
                                cluster_num, plugin.best_channels[col],
                                plugin.unit_number[col])))
                        #plugin.table.setHorizontalHeaderItem(row, QtGui.QTableWidgetItem('{:.0f}'.format(controller.supervisor.clustering.cluster_ids[plugin.m_inds[row]])))
                    for col in range(plugin.dists.shape[1]):
                        c_id = controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[col]]
                        #height=controller2.supervisor.cluster_view._columns['height']['func'](c_id)
                        snr = controller2.supervisor.cluster_view._columns[
                            'snr']['func'](c_id)
                        plugin.table.setVerticalHeaderItem(
                            col,
                            QtGui.QTableWidgetItem('{:.0f}-{:.1f}'.format(
                                c_id, snr)))
                    plugin.table.setEditTriggers(
                        QtGui.QAbstractItemView.NoEditTriggers)
                    plugin.table.resizeColumnsToContents()
                    plugin.table.itemClicked.connect(handle_item_clicked)

#                self.fig = plt.figure()
#                ax = self.fig.add_axes([0.15, 0.02, 0.83, 0.975])
#                normal = plt.Normalize(dists.min()-1, dists.max()+1)
#                dists_txt=np.round(dists/dists.max()*100)
#                self.table=ax.table(cellText=dists_txt, rowLabels=controller.supervisor.clustering.cluster_ids[m_inds], colLabels=controller2.supervisor.clustering.cluster_ids,
#                    colWidths = [0.03]*dists.shape[1], loc='center',
#                    cellColours=plt.cm.hot(normal(dists)))
#                self.fig.show()
#a = QApplication(sys.argv)

                tablegui = GUI(position=(400, 200), size=(400, 300))
                table = QtGui.QTableWidget()
                table.setWindowTitle("Merge Table")
                #table.resize(600, 400)

                plugin.matchi = matchi
                plugin.sortrows = sortrows
                plugin.best_channels = best_channels
                plugin.unit_number = unit_number
                plugin.unit_type = unit_type
                plugin.m_inds = m_inds
                plugin.tablegui = tablegui  #need to keep a reference otherwide the gui is deleted by garbage collection, leading to a segfault!
                plugin.table = table
                plugin.dists = dists

                create_table(controller, controller2, plugin)

                tablegui.add_view(table)

                actions = Actions(tablegui, name='Merge', menu='Merge')

                @actions.add(menu='Merge',
                             name='Set master for selected slave',
                             shortcut='enter')
                def set_master(plugin=plugin,
                               controller=controller,
                               controller2=controller2):
                    row = np.array([
                        cell.row() for cell in plugin.table.selectedIndexes()
                    ])
                    column = np.array([
                        cell.column()
                        for cell in plugin.table.selectedIndexes()
                    ])
                    if len(row) == 1:
                        print("Row {} and Column {} is selected".format(
                            row, column))
                        plugin.table.item(row[0], column[0]).setForeground(
                            QtGui.QColor(255, 0, 0))
                        #plugin.table.item(0, 0).setForeground(QtGui.QColor(0,255,0))
                        plugin.table.item(
                            row[0], plugin.matchi[plugin.sortrows[
                                row[0]]]).setForeground(QtGui.QColor(0, 0, 0))
                        plugin.matchi[plugin.sortrows[row[0]]] = column[0]
                        plugin.table.show()
                        plugin.tablegui.show()
                    else:
                        st = 'Only one cell can be selected when setting master'
                        print(st)
                        plugin.tablegui.status_message = st

                @actions.add(menu='Merge',
                             name='Merge selected slaves',
                             shortcut='m')
                def merge_slaves_by_selection(plugin=plugin,
                                              controller=controller,
                                              controller2=controller2):
                    row = np.array([
                        cell.row() for cell in plugin.table.selectedIndexes()
                    ])
                    column = np.array([
                        cell.column()
                        for cell in plugin.table.selectedIndexes()
                    ])
                    row = np.unique(row)
                    column = np.unique(column)
                    if len(column) == 1 and len(row) > 1:
                        merge_slaves(plugin, controller, controller2, row)
                        create_table(controller, controller2, plugin)
                        plugin.tablegui.show()
                    else:
                        if len(column) > 1:
                            st = 'Only one master can be selected when merging slaves'
                        elif len(row) == 1:
                            st = 'At least two slaves must be selected to merge slaves'
                        else:
                            st = 'Unknown slave merging problem'
                        print(st)
                        plugin.tablegui.status_message = st

                def merge_slaves_by_array(plugin, controller, controller2,
                                          merge_matchis):
                    for merge_matchi in merge_matchis:
                        row = np.where(
                            plugin.matchi[plugin.sortrows] == merge_matchi)[0]
                        merge_slaves(plugin, controller, controller2, row)

                    create_table(controller, controller2, plugin)
                    plugin.tablegui.show()

                def merge_slaves(plugin, controller, controller2, row):
                    controller2.supervisor.merge(
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]].tolist())
                    assign_matchi = plugin.matchi[plugin.sortrows[row[0]]]
                    plugin.matchi = np.delete(plugin.matchi,
                                              plugin.sortrows[row],
                                              axis=0)
                    plugin.matchi = np.append(plugin.matchi, assign_matchi)
                    controller2.mean_waveforms = np.delete(
                        controller2.mean_waveforms,
                        plugin.sortrows[row],
                        axis=2)
                    new_mean_waveforms = create_mean_waveforms(
                        controller2,
                        max_waveforms_per_cluster=100,
                        cluster_ids=controller2.supervisor.clustering.
                        cluster_ids[-1])
                    controller2.mean_waveforms = np.append(
                        controller2.mean_waveforms, new_mean_waveforms, axis=2)
                    plugin.dists = np.delete(plugin.dists,
                                             plugin.sortrows[row],
                                             axis=1)
                    plugin.dists = np.append(plugin.dists,
                                             calc_dists(
                                                 controller,
                                                 controller2,
                                                 plugin.m_inds,
                                                 s_inds=plugin.dists.shape[1]),
                                             axis=1)
                    plugin.sortrows = np.argsort(plugin.matchi,
                                                 kind='mergesort')

                @actions.add(menu='Merge',
                             name='Move low-snr clusters to noise',
                             shortcut='n')
                def move_low_snr_to_noise(plugin=plugin,
                                          controller=controller,
                                          controller2=controller2):

                    cluster_ids = controller2.supervisor.clustering.cluster_ids
                    snrs = np.zeros(cluster_ids.shape)
                    for i in range(len(cluster_ids)):
                        snrs[i] = controller2.supervisor.cluster_view._columns[
                            'snr']['func'](cluster_ids[i])
                    thresh = 0.2  # for amplitude
                    thresh = 0.5  # for snr
                    noise_clusts = cluster_ids[snrs < thresh]

                    n_ind = []
                    for clu in noise_clusts:
                        this_ind = np.where(
                            controller2.supervisor.clustering.cluster_ids[
                                plugin.sortrows] == clu)[0][0]
                        n_ind.append(this_ind)

                    ind = plugin.m_inds.shape[0]
                    plugin.matchi[plugin.sortrows[n_ind]] = ind
                    plugin.m_inds = np.insert(plugin.m_inds, ind, -2)
                    plugin.best_channels = np.insert(plugin.best_channels, ind,
                                                     999)
                    plugin.unit_number = np.insert(plugin.unit_number, ind, 0)
                    plugin.unit_type = np.insert(plugin.unit_type, ind, 3)
                    plugin.dists = np.insert(plugin.dists, ind, -1, axis=0)
                    plugin.sortrows = np.argsort(plugin.matchi,
                                                 kind='mergesort')

                    create_table(controller, controller2, plugin)
                    tablegui.show()
                    st = 'Cluster ids {} moved to noise'.format(noise_clusts)
                    print(st)
                    plugin.tablegui.status_message = st

                @actions.add(menu='Merge',
                             name='Add new unit label',
                             shortcut='a')
                def add_unit(plugin=plugin,
                             controller=controller,
                             controller2=controller2):
                    chan, ok = QtGui.QInputDialog.getText(
                        None, 'Adding new unit label:', '       Channel:')
                    if not ok:
                        return
                    try:
                        chan = int(chan)
                    except:
                        plugin.tablegui.status_message = 'Error inputting channel'
                        return
                    nums = plugin.unit_number[plugin.best_channels == int(
                        chan)]
                    if len(nums) == 0:
                        next_unit_num = 1
                    else:
                        next_unit_num = int(nums.max()) + 1
                    dlg = QtGui.QInputDialog(None)
                    dlg.setInputMode(QtGui.QInputDialog.TextInput)
                    dlg.setTextValue('{}'.format(next_unit_num))
                    dlg.setLabelText("Unit number:")
                    dlg.resize(300, 300)
                    #dlg.mainLayout = QtGui.QVBoxLayout()
                    #dlg.setLayout(dlg.mainLayout)
                    b1 = QtGui.QRadioButton("SU", dlg)
                    b2 = QtGui.QRadioButton("MU", dlg)
                    b1.move(100, 0)
                    b2.move(150, 0)
                    b1.setChecked(True)
                    ok = dlg.exec_()
                    unit_number = dlg.textValue()
                    if not ok:
                        return
                    try:
                        unit_number = int(unit_number)
                    except:
                        plugin.tablegui.status_message = 'Error inputting unit number'
                        return
                    if b1.isChecked():
                        unit_type = 1
                    elif b2.isChecked():
                        unit_type = 2
                    else:
                        plugin.tablegui.status_message(
                            'Error getting unit type, must have checked either SU or MU'
                        )
                        return
                    below_inds = np.logical_and(
                        plugin.unit_type == unit_type,
                        plugin.best_channels <= chan).nonzero()[0]
                    if below_inds.shape[0] == 0:
                        below_inds = plugin.unit_type == unit_type
                        below_inds = below_inds.nonzero()[0]
                    ind = below_inds[-1] + 1
                    plugin.m_inds = np.insert(plugin.m_inds, ind, -1)
                    plugin.matchi[plugin.matchi >= ind] += 1
                    plugin.best_channels = np.insert(plugin.best_channels, ind,
                                                     chan)
                    plugin.unit_number = np.insert(plugin.unit_number, ind,
                                                   unit_number)
                    plugin.unit_type = np.insert(plugin.unit_type, ind,
                                                 unit_type)
                    plugin.dists = np.insert(plugin.dists, ind, -1, axis=0)
                    create_table(controller, controller2, plugin)

                    tablegui.show()

                @actions.add(menu='Merge',
                             name='Save cluster associations',
                             alias='sca')
                def save_cluster_associations(plugin=plugin,
                                              controller=controller,
                                              controller2=controller2):
                    un_matchi, counts = np.unique(plugin.matchi,
                                                  return_index=False,
                                                  return_inverse=False,
                                                  return_counts=True)
                    rmi = np.where(plugin.unit_type[un_matchi] == 3)[0]
                    if (len(rmi) > 0):
                        un_matchi = np.delete(un_matchi, rmi)
                        counts = np.delete(counts, rmi)
                    if np.any(counts > 1):
                        msgBox = QtGui.QMessageBox()
                        msgBox.setText(
                            'There are {} master clusters that are about '
                            'to be assigned multiple slave clusters. If this slave will '
                            'be used as a master for an addional merge, in most cases '
                            'slave clusters that share the same master match should be '
                            'merged.'.format(np.sum(counts > 1)))
                        msgBox.setInformativeText(
                            'Do you want to automatically do these merges?')
                        msgBox.setStandardButtons(QtGui.QMessageBox.Yes
                                                  | QtGui.QMessageBox.No
                                                  | QtGui.QMessageBox.Cancel)
                        msgBox.setDefaultButton(QtGui.QMessageBox.Yes)
                        ret = msgBox.exec_()
                        if ret == QtGui.QMessageBox.Yes:
                            merge_slaves_by_array(plugin, controller,
                                                  controller2,
                                                  un_matchi[counts > 1])
                            msgBox = QtGui.QMessageBox()
                            msgBox.setText(
                                'Merges done. Not saving yet! Check to see that everything is okay and then save.'
                            )
                            msgBox.exec_()
                            return
                        elif ret == QtGui.QMessageBox.No:
                            pass
                        else:
                            return
                    # assign labels to slave phy's clusters based on master phy's labels
                    good_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 1]].tolist()
                    controller2.supervisor.move('good', good_clusts)
                    mua_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 2]].tolist()
                    controller2.supervisor.move('mua', mua_clusts)
                    mua_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 3]].tolist()
                    controller2.supervisor.move('noise', mua_clusts)

                    #save both master and slave
                    controller.supervisor.save()
                    controller2.supervisor.save()

                    #save associations
                    create_tsv(
                        op.join(controller.model.dir_path, 'cluster_names.ts'),
                        controller.supervisor.clustering.cluster_ids[
                            plugin.m_inds[~np.in1d(plugin.m_inds, (-1, -2))]],
                        plugin.unit_type[~np.in1d(plugin.m_inds, (-1, -2))],
                        plugin.best_channels[~np.in1d(plugin.m_inds,
                                                      (-1, -2))],
                        plugin.unit_number[~np.in1d(plugin.m_inds, (-1, -2))])
                    create_tsv(
                        op.join(controller2.model.dir_path,
                                'cluster_names.ts'),
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows],
                        plugin.unit_type[plugin.matchi[plugin.sortrows]],
                        plugin.best_channels[plugin.matchi[plugin.sortrows]],
                        plugin.unit_number[plugin.matchi[plugin.sortrows]])
                    with open(
                            op.join(controller.model.dir_path,
                                    'Merged_Files.txt'), 'a') as text_file:
                        text_file.write('{} on {}\n'.format(
                            controller2.model.dir_path, time.strftime('%c')))
                    plugin.tablegui.status_message = 'Saved clusted associations'
                    print('Saved clusted associations')

                def create_tsv(filename, cluster_id, unit_type, channel,
                               unit_number):
                    if sys.version_info[0] < 3:
                        file = open(filename, 'wb')
                    else:
                        file = open(filename, 'w', newline='')
                    with file as f:
                        writer = csv.writer(f, delimiter='\t')
                        writer.writerow(
                            ['cluster_id', 'unit_type', 'chan', 'unit_number'])
                        for i in range(len(cluster_id)):
                            writer.writerow([
                                cluster_id[i], unit_type[i], channel[i],
                                unit_number[i]
                            ])

                tablegui.show()
예제 #25
0
파일: test.py 프로젝트: chongxi/mua
ch = 26
min_cluster_size = 5
leaf_size = 10

hdbcluster = HDBSCAN(min_cluster_size=min_cluster_size, 
                     leaf_size=leaf_size,
                     gen_min_span_tree=True, 
                     algorithm='boruvka_kdtree')
clu = hdbcluster.fit_predict(fet[ch])
print 'get clusters', np.unique(clu)


#
from phy.gui import GUI, create_app, run_app
create_app()
gui = GUI(position=(400, 200), size=(600, 400))

scatter_view = view_scatter_3d()
scatter_view.attach(gui)
scatter_view.set_data(fet[ch], clu)


nclu = len(np.unique(clu))
view = View(layout='grid',  shape=(3, nclu))
gui.add_view(view)
palette = sns.color_palette()

view.clear()
for chNo in range(3):
    for clu_id in np.unique(clu):
        color = palette[clu_id] if clu_id>=0 else np.array([1,1,1])
예제 #26
0
파일: xview.py 프로젝트: chongxi/mua
        if 1 in event.buttons and modifiers is not ():
            p1 = event.press_event.pos
            p2 = event.last_event.pos
            if modifiers[0].name == 'Shift':
                self.cross.ref_enable(p2)

        elif self.cross.cross_state:
            if event.press_event is None:
                self.cross.moveto(event.pos)
                self.cross.ref_disable()


if __name__ == '__main__':
    from phy.gui import GUI, create_app, run_app
    create_app()
    gui = GUI(position=(0, 0), size=(600, 400), name='GUI')
    ##############################################
    ### Test scatter_view
    from sklearn.preprocessing import normalize
    n = 1000000
    fet = np.random.randn(n,3)
    fet = normalize(fet,axis=1)
    print fet.shape
    clu = np.random.randint(3,size=(n,1))
    scatter_view = view_scatter_3d()
    scatter_view.attach(gui)
    scatter_view.set_data(fet, clu)
    #############################################################################################
    from Binload import Binload
    ### Set Parameters ###
    filename  = 'S:/pcie.bin'
예제 #27
0
def test_waveform_view(qtbot, tempdir):
    nc = 5

    def get_waveforms(cluster_id):
        return Bunch(
            data=artificial_waveforms(10, 20, nc), channel_ids=np.arange(nc), channel_positions=staggered_positions(nc)
        )

    v = WaveformView(waveforms=get_waveforms)
    gui = GUI(config_dir=tempdir)
    gui.show()
    v.attach(gui)
    qtbot.addWidget(gui)

    v.on_select([])
    v.on_select([0])
    v.on_select([0, 2, 3])
    v.on_select([0, 2])

    v.toggle_waveform_overlap()
    v.toggle_waveform_overlap()

    v.toggle_show_labels()
    v.toggle_show_labels()

    # Box scaling.
    bs = v.boxed.box_size
    v.increase()
    v.decrease()
    ac(v.boxed.box_size, bs)

    bs = v.boxed.box_size
    v.widen()
    v.narrow()
    ac(v.boxed.box_size, bs)

    # Probe scaling.
    bp = v.boxed.box_pos
    v.extend_horizontally()
    v.shrink_horizontally()
    ac(v.boxed.box_pos, bp)

    bp = v.boxed.box_pos
    v.extend_vertically()
    v.shrink_vertically()
    ac(v.boxed.box_pos, bp)

    a, b = v.probe_scaling
    v.probe_scaling = (a, b * 2)
    ac(v.probe_scaling, (a, b * 2))

    a, b = v.box_scaling
    v.box_scaling = (a * 2, b)
    ac(v.box_scaling, (a * 2, b))

    # Simulate channel selection.
    _clicked = []

    @v.gui.connect_
    def on_channel_click(channel_id=None, button=None, key=None):
        _clicked.append((channel_id, button, key))

    v.events.key_press(key=keys.Key("2"))
    v.events.mouse_press(pos=(0.0, 0.0), button=1)
    v.events.key_release(key=keys.Key("2"))

    assert _clicked == [(0, 1, 2)]

    # qtbot.stop()
    gui.close()