コード例 #1
0
ファイル: base.py プロジェクト: GrohLab/phy
    def attach(self, gui):
        """Attach the view to the GUI.

        Perform the following:

        - Add the view to the GUI.
        - Update the view's attribute from the GUI state
        - Add the default view actions (auto_update, screenshot)
        - Bind the on_select() method to the select event raised by the supervisor.

        """

        # Add shortcuts only for the first view of any given type.
        shortcuts = self.shortcuts if not gui.list_views(self.__class__) else None

        gui.add_view(self, position=self._default_position)
        self.gui = gui

        # Set the view state.
        self.set_state(gui.state.get_view_state(self))

        self.actions = Actions(
            gui, name=self.name, view=self,
            default_shortcuts=shortcuts, default_snippets=self.default_snippets)

        # Freeze and unfreeze the view when selecting clusters.
        self.actions.add(
            self.toggle_auto_update, checkable=True, checked=self.auto_update, show_shortcut=False)
        self.actions.add(self.screenshot, show_shortcut=False)
        self.actions.add(self.close, show_shortcut=False)
        self.actions.separator()

        on_select = partial(self.on_select_threaded, gui=gui)
        connect(on_select, event='select')

        # Save the view state in the GUI state.
        @connect
        def on_close_view(view_, gui):
            if view_ != self:
                return
            logger.debug("Close view %s.", self.name)
            self._closed = True
            gui.remove_menu(self.name)
            unconnect(on_select)
            gui.state.update_view_state(self, self.state)
            self.canvas.close()
            gc.collect(0)

        @connect(sender=gui)
        def on_close(sender):
            gui.state.update_view_state(self, self.state)

        # HACK: Fix bug on macOS where docked OpenGL widgets were not displayed at startup.
        self._set_floating = AsyncCaller(delay=5)

        @self._set_floating.set
        def _set_floating():
            self.dock.setFloating(False)

        emit('view_attached', self, gui)
コード例 #2
0
 def create_actions(self, gui):
     """Create the actions."""
     self.actions = Actions(gui)
     # self.actions.add(self.create_params, name="Create params.py", toolbar=True)
     self.actions.add(self.find_good_channels,
                      name="Find good channels",
                      toolbar=True)
     self.actions.add(self.preprocess, name="Preprocess", toolbar=True)
     self.actions.add(self.spike_sort, name="Spike sort", toolbar=True)
コード例 #3
0
ファイル: views.py プロジェクト: mspacek/phy
    def attach(self, gui):
        """Attach the view to the GUI."""

        # Disable keyboard pan so that we can use arrows as global shortcuts
        # in the GUI.
        self.panzoom.enable_keyboard_pan = False

        gui.add_view(self)
        self.gui = gui

        # Set the view state.
        self.set_state(gui.state.get_view_state(self))

        gui.connect_(self.on_select)
        self.actions = Actions(gui,
                               name=self.__class__.__name__,
                               menu=self.__class__.__name__,
                               default_shortcuts=self.shortcuts)

        # Update the GUI status message when the `self.set_status()` method
        # is called, i.e. when the `status` event is raised by the VisPy
        # view.
        @self.connect
        def on_status(e):
            gui.status_message = e.message

        # Save the view state in the GUI state.
        @gui.connect_
        def on_close():
            gui.state.update_view_state(self, self.state)
            # NOTE: create_gui() already saves the state, but the event
            # is registered *before* we add all views.
            gui.state.save()

        self.show()
コード例 #4
0
        def on_gui_ready(gui):

            actions = Actions(gui)

            @actions.add(alias='emw')
            def ExportMeanWaveforms(max_waveforms_per_cluster=1E4,
                                    controller=controller):
                #make max_waveforms_per_cluster a really big number if you want to get all the waveforms (slow)
                print('Exporting mean waveforms')
                cluster_ids = controller.supervisor.clustering.cluster_ids
                mean_waveforms = np.zeros(
                    (controller.model.n_channels_dat,
                     controller.model.n_samples_templates, len(cluster_ids)))
                for i in range(len(cluster_ids)):
                    print('i={0},cluster={1}'.format(i, cluster_ids[i]))
                    spike_ids = controller.selector.select_spikes(
                        [cluster_ids[i]], max_waveforms_per_cluster,
                        controller.batch_size_waveforms)
                    #channel_ids = controller.get_best_channels(cluster_ids[i])
                    channel_ids = np.arange(
                        controller.model.n_channels_dat)  #gets all chnnels
                    data = controller.model.get_waveforms(
                        spike_ids, channel_ids)
                    mean_waveforms[:, :, i] = np.rollaxis(data.mean(0), 1)
                np.save(
                    op.join(controller.model.dir_path, 'mean_waveforms.npy'),
                    mean_waveforms)
                print('Done exporting mean waveforms')
コード例 #5
0
ファイル: CellTypes.py プロジェクト: LBHB/phy-contrib
        def on_gui_ready(gui, **kwargs):

            actions = Actions(gui)

            @actions.add(alias='celltypes')
            def ExportCellTypes(max_waveforms_per_cluster=1E3,
                                controller=controller):
                export_cell_types(controller, max_waveforms_per_cluster)

            @controller.supervisor.connect
            def on_request_save(*args, controller=controller):
                if len(args) >= 6:
                    spike_clusters = args[0]
                    groups = args[1]
                    amplitude = args[2]
                    contamination = args[3]
                    KS_label = args[4]
                    labels = args[5]
                elif len(args) == 3:
                    spike_clusters = args[0]
                    groups = args[1]
                    labels = args[2]
                else:
                    import pdb
                    pdb.set_trace()
                    pass
                export_cell_types(controller=controller, groups=groups)
コード例 #6
0
    def attach(self, gui, name=None):
        """Attach the view to the GUI."""

        # Disable keyboard pan so that we can use arrows as global shortcuts
        # in the GUI.
        self.panzoom.enable_keyboard_pan = False

        gui.add_view(self)
        self.gui = gui

        # Set the view state.
        self.set_state(gui.state.get_view_state(self))

        # Call on_select() asynchronously after a delay, and set a busy
        # cursor.
        self.async_caller = AsyncCaller(delay=self._callback_delay)

        @gui.connect_
        def on_select(cluster_ids, **kwargs):
            # Call this function after a delay unless there is another
            # cluster selection in the meantime.
            @self.async_caller.set
            def update_view():
                with busy_cursor():
                    self.on_select(cluster_ids, **kwargs)

        self.actions = Actions(gui,
                               name=name or self.__class__.__name__,
                               menu=self.__class__.__name__,
                               default_shortcuts=self.shortcuts)

        # Update the GUI status message when the `self.set_status()` method
        # is called, i.e. when the `status` event is raised by the VisPy
        # view.
        @self.connect
        def on_status(e):
            gui.status_message = e.message

        # Save the view state in the GUI state.
        @gui.connect_
        def on_close():
            gui.state.update_view_state(self, self.state)
            # NOTE: create_gui() already saves the state, but the event
            # is registered *before* we add all views.
            gui.state.save()

        self.show()
コード例 #7
0
ファイル: CopyTraceView.py プロジェクト: LBHB/phy-contrib
        def on_gui_ready(gui, **kwargs):

            actions = Actions(gui)

            @actions.add(menu='TraceView')
            def CopyTraceView():
                tv = gui.get_view('TraceView')
                m = controller.model
                tv2 = TraceView(
                    traces=tv.traces,
                    n_channels=tv.n_channels,
                    sample_rate=tv.sample_rate,
                    duration=tv.duration,
                    channel_vertical_order=m.channel_vertical_order,
                )
                gui.add_view(tv2, name='Trace2View')
                tv2.do_show_labels = tv.do_show_labels
                tv2.set_interval(tv._interval)
                tv2.go_to(tv.time)
                tv2.panzoom.set_pan_zoom(zoom=tv.panzoom._zoom,
                                         pan=tv.panzoom._pan)
コード例 #8
0
ファイル: ExportSNRs.py プロジェクト: LBHB/phy-contrib
        def on_gui_ready(gui):

            actions = Actions(gui)

            @actions.add(alias='esnr', menu='Export', name='Export SNRs')
            def ExportSNRs(max_waveforms_per_cluster=1E3,
                           controller=controller):
                #make max_waveforms_per_cluster a really big number if you want to get all the waveforms (slow)
                print('Exporting SNRs')
                cluster_ids = controller.supervisor.clustering.cluster_ids
                snr = np.zeros(
                    (controller.model.n_channels_dat, len(cluster_ids)))
                snr[:] = np.NAN
                for i in range(len(cluster_ids)):
                    print('i={0},cluster={1}'.format(i, cluster_ids[i]))
                    if max_waveforms_per_cluster == 100:
                        all_data = controller._get_waveforms(
                            int(cluster_ids[i]))
                        data = all_data.data
                        channel_ids = all_data.channel_ids
                    else:
                        spike_ids = controller.selector.select_spikes(
                            [cluster_ids[i]],
                            max_waveforms_per_cluster,
                            controller.batch_size_waveforms,
                            #subset='random', to get a random subset
                        )
                        #channel_ids = controller.get_best_channels(cluster_ids[i])
                        channel_ids = np.arange(
                            controller.model.n_channels_dat)  #gets all chnnels
                        data = controller.model.get_waveforms(
                            spike_ids, channel_ids)
                    noise_std = np.concatenate(
                        (data[:, :10, :], data[:, :10, :]),
                        axis=1).std(axis=(0, 1))
                    sig_std = data.mean(0).std(0)
                    snr[channel_ids, i] = sig_std / noise_std

                np.save(op.join(controller.model.dir_path, 'snrs.npy'), snr)
                print('Done exporting snrs')
コード例 #9
0
class ManualClusteringView(object):
    """Base class for clustering views.

    Typical property objects:

    - `self.canvas`: a `PlotCanvas` instance by default (can also be a `PlotCanvasMpl` instance).
    - `self.default_shortcuts`: a dictionary with the default keyboard shortcuts for the view
    - `self.shortcuts`: a dictionary with the actual keyboard shortcuts for the view (can be passed
      to the view's constructor).
    - `self.state_attrs`: a tuple with all attributes that should be automatically saved in the
      view's global GUI state.
    - `self.local_state_attrs`: like above, but for the local GUI state (dataset-dependent).

    Events raised:

    - `view_attached(view, gui)`: this is the event to connect to if you write a plugin that
      needs to modify a view.
    - `is_busy(view)`
    - `toggle_auto_update(view)`

    """
    default_shortcuts = {}
    default_snippets = {}
    auto_update = True  # automatically update the view when the cluster selection changes
    _default_position = None
    plot_canvas_class = PlotCanvas
    ex_status = ''  # the GUI can update this to
    max_n_clusters = 0  # By default, show all clusters.

    def __init__(self, shortcuts=None, **kwargs):
        self._lock = None
        self._closed = False
        self.cluster_ids = ()

        # Load default shortcuts, and override with any user shortcuts.
        self.shortcuts = self.default_shortcuts.copy()
        self.shortcuts.update(shortcuts or {})

        # Whether to enable threading. Disabled in tests.
        self._enable_threading = kwargs.get('enable_threading', True)

        # List of attributes to save in the GUI view state.
        self.state_attrs = ('auto_update', )

        # List of attributes to save in the local GUI state as well.
        self.local_state_attrs = ()

        # Attached GUI.
        self.gui = None

        self.canvas = self.plot_canvas_class()

        # Attach the Qt events to this class, so that derived class
        # can override on_mouse_click() and so on.
        self.canvas.attach_events(self)

    # -------------------------------------------------------------------------
    # Internal methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self, bunchs):
        """Compute the data bounds."""
        # Return the extended data_bounds if they
        return extend_bounds([_get_bunch_bounds(bunch) for bunch in bunchs])

    def _plot_cluster(self, bunch):
        """Plot one cluster.

        To override.

        """
        pass

    def _update_axes(self):
        """Update the axes."""
        self.canvas.axes.reset_data_bounds(self.data_bounds)

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids.

        To override.

        """
        return

    def plot(self, **kwargs):  # pragma: no cover
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data()
        self.data_bounds = self._get_data_bounds(bunchs)
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self._update_axes()
        self.canvas.update()

    # -------------------------------------------------------------------------
    # Main public methods
    # -------------------------------------------------------------------------

    def on_select(self, cluster_ids=None, **kwargs):
        """Callback function when clusters are selected. May be overriden."""
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        self.plot(**kwargs)

    def on_select_threaded(self, sender, cluster_ids, gui=None, **kwargs):
        # Decide whether the view should react to the select event or not.
        if not self.auto_update or self._closed:
            return
        # Only the Supervisor and some specific views can trigger a proper select event.
        if sender.__class__.__name__ in ('ClusterView', 'SimilarityView'):
            return
        assert isinstance(cluster_ids, list)
        if not cluster_ids:
            return
        # Maximum number of clusters that can be displayed in the view, for performance reasons.
        if self.max_n_clusters and len(cluster_ids) > self.max_n_clusters:
            return

        # The lock is used so that two different background threads do not access the same
        # view simultaneously, which can lead to conflicts, errors in the plotting code,
        # and QTimer thread exceptions that lead to frozen OpenGL views.
        if self._lock:
            return
        self._lock = True

        # The view update occurs in a thread in order not to block the main GUI thread.
        # A complication is that OpenGL updates should only occur in the main GUI thread,
        # whereas the computation of the data buffers to upload to the GPU should happen
        # in a thread. Finally, the select events are throttled (more precisely, debounced)
        # to avoid clogging the GUI when many clusters are successively selected, but this
        # is implemented at the level of the table widget, not here.

        # This function executes in the Qt thread pool.
        def _worker():  # pragma: no cover
            self.on_select(cluster_ids=cluster_ids, **kwargs)

        # We launch this function in the thread pool.
        worker = Worker(_worker)

        # Once the worker has finished in the thread, the finished signal is raised,
        # and the callback function below runs on the main GUI thread.
        # All OpenGL updates triggered in the worker (background thread) where recorded
        # instead of being immediately executed (which would have caused errors because
        # OpenGL updates should not be executed from a background thread).
        # Once these updates have been collected in the right order, we execute all of
        # them here, in the main GUI thread.
        @worker.signals.finished.connect
        def finished():
            # HACK: work-around for https://github.com/cortex-lab/phy/issues/1016
            try:
                self
            except NameError as e:  # pragma: no cover
                logger.warning(str(e))
                return
            # When the task has finished in the thread pool, we recover all program
            # updates of the view, and we execute them on the GPU.
            if isinstance(self.canvas, PlotCanvas):
                self.canvas.set_lazy(False)
                # We go through all collected OpenGL updates.
                for program, name, data in self.canvas.iter_update_queue():
                    # We update data buffers in OpenGL programs.
                    program[name] = data
            # Finally, we update the canvas.
            self.canvas.update()
            emit('is_busy', self, False)
            self._lock = None
            self.update_status()

        # Start the task on the thread pool, and let the OpenGL canvas know that we're
        # starting to record all OpenGL calls instead of executing them immediately.
        # This is what we call the "lazy" mode.
        emit('is_busy', self, True)
        if getattr(gui, '_enable_threading', True):
            # This is only for OpenGL views.
            self.canvas.set_lazy(True)
            thread_pool().start(worker)
        else:  # pragma: no cover
            # This is for OpenGL views, without threading.
            worker.run()
            self._lock = None

    def on_cluster(self, up):
        """Callback function when a clustering action occurs. May be overriden.

        Note: this method is called *before* on_select() so as to give a chance to the view
        to update itself before the selection of the new clusters.

        This method is mostly only useful to views that show all clusters and not just the
        selected clusters (template view, raster view).

        """

    def attach(self, gui):
        """Attach the view to the GUI.

        Perform the following:

        - Add the view to the GUI.
        - Update the view's attribute from the GUI state
        - Add the default view actions (auto_update, screenshot)
        - Bind the on_select() method to the select event raised by the supervisor.

        """

        # Add shortcuts only for the first view of any given type.
        shortcuts = self.shortcuts if not gui.list_views(
            self.__class__) else None

        gui.add_view(self, position=self._default_position)
        self.gui = gui

        # Set the view state.
        self.set_state(gui.state.get_view_state(self))

        self.actions = Actions(gui,
                               name=self.name,
                               view=self,
                               default_shortcuts=shortcuts,
                               default_snippets=self.default_snippets)

        # Freeze and unfreeze the view when selecting clusters.
        self.actions.add(self.toggle_auto_update,
                         checkable=True,
                         checked=self.auto_update,
                         show_shortcut=False)
        self.actions.add(self.screenshot, show_shortcut=False)
        self.actions.add(self.close, show_shortcut=False)
        self.actions.separator()

        on_select = partial(self.on_select_threaded, gui=gui)
        connect(on_select, event='select')

        # Save the view state in the GUI state.
        @connect(view=self)
        def on_close_view(view_, gui):
            logger.debug("Close view %s.", self.name)
            self._closed = True
            gui.remove_menu(self.name)
            unconnect(on_select)
            gui.state.update_view_state(self, self.state)
            self.canvas.close()
            gc.collect(0)

        @connect(sender=gui)
        def on_close(sender):
            gui.state.update_view_state(self, self.state)

        # HACK: Fix bug on macOS where docked OpenGL widgets were not displayed at startup.
        self._set_floating = AsyncCaller(delay=5)

        @self._set_floating.set
        def _set_floating():
            self.dock.setFloating(False)

        emit('view_attached', self, gui)

    @property
    def status(self):
        """To be overriden."""
        return ''

    def update_status(self):
        if hasattr(self, 'dock'):
            self.dock.set_status('%s %s' % (self.status, self.ex_status))

    # -------------------------------------------------------------------------
    # Misc public methods
    # -------------------------------------------------------------------------

    def toggle_auto_update(self, checked):
        """When on, the view is automatically updated when the cluster selection changes."""
        logger.debug("%s auto update for %s.",
                     'Enable' if checked else 'Disable', self.name)
        self.auto_update = checked
        emit('toggle_auto_update', self, checked)

    def screenshot(self, dir=None):
        """Save a PNG screenshot of the view into a given directory. By default, the screenshots
        are saved in `~/.phy/screenshots/`."""
        path = screenshot_default_path(self, dir=dir)
        return screenshot(self.canvas, path=path)

    @property
    def state(self):
        """View state, a Bunch instance automatically persisted in the GUI state when the
        GUI is closed. To be overriden."""
        attrs = set(self.state_attrs + self.local_state_attrs)
        return Bunch({key: getattr(self, key, None) for key in attrs})

    def set_state(self, state):
        """Set the view state.

        The passed object is the persisted `self.state` bunch.

        May be overriden.

        """
        logger.debug("Set state for %s.",
                     getattr(self, 'name', self.__class__.__name__))
        for k, v in state.items():
            setattr(self, k, v)

    def show(self):
        """Show the underlying canvas."""
        return self.canvas.show()

    def close(self):
        """Close the view."""
        if hasattr(self, 'dock'):
            return self.dock.close()
        self.canvas.close()
        self._closed = True
        unconnect(self)
        gc.collect(0)
コード例 #10
0
class KilosortGUICreator(object):
    def __init__(self, dat_path, **kwargs):
        self.dat_path = Path(dat_path).resolve()
        self.gui_name = 'PythonKilosortGUI'
        self.__dict__.update(kwargs)
        self.load_data()

    def load_data(self):
        # TODO: use EphysTraces
        dat_path = self.dat_path
        if dat_path.suffix == '.cbin':
            data = load_raw_data(path=dat_path)
            sample_rate = data.sample_rate
            n_channels_dat = data.shape[1]
        else:
            sample_rate = float(self.sample_rate)
            assert sample_rate > 0.

            n_channels_dat = int(self.n_channels_dat)

            dtype = np.dtype(self.dtype)
            offset = int(self.offset or 0)
            order = getattr(self, 'order', None)

            # Memmap the raw data file.
            data = load_raw_data(
                path=dat_path,
                n_channels_dat=n_channels_dat,
                dtype=dtype,
                offset=offset,
                order=order,
            )

        # Parameters for the creation of params.py
        self.n_channels_dat = n_channels_dat
        self.offset = offset
        self.dtype = dtype

        self.data = data
        self.duration = self.data.shape[0] / sample_rate

    # def create_params(self):
    # paramspy = DEFAULT_PARAMS_PY.format(
    #     dat_path='["%s"]' % str(self.dat_path),
    #     n_channels_dat=self.n_channels_dat,
    #     offset=self.offset,
    #     dtype=self.dtype,
    #     sample_rate=self.sample_rate,
    # )

    def _run(self, stop_after=None):
        # TODO: test
        run(self.dat_path,
            self.probe,
            dir_path=self.dir_path,
            stop_after=stop_after)

    def find_good_channels(self):
        self._run('good_channels')
        # TODO: update probe view

    def preprocess(self):
        self._run('preprocess')
        # TODO: update trace view

    def spike_sort(self):
        self._run()
        # TODO: create custom logging handler that redirectors to ipython view
        # view.append_stream(...)

    def create_actions(self, gui):
        """Create the actions."""
        self.actions = Actions(gui)
        # self.actions.add(self.create_params, name="Create params.py", toolbar=True)
        self.actions.add(self.find_good_channels,
                         name="Find good channels",
                         toolbar=True)
        self.actions.add(self.preprocess, name="Preprocess", toolbar=True)
        self.actions.add(self.spike_sort, name="Spike sort", toolbar=True)

    def create_ipython_view(self, gui):
        """Add the IPython view."""
        view = IPythonView()
        view.attach(gui)

        view.inject(gui=gui, creator=self, data=self.data)

        return view

    def create_trace_view(self, gui):
        """Add the trace view."""
        gui._toolbar.addWidget(QLabel("Time selection: "))
        time_slider = QSlider(Qt.Horizontal, gui)
        time_slider.setRange(0, 100)
        time_slider.setTracking(False)
        gui._toolbar.addWidget(time_slider)
        self.time_slider = time_slider

        def _get_traces(interval):
            return Bunch(data=select_traces(
                self.data, interval, sample_rate=self.sample_rate))

        view = TraceImageView(
            traces=_get_traces,
            n_channels=self.n_channels_dat,
            sample_rate=self.sample_rate,
            duration=self.duration,
            enable_threading=False,
        )
        view.attach(gui)

        self.move_time_slider_to(view.time)

        @time_slider.valueChanged.connect
        def time_slider_changed():
            view.go_to(
                float(time_slider.sliderPosition()) / 100 * self.duration)

        @connect(sender=view)
        def on_time_range_selected(sender, interval):
            self.move_time_slider_to(.5 * (interval[0] + interval[1]))

        return view

    def move_time_slider_to(self, time):
        """Move the time slider."""
        self.time_slider.setSliderPosition(int(time / self.duration * 100))

    def create_probe_view(self, gui):
        """Add the view that shows the probe layout."""
        channel_positions = linear_positions(self.n_channels_dat)
        view = ProbeView(channel_positions)
        view.attach(gui)
        # TODO: update positions dynamically when the probe view changes
        return view

    def create_params_widget(self, gui):
        """Create the widget that allows to enter parameters for KS2."""
        widget = KeyValueWidget(gui)
        for name, default in default_params.items():
            # HACK: None default params in KS2 are floats
            vtype = 'float' if default is None else None
            widget.add_pair(name, default, vtype=vtype)
        # Time interval (TODO: take it into account with EphysTraces).
        widget.add_pair('time interval', [0.0, self.duration])
        widget.add_pair(
            'custom probe',
            dedent('''
        # Python code that returns a probe variable which is a Bunch instance,
        # with the following variables: NchanTOT, chanMap, xc, yc, kcoords.
        ''').strip(), 'multiline')

        scroll = QScrollArea()
        scroll.setWidget(widget)
        scroll.setWidgetResizable(True)
        # scroll.show()

        widget = Parameters(gui)
        layout = QVBoxLayout(widget)
        layout.addWidget(scroll)

        gui.add_view(widget)
        return widget

    def create_gui(self):
        """Create the spike sorting GUI."""

        gui = GUI(name=self.gui_name,
                  subtitle=self.dat_path.resolve(),
                  enable_threading=False)
        gui.has_save_action = False
        gui.set_default_actions()
        self.create_actions(gui)
        self.create_params_widget(gui)
        self.create_ipython_view(gui)
        self.create_trace_view(gui)
        self.create_probe_view(gui)

        return gui
コード例 #11
0
    def attach(self, gui):
        """Attach the view to the GUI.

        Perform the following:

        - Add the view to the GUI.
        - Update the view's attribute from the GUI state
        - Add the default view actions (auto_update, screenshot)
        - Bind the on_select() method to the select event raised by the supervisor.
          This runs on a background thread not to block the GUI thread.

        """

        # Add shortcuts only for the first view of any given type.
        shortcuts = self.shortcuts if not gui.list_views(self.__class__) else None

        gui.add_view(self, position=self._default_position)
        self.gui = gui

        # Set the view state.
        self.set_state(gui.state.get_view_state(self))

        self.actions = Actions(
            gui, name=self.name, menu='&View', submenu=self.name,
            default_shortcuts=shortcuts, default_snippets=self.default_snippets)

        # Freeze and unfreeze the view when selecting clusters.
        self.actions.add(
            self.toggle_auto_update, checkable=True, checked=self.auto_update, show_shortcut=False)
        self.actions.add(self.screenshot, show_shortcut=False)
        self.actions.separator()

        emit('view_actions_created', self)

        @connect
        def on_select(sender, cluster_ids, **kwargs):
            # Decide whether the view should react to the select event or not.
            if not self.auto_update:
                return
            if sender.__class__.__name__ != 'Supervisor':
                return
            assert isinstance(cluster_ids, list)
            if not cluster_ids:
                return

            # The lock is used so that two different background threads do not access the same
            # view simultaneously, which can lead to conflicts, errors in the plotting code,
            # and QTimer thread exceptions that lead to frozen OpenGL views.
            if self._lock:
                return
            self._lock = True

            # The view update occurs in a thread in order not to block the main GUI thread.
            # A complication is that OpenGL updates should only occur in the main GUI thread,
            # whereas the computation of the data buffers to upload to the GPU should happen
            # in a thread. Finally, the select events are throttled (more precisely, debounced)
            # to avoid clogging the GUI when many clusters are successively selected, but this
            # is implemented at the level of the table widget, not here.

            # This function executes in the Qt thread pool.
            def _worker():  # pragma: no cover
                self.on_select(cluster_ids=cluster_ids, **kwargs)

            # We launch this function in the thread pool.
            worker = Worker(_worker)

            # Once the worker has finished in the thread, the finished signal is raised,
            # and the callback function below runs on the main GUI thread.
            # All OpenGL updates triggered in the worker (background thread) where recorded
            # instead of being immediately executed (which would have caused errors because
            # OpenGL updates should not be executed from a background thread).
            # Once these updates have been collected in the right order, we execute all of
            # them here, in the main GUI thread.
            @worker.signals.finished.connect
            def finished():
                # When the task has finished in the thread pool, we recover all program
                # updates of the view, and we execute them on the GPU.
                if isinstance(self.canvas, PlotCanvas):
                    self.canvas.set_lazy(False)
                    # We go through all collected OpenGL updates.
                    for program, name, data in self.canvas.iter_update_queue():
                        # We update data buffers in OpenGL programs.
                        program[name] = data
                # Finally, we update the canvas.
                self.canvas.update()
                emit('is_busy', self, False)
                self._lock = None

            # Start the task on the thread pool, and let the OpenGL canvas know that we're
            # starting to record all OpenGL calls instead of executing them immediately.
            # This is what we call the "lazy" mode.
            emit('is_busy', self, True)
            if getattr(gui, '_enable_threading', True):
                # This is only for OpenGL views.
                self.canvas.set_lazy(True)
                thread_pool().start(worker)
            else:
                # This is for OpenGL views, without threading.
                worker.run()
                self._lock = None

        # Update the GUI status message when the `self.set_status()` method
        # is called, i.e. when the `status` event is raised by the view.
        @connect(sender=self)  # pragma: no cover
        def on_status(sender=None, e=None):
            gui.status_message = e.message

        # Save the view state in the GUI state.
        @connect(sender=gui)
        def on_close_view(sender, view):
            if view != self:
                return
            logger.debug("Close view %s.", self.name)
            gui.remove_menu(self.name)
            unconnect(on_select)
            gui.state.update_view_state(self, self.state)
            self.canvas.close()
            gc.collect(0)

        @connect(sender=gui)
        def on_close(sender):
            gui.state.update_view_state(self, self.state)

        # HACK: Fix bug on macOS where docked OpenGL widgets were not displayed at startup.
        self._set_floating = AsyncCaller(delay=1)
        @self._set_floating.set
        def _set_floating():
            self.dock_widget.setFloating(False)
コード例 #12
0
        def on_gui_ready(gui):

            cluster_ids = controller.supervisor.clustering.cluster_ids
            self.text_handles = [None] * len(cluster_ids)
            self.color = [None] * len(cluster_ids)
            self.text_color = [None] * len(cluster_ids)
            self.fig = [None]
            self.drawn = False

            @gui.connect_
            def on_select(clusters, controller=controller, **kwargs):
                if self.drawn:
                    for i in range(len(cluster_ids)):
                        if any(cluster_ids[i] == clusters):
                            self.text_handles[i].set_color('r')
                        else:
                            self.text_handles[i].set_color(self.text_color[i])

                    self.fig.canvas.draw()
                    self.fig.show()

            actions = Actions(gui)

            @actions.add(alias='pcl')  #shortcut='ctrl+p',
            def PlotClusterLocations(controller=controller):
                cluster_ids = controller.supervisor.clustering.cluster_ids
                print('Plotting Cluster Locations')
                self.drawn = True
                self.text_handles = [None] * len(cluster_ids)
                self.color = [None] * len(cluster_ids)
                self.text_color = [None] * len(cluster_ids)
                height = np.zeros(len(cluster_ids))
                center_of_mass = np.zeros((len(cluster_ids), 2))
                type = np.zeros(len(cluster_ids), dtype=int)
                for i in range(len(cluster_ids)):
                    mv = np.zeros(controller.model.n_channels_dat)
                    data = controller._get_waveforms(int(cluster_ids[i]))
                    height[i] = abs(data.data.mean(0).min())
                    mv[data.channel_ids] = -data.data.mean(0).min(0)
                    mv[mv < 0] = 0
                    if (any(mv) == False) or (any(
                            -data.data.mean(0).min(0) > data.data.mean(
                                0)[:2].mean(0)) == False):
                        #Quick fix for small-amplitude (usually noise) clusters
                        #sets clusters with all mean amplitudes less than zero to have a location defined by get_best_channel instead of a weighted average
                        mv[controller.get_best_channel(cluster_ids[i])] = 1

                    mv = mv / mv.sum()
                    center_of_mass[i, :] = (
                        mv * controller.model.channel_positions.T).sum(1)

                    if controller.model.n_channels == 128:
                        # LAS hack to make 128D cluster locations look better
                        if center_of_mass[i,
                                          0] > 200 and center_of_mass[i,
                                                                      0] < 500:
                            center_of_mass[i, 0] = center_of_mass[i, 0] - 200
                            center_of_mass[i, 1] = center_of_mass[i, 1] - 1000
                        if center_of_mass[i,
                                          0] > 600 and center_of_mass[i,
                                                                      0] < 800:
                            center_of_mass[i, 0] = center_of_mass[i, 0] - 400
                            center_of_mass[i, 1] = center_of_mass[i, 1] - 2000
                        if center_of_mass[i, 0] > 800:
                            center_of_mass[i, 0] = center_of_mass[i, 0] - 600
                            center_of_mass[i, 1] = center_of_mass[i, 1] - 3000
                    if cluster_ids[
                            i] not in controller.supervisor.cluster_groups.keys(
                            ) or controller.supervisor.cluster_groups[
                                cluster_ids[i]] == '':
                        type[i] = 0
                        self.color[i] = (1, 1, 1)
                        self.text_color[i] = (1, 1, 1)
                    elif controller.supervisor.cluster_groups[
                            cluster_ids[i]] == 'good':
                        type[i] = 3
                        self.color[i] = (0.5255, 0.8196, 0.4275)
                        self.text_color[i] = (1, 1, 1)
                    elif controller.supervisor.cluster_groups[
                            cluster_ids[i]] == 'mua':
                        type[i] = 2
                        self.color[i] = (0, .7333, 1)
                        self.text_color[i] = (1, 1, 1)
                    elif controller.supervisor.cluster_groups[
                            cluster_ids[i]] == 'noise':
                        type[i] = 1
                        self.color[i] = (.5, .5, .5)
                        self.text_color[i] = (.5, .5, .5)
                    else:
                        from PyQt4.QtCore import pyqtRemoveInputHook
                        from pdb import set_trace
                        pyqtRemoveInputHook()
                        set_trace()
                        raise RuntimeError(
                            'Cluster group ({}) of cluster {} is unknown.'.
                            format(
                                controller.supervisor.cluster_groups[
                                    cluster_ids[i]], cluster_ids[i]))

                py.rc('xtick', color='w')
                py.rc('ytick', color='w')
                py.rc('axes', edgecolor='w')
                self.fig = py.figure()
                ax = self.fig.add_axes([0.15, 0.02, 0.83, 0.975])
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                mngr = py.get_current_fig_manager()
                mngr.window.setGeometry(1700, 20, 220, 1180)
                rect = self.fig.patch
                rect.set_facecolor('k')

                ax.scatter(center_of_mass[:, 0],
                           center_of_mass[:, 1],
                           height / height.max() * 200,
                           facecolors='none',
                           edgecolors=self.color)
                # inds = np.where(use1)[0]
                # for j in range(len(inds)):
                #     self.text_handles[i]=ax.annotate(cluster_ids[inds[j]], (center_of_mass[inds[j],0],center_of_mass[inds[j],1]),color=textcolors[i],fontsize=10)
                for i in range(len(cluster_ids)):
                    self.text_handles[i] = ax.annotate(
                        cluster_ids[i],
                        (center_of_mass[i, 0], center_of_mass[i, 1]),
                        color=self.text_color[i],
                        fontsize=10)

                ax.set_axis_bgcolor('k')
                mins = center_of_mass.min(0) - abs(center_of_mass.min(0) * .1)
                maxs = center_of_mass.max(0) + abs(center_of_mass.max(0) * .1)
                ax.axis((mins[0], maxs[0], mins[1], maxs[1]))
                self.fig.show()
                save_path = op.join(controller.model.dir_path,
                                    'cluster_centers_of_mass.npy')
                np.save(save_path, center_of_mass)
                print(
                    'Cluster centers of mass exported to {}'.format(save_path))
コード例 #13
0
            def MergeRuns(controller=controller, plugin=plugin):
                if True:
                    path2 = QtGui.QFileDialog.getExistingDirectory(
                        None,
                        "Select the results folder for the sort to be merged",
                        op.dirname(
                            op.dirname(controller.model.dir_path)
                        ),  #two folders up from the current phy's path
                        QtGui.QFileDialog.ShowDirsOnly)
                else:
                    path2 = '/home/luke/KiloSort_tmp/BOL005c_9_96clusts/results'
                params_path = op.join(path2, 'params.py')
                params = _read_python(params_path)
                params['dtype'] = np.dtype(params['dtype'])
                params['path'] = path2
                if op.realpath(params['dat_path']) != params['dat_path']:
                    params['dat_path'] = op.join(path2, params['dat_path'])
                print('Loading {}'.format(path2))
                controller2 = TemplateController(**params)
                #controller2.gui_name = 'TemplateGUI2'
                gui2 = controller2.create_gui()
                gui2.show()

                #                @gui2.connect_
                #                def on_select(clusters,controller2=controller2):
                #                    controller.supervisor.select(clusters)

                #create mean_waveforms for each controller (run)
                print('computing mean waveforms for master run...')
                controller.mean_waveforms = create_mean_waveforms(
                    controller, max_waveforms_per_cluster=100)
                print('computing mean waveforms for slave run...')
                controller2.mean_waveforms = create_mean_waveforms(
                    controller2, max_waveforms_per_cluster=100)

                groups = {
                    c: controller.supervisor.cluster_meta.get('group', c)
                    or 'unsorted'
                    for c in controller.supervisor.clustering.cluster_ids
                }
                groups2 = {
                    c: controller2.supervisor.cluster_meta.get('group', c)
                    or 'unsorted'
                    for c in controller2.supervisor.clustering.cluster_ids
                }
                su_inds = np.nonzero([
                    controller.supervisor.cluster_meta.get('group',
                                                           c) == 'good'
                    for c in controller.supervisor.clustering.cluster_ids
                ])[0]
                mu_inds = np.nonzero([
                    controller.supervisor.cluster_meta.get('group', c) == 'mua'
                    for c in controller.supervisor.clustering.cluster_ids
                ])[0]
                su_best_channels = np.array([
                    controller.get_best_channel(c) for c in
                    controller.supervisor.clustering.cluster_ids[su_inds]
                ])
                mu_best_channels = np.array([
                    controller.get_best_channel(c) for c in
                    controller.supervisor.clustering.cluster_ids[mu_inds]
                ])
                su_order = np.argsort(su_best_channels, kind='mergesort')
                mu_order = np.argsort(mu_best_channels, kind='mergesort')
                m_inds = np.concatenate((su_inds[su_order], mu_inds[mu_order]))

                filename = op.join(controller.model.dir_path,
                                   'cluster_names.ts')
                if not op.exists(filename):
                    best_channels = np.concatenate(
                        (su_best_channels[su_order],
                         mu_best_channels[mu_order]))
                    unit_type = np.concatenate(
                        (np.ones(len(su_order)), 2 * np.ones(len(mu_order))))
                    unit_number = np.zeros(len(best_channels))
                    for chan in np.unique(best_channels):
                        matched_clusts = best_channels == chan
                        unit_number[matched_clusts] = np.arange(
                            sum(matched_clusts)) + 1
                else:
                    print('{} exists, loading'.format(filename))
                    unit_types, channels, unit_numbers = load_metadata(
                        filename, controller.supervisor.clustering.cluster_ids)
                    best_channels = channels[m_inds]
                    unit_number = unit_numbers[m_inds]
                    unit_type = unit_types[m_inds]
                    unit_type_current = np.concatenate(
                        (np.ones(len(su_order)), 2 * np.ones(len(mu_order))))
                    if ~np.all(unit_type == unit_type_current):
                        raise RuntimeError(
                            'For the master phy, the unit types saved in "cluster_names.ts"'
                            'do not match those save in "cluster_groups.tsv" This likely means work was done on '
                            'this phy after merging with a previous master. Not sure how to deal with this!'
                        )
                    #re-sort to make unit numbers in order
                    # assuming unit_type is already sorted (which it should be...)
                    nsu = np.sum(unit_type == 1)
                    su_order = np.lexsort(
                        (unit_number[:nsu], best_channels[:nsu]))
                    mu_order = np.lexsort(
                        (unit_number[nsu:], best_channels[nsu:]))
                    m_inds[:nsu] = m_inds[su_order]
                    m_inds[nsu:] = m_inds[mu_order + nsu]
                    best_channels = channels[m_inds]
                    unit_number = unit_numbers[m_inds]
                    unit_type = unit_types[m_inds]

                dists = calc_dists(controller, controller2, m_inds)

                so = np.argsort(dists, 0, kind='mergesort')
                matchi = so[0, :]  #best match index to master for each slave
                sortrows = np.argsort(
                    matchi, kind='mergesort')  #sort index for best match

                def handle_item_clicked(item,
                                        controller=controller,
                                        controller2=controller2,
                                        plugin=plugin):
                    row = np.array(
                        [cell.row() for cell in table.selectedIndexes()])
                    column = np.array(
                        [cell.column() for cell in table.selectedIndexes()])
                    print("Row {} and Column {} was clicked".format(
                        row, column))
                    print("M {} S {} ".format(
                        controller.supervisor.clustering.cluster_ids[
                            plugin.m_inds[column]],
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]]))
                    column = column[~np.in1d(plugin.m_inds[column], (-1, -2))]
                    if len(column) == 0:
                        pass
                        #controller.supervisor.select(None)
                        # make a deselect function and call it here if feeling fancy
                    else:
                        controller.supervisor.select(
                            controller.supervisor.clustering.cluster_ids[
                                plugin.m_inds[column]].tolist())
                    controller2.supervisor.select(
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]].tolist())
                    #print("Row %d and Column %d was clicked" % (row, column))

                def create_table(controller, controller2, plugin):
                    plugin.table.setRowCount(len(plugin.matchi))
                    plugin.table.setColumnCount(len(plugin.m_inds))

                    # set data
                    dists_txt = np.round(plugin.dists / plugin.dists.max() *
                                         100)
                    normal = plt.Normalize(
                        plugin.dists[plugin.dists != -1].min() - 1,
                        plugin.dists.max() + 1)
                    colors = plt.cm.viridis_r(normal(plugin.dists)) * 255
                    for col in range(len(plugin.m_inds)):
                        for row in range(len(plugin.matchi)):
                            if plugin.dists[col, plugin.sortrows[row]] < 0:
                                item = QtGui.QTableWidgetItem('N/A')
                                item.setBackground(QtGui.QColor(127, 127, 127))
                            else:
                                item = QtGui.QTableWidgetItem('{:.0f}'.format(
                                    dists_txt[col, plugin.sortrows[row]]))
                                item.setBackground(
                                    QtGui.QColor(
                                        colors[col, plugin.sortrows[row], 0],
                                        colors[col, plugin.sortrows[row], 1],
                                        colors[col, plugin.sortrows[row], 2]))
                            if plugin.matchi[plugin.sortrows[row]] == col:
                                item.setForeground(QtGui.QColor(255, 0, 0))
                            #item.setFlags(Qt.ItemIsEditable)
                            plugin.table.setItem(row, col, item)
                            #plugin.table.item(row,col).setForeground(QtGui.QColor(0,255,0))
                    for col in range(plugin.dists.shape[0]):
                        if plugin.m_inds[col] == -1:
                            cluster_num = 'None'
                        elif plugin.m_inds[col] == -2:
                            cluster_num = 'Noise'
                        else:
                            cluster_num = controller.supervisor.clustering.cluster_ids[
                                plugin.m_inds[col]]
                        plugin.table.setHorizontalHeaderItem(
                            col,
                            QtGui.QTableWidgetItem('{}\n{:.0f}-{:.0f}'.format(
                                cluster_num, plugin.best_channels[col],
                                plugin.unit_number[col])))
                        #plugin.table.setHorizontalHeaderItem(row, QtGui.QTableWidgetItem('{:.0f}'.format(controller.supervisor.clustering.cluster_ids[plugin.m_inds[row]])))
                    for col in range(plugin.dists.shape[1]):
                        c_id = controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[col]]
                        #height=controller2.supervisor.cluster_view._columns['height']['func'](c_id)
                        snr = controller2.supervisor.cluster_view._columns[
                            'snr']['func'](c_id)
                        plugin.table.setVerticalHeaderItem(
                            col,
                            QtGui.QTableWidgetItem('{:.0f}-{:.1f}'.format(
                                c_id, snr)))
                    plugin.table.setEditTriggers(
                        QtGui.QAbstractItemView.NoEditTriggers)
                    plugin.table.resizeColumnsToContents()
                    plugin.table.itemClicked.connect(handle_item_clicked)

#                self.fig = plt.figure()
#                ax = self.fig.add_axes([0.15, 0.02, 0.83, 0.975])
#                normal = plt.Normalize(dists.min()-1, dists.max()+1)
#                dists_txt=np.round(dists/dists.max()*100)
#                self.table=ax.table(cellText=dists_txt, rowLabels=controller.supervisor.clustering.cluster_ids[m_inds], colLabels=controller2.supervisor.clustering.cluster_ids,
#                    colWidths = [0.03]*dists.shape[1], loc='center',
#                    cellColours=plt.cm.hot(normal(dists)))
#                self.fig.show()
#a = QApplication(sys.argv)

                tablegui = GUI(position=(400, 200), size=(400, 300))
                table = QtGui.QTableWidget()
                table.setWindowTitle("Merge Table")
                #table.resize(600, 400)

                plugin.matchi = matchi
                plugin.sortrows = sortrows
                plugin.best_channels = best_channels
                plugin.unit_number = unit_number
                plugin.unit_type = unit_type
                plugin.m_inds = m_inds
                plugin.tablegui = tablegui  #need to keep a reference otherwide the gui is deleted by garbage collection, leading to a segfault!
                plugin.table = table
                plugin.dists = dists

                create_table(controller, controller2, plugin)

                tablegui.add_view(table)

                actions = Actions(tablegui, name='Merge', menu='Merge')

                @actions.add(menu='Merge',
                             name='Set master for selected slave',
                             shortcut='enter')
                def set_master(plugin=plugin,
                               controller=controller,
                               controller2=controller2):
                    row = np.array([
                        cell.row() for cell in plugin.table.selectedIndexes()
                    ])
                    column = np.array([
                        cell.column()
                        for cell in plugin.table.selectedIndexes()
                    ])
                    if len(row) == 1:
                        print("Row {} and Column {} is selected".format(
                            row, column))
                        plugin.table.item(row[0], column[0]).setForeground(
                            QtGui.QColor(255, 0, 0))
                        #plugin.table.item(0, 0).setForeground(QtGui.QColor(0,255,0))
                        plugin.table.item(
                            row[0], plugin.matchi[plugin.sortrows[
                                row[0]]]).setForeground(QtGui.QColor(0, 0, 0))
                        plugin.matchi[plugin.sortrows[row[0]]] = column[0]
                        plugin.table.show()
                        plugin.tablegui.show()
                    else:
                        st = 'Only one cell can be selected when setting master'
                        print(st)
                        plugin.tablegui.status_message = st

                @actions.add(menu='Merge',
                             name='Merge selected slaves',
                             shortcut='m')
                def merge_slaves_by_selection(plugin=plugin,
                                              controller=controller,
                                              controller2=controller2):
                    row = np.array([
                        cell.row() for cell in plugin.table.selectedIndexes()
                    ])
                    column = np.array([
                        cell.column()
                        for cell in plugin.table.selectedIndexes()
                    ])
                    row = np.unique(row)
                    column = np.unique(column)
                    if len(column) == 1 and len(row) > 1:
                        merge_slaves(plugin, controller, controller2, row)
                        create_table(controller, controller2, plugin)
                        plugin.tablegui.show()
                    else:
                        if len(column) > 1:
                            st = 'Only one master can be selected when merging slaves'
                        elif len(row) == 1:
                            st = 'At least two slaves must be selected to merge slaves'
                        else:
                            st = 'Unknown slave merging problem'
                        print(st)
                        plugin.tablegui.status_message = st

                def merge_slaves_by_array(plugin, controller, controller2,
                                          merge_matchis):
                    for merge_matchi in merge_matchis:
                        row = np.where(
                            plugin.matchi[plugin.sortrows] == merge_matchi)[0]
                        merge_slaves(plugin, controller, controller2, row)

                    create_table(controller, controller2, plugin)
                    plugin.tablegui.show()

                def merge_slaves(plugin, controller, controller2, row):
                    controller2.supervisor.merge(
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows[row]].tolist())
                    assign_matchi = plugin.matchi[plugin.sortrows[row[0]]]
                    plugin.matchi = np.delete(plugin.matchi,
                                              plugin.sortrows[row],
                                              axis=0)
                    plugin.matchi = np.append(plugin.matchi, assign_matchi)
                    controller2.mean_waveforms = np.delete(
                        controller2.mean_waveforms,
                        plugin.sortrows[row],
                        axis=2)
                    new_mean_waveforms = create_mean_waveforms(
                        controller2,
                        max_waveforms_per_cluster=100,
                        cluster_ids=controller2.supervisor.clustering.
                        cluster_ids[-1])
                    controller2.mean_waveforms = np.append(
                        controller2.mean_waveforms, new_mean_waveforms, axis=2)
                    plugin.dists = np.delete(plugin.dists,
                                             plugin.sortrows[row],
                                             axis=1)
                    plugin.dists = np.append(plugin.dists,
                                             calc_dists(
                                                 controller,
                                                 controller2,
                                                 plugin.m_inds,
                                                 s_inds=plugin.dists.shape[1]),
                                             axis=1)
                    plugin.sortrows = np.argsort(plugin.matchi,
                                                 kind='mergesort')

                @actions.add(menu='Merge',
                             name='Move low-snr clusters to noise',
                             shortcut='n')
                def move_low_snr_to_noise(plugin=plugin,
                                          controller=controller,
                                          controller2=controller2):

                    cluster_ids = controller2.supervisor.clustering.cluster_ids
                    snrs = np.zeros(cluster_ids.shape)
                    for i in range(len(cluster_ids)):
                        snrs[i] = controller2.supervisor.cluster_view._columns[
                            'snr']['func'](cluster_ids[i])
                    thresh = 0.2  # for amplitude
                    thresh = 0.5  # for snr
                    noise_clusts = cluster_ids[snrs < thresh]

                    n_ind = []
                    for clu in noise_clusts:
                        this_ind = np.where(
                            controller2.supervisor.clustering.cluster_ids[
                                plugin.sortrows] == clu)[0][0]
                        n_ind.append(this_ind)

                    ind = plugin.m_inds.shape[0]
                    plugin.matchi[plugin.sortrows[n_ind]] = ind
                    plugin.m_inds = np.insert(plugin.m_inds, ind, -2)
                    plugin.best_channels = np.insert(plugin.best_channels, ind,
                                                     999)
                    plugin.unit_number = np.insert(plugin.unit_number, ind, 0)
                    plugin.unit_type = np.insert(plugin.unit_type, ind, 3)
                    plugin.dists = np.insert(plugin.dists, ind, -1, axis=0)
                    plugin.sortrows = np.argsort(plugin.matchi,
                                                 kind='mergesort')

                    create_table(controller, controller2, plugin)
                    tablegui.show()
                    st = 'Cluster ids {} moved to noise'.format(noise_clusts)
                    print(st)
                    plugin.tablegui.status_message = st

                @actions.add(menu='Merge',
                             name='Add new unit label',
                             shortcut='a')
                def add_unit(plugin=plugin,
                             controller=controller,
                             controller2=controller2):
                    chan, ok = QtGui.QInputDialog.getText(
                        None, 'Adding new unit label:', '       Channel:')
                    if not ok:
                        return
                    try:
                        chan = int(chan)
                    except:
                        plugin.tablegui.status_message = 'Error inputting channel'
                        return
                    nums = plugin.unit_number[plugin.best_channels == int(
                        chan)]
                    if len(nums) == 0:
                        next_unit_num = 1
                    else:
                        next_unit_num = int(nums.max()) + 1
                    dlg = QtGui.QInputDialog(None)
                    dlg.setInputMode(QtGui.QInputDialog.TextInput)
                    dlg.setTextValue('{}'.format(next_unit_num))
                    dlg.setLabelText("Unit number:")
                    dlg.resize(300, 300)
                    #dlg.mainLayout = QtGui.QVBoxLayout()
                    #dlg.setLayout(dlg.mainLayout)
                    b1 = QtGui.QRadioButton("SU", dlg)
                    b2 = QtGui.QRadioButton("MU", dlg)
                    b1.move(100, 0)
                    b2.move(150, 0)
                    b1.setChecked(True)
                    ok = dlg.exec_()
                    unit_number = dlg.textValue()
                    if not ok:
                        return
                    try:
                        unit_number = int(unit_number)
                    except:
                        plugin.tablegui.status_message = 'Error inputting unit number'
                        return
                    if b1.isChecked():
                        unit_type = 1
                    elif b2.isChecked():
                        unit_type = 2
                    else:
                        plugin.tablegui.status_message(
                            'Error getting unit type, must have checked either SU or MU'
                        )
                        return
                    below_inds = np.logical_and(
                        plugin.unit_type == unit_type,
                        plugin.best_channels <= chan).nonzero()[0]
                    if below_inds.shape[0] == 0:
                        below_inds = plugin.unit_type == unit_type
                        below_inds = below_inds.nonzero()[0]
                    ind = below_inds[-1] + 1
                    plugin.m_inds = np.insert(plugin.m_inds, ind, -1)
                    plugin.matchi[plugin.matchi >= ind] += 1
                    plugin.best_channels = np.insert(plugin.best_channels, ind,
                                                     chan)
                    plugin.unit_number = np.insert(plugin.unit_number, ind,
                                                   unit_number)
                    plugin.unit_type = np.insert(plugin.unit_type, ind,
                                                 unit_type)
                    plugin.dists = np.insert(plugin.dists, ind, -1, axis=0)
                    create_table(controller, controller2, plugin)

                    tablegui.show()

                @actions.add(menu='Merge',
                             name='Save cluster associations',
                             alias='sca')
                def save_cluster_associations(plugin=plugin,
                                              controller=controller,
                                              controller2=controller2):
                    un_matchi, counts = np.unique(plugin.matchi,
                                                  return_index=False,
                                                  return_inverse=False,
                                                  return_counts=True)
                    rmi = np.where(plugin.unit_type[un_matchi] == 3)[0]
                    if (len(rmi) > 0):
                        un_matchi = np.delete(un_matchi, rmi)
                        counts = np.delete(counts, rmi)
                    if np.any(counts > 1):
                        msgBox = QtGui.QMessageBox()
                        msgBox.setText(
                            'There are {} master clusters that are about '
                            'to be assigned multiple slave clusters. If this slave will '
                            'be used as a master for an addional merge, in most cases '
                            'slave clusters that share the same master match should be '
                            'merged.'.format(np.sum(counts > 1)))
                        msgBox.setInformativeText(
                            'Do you want to automatically do these merges?')
                        msgBox.setStandardButtons(QtGui.QMessageBox.Yes
                                                  | QtGui.QMessageBox.No
                                                  | QtGui.QMessageBox.Cancel)
                        msgBox.setDefaultButton(QtGui.QMessageBox.Yes)
                        ret = msgBox.exec_()
                        if ret == QtGui.QMessageBox.Yes:
                            merge_slaves_by_array(plugin, controller,
                                                  controller2,
                                                  un_matchi[counts > 1])
                            msgBox = QtGui.QMessageBox()
                            msgBox.setText(
                                'Merges done. Not saving yet! Check to see that everything is okay and then save.'
                            )
                            msgBox.exec_()
                            return
                        elif ret == QtGui.QMessageBox.No:
                            pass
                        else:
                            return
                    # assign labels to slave phy's clusters based on master phy's labels
                    good_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 1]].tolist()
                    controller2.supervisor.move('good', good_clusts)
                    mua_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 2]].tolist()
                    controller2.supervisor.move('mua', mua_clusts)
                    mua_clusts = controller2.supervisor.clustering.cluster_ids[
                        plugin.sortrows[plugin.unit_type[plugin.matchi[
                            plugin.sortrows]] == 3]].tolist()
                    controller2.supervisor.move('noise', mua_clusts)

                    #save both master and slave
                    controller.supervisor.save()
                    controller2.supervisor.save()

                    #save associations
                    create_tsv(
                        op.join(controller.model.dir_path, 'cluster_names.ts'),
                        controller.supervisor.clustering.cluster_ids[
                            plugin.m_inds[~np.in1d(plugin.m_inds, (-1, -2))]],
                        plugin.unit_type[~np.in1d(plugin.m_inds, (-1, -2))],
                        plugin.best_channels[~np.in1d(plugin.m_inds,
                                                      (-1, -2))],
                        plugin.unit_number[~np.in1d(plugin.m_inds, (-1, -2))])
                    create_tsv(
                        op.join(controller2.model.dir_path,
                                'cluster_names.ts'),
                        controller2.supervisor.clustering.cluster_ids[
                            plugin.sortrows],
                        plugin.unit_type[plugin.matchi[plugin.sortrows]],
                        plugin.best_channels[plugin.matchi[plugin.sortrows]],
                        plugin.unit_number[plugin.matchi[plugin.sortrows]])
                    with open(
                            op.join(controller.model.dir_path,
                                    'Merged_Files.txt'), 'a') as text_file:
                        text_file.write('{} on {}\n'.format(
                            controller2.model.dir_path, time.strftime('%c')))
                    plugin.tablegui.status_message = 'Saved clusted associations'
                    print('Saved clusted associations')

                def create_tsv(filename, cluster_id, unit_type, channel,
                               unit_number):
                    if sys.version_info[0] < 3:
                        file = open(filename, 'wb')
                    else:
                        file = open(filename, 'w', newline='')
                    with file as f:
                        writer = csv.writer(f, delimiter='\t')
                        writer.writerow(
                            ['cluster_id', 'unit_type', 'chan', 'unit_number'])
                        for i in range(len(cluster_id)):
                            writer.writerow([
                                cluster_id[i], unit_type[i], channel[i],
                                unit_number[i]
                            ])

                tablegui.show()