Esempio n. 1
0
    def _create_actions(self, gui):
        self.actions = Actions(gui,
                               name='Clustering',
                               menu='&Clustering',
                               default_shortcuts=self.shortcuts)

        # Selection.
        self.actions.add(self.select, alias='c')
        self.actions.separator()

        self.actions.add(self.undo)
        self.actions.add(self.redo)
        self.actions.separator()

        # Clustering.
        self.actions.add(self.merge, alias='g')
        self.actions.add(self.split, alias='k')
        self.actions.separator()

        # Move.
        self.actions.add(self.move)
        self.actions.separator()

        for group in ('noise', 'mua', 'good'):
            self.actions.add(partial(self.move_best, group),
                             name='move_best_to_' + group,
                             docstring='Move the best clusters to %s.' % group)
            self.actions.add(partial(self.move_similar, group),
                             name='move_similar_to_' + group,
                             docstring='Move the similar clusters to %s.' %
                             group)
            self.actions.add(partial(self.move_all, group),
                             name='move_all_to_' + group,
                             docstring='Move all selected clusters to %s.' %
                             group)
            self.actions.separator()

        # Label.
        self.actions.add(self.label, alias='l')

        # Others.
        self.actions.add(self.save, menu='&File')

        # Wizard.
        self.actions.add(self.reset, menu='&Wizard')
        self.actions.separator(menu='&Wizard')
        self.actions.add(self.next, menu='&Wizard')
        self.actions.add(self.previous, menu='&Wizard')
        self.actions.separator(menu='&Wizard')
        self.actions.add(self.next_best, menu='&Wizard')
        self.actions.add(self.previous_best, menu='&Wizard')
        self.actions.separator(menu='&Wizard')
Esempio n. 2
0
    def attach(self, gui):
        """Attach the GUI and create the menus."""
        # Create the menus.
        ds = self.default_shortcuts
        dsp = self.default_snippets
        self.edit_actions = Actions(
            gui, name='Edit', menu='&Edit', insert_menu_before='&View',
            default_shortcuts=ds, default_snippets=dsp)
        self.select_actions = Actions(
            gui, name='Select', menu='Sele&ct', insert_menu_before='&View',
            default_shortcuts=ds, default_snippets=dsp)

        # Create the actions.
        self._create_edit_actions()
        self._create_select_actions()
        self._create_toolbar(gui)
Esempio n. 3
0
    def attach(self, gui):
        """Attach the GUI and create the menus."""
        # Create the menus.
        ds = self.default_shortcuts
        dsp = self.default_snippets
        self.edit_actions = Actions(
            gui, menu='&Edit', default_shortcuts=ds, default_snippets=dsp)
        self.select_actions = Actions(
            gui, menu='Sele&ct', default_shortcuts=ds, default_snippets=dsp)
        self.view_actions = Actions(
            gui, menu='&View', default_shortcuts=ds, default_snippets=dsp)

        # Create the actions.
        self._create_edit_actions(gui.state)
        self._create_select_actions(gui.state)
        self._create_view_actions(gui.state)
Esempio n. 4
0
class ManualClustering(object):
    """Component that brings manual clustering facilities to a GUI:

    * Clustering instance: merge, split, undo, redo
    * ClusterMeta instance: change cluster metadata (e.g. group)
    * Selection
    * Many manual clustering-related actions, snippets, shortcuts, etc.

    Parameters
    ----------

    spike_clusters : ndarray
    spikes_per_cluster : function `cluster_id -> spike_ids`
    cluster_groups : dictionary
    shortcuts : dict
    quality: func
    similarity: func

    GUI events
    ----------

    When this component is attached to a GUI, the GUI emits the following
    events:

    select(cluster_ids)
        when clusters are selected
    cluster(up)
        when a merge or split happens
    request_save(spike_clusters, cluster_groups)
        when a save is requested by the user

    """

    default_shortcuts = {
        # Clustering.
        'merge': 'g',
        'split': 'k',

        # Move.
        'move_best_to_noise': 'alt+n',
        'move_best_to_mua': 'alt+m',
        'move_best_to_good': 'alt+g',

        'move_similar_to_noise': 'ctrl+n',
        'move_similar_to_mua': 'ctrl+m',
        'move_similar_to_good': 'ctrl+g',

        'move_all_to_noise': 'ctrl+alt+n',
        'move_all_to_mua': 'ctrl+alt+m',
        'move_all_to_good': 'ctrl+alt+g',

        # Wizard.
        'reset': 'ctrl+alt+space',
        'next': 'space',
        'previous': 'shift+space',
        'next_best': 'down',
        'previous_best': 'up',

        # Misc.
        'save': 'Save',
        'show_shortcuts': 'Save',
        'undo': 'Undo',
        'redo': ('ctrl+shift+z', 'ctrl+y'),
    }

    def __init__(self,
                 spike_clusters,
                 spikes_per_cluster,
                 cluster_groups=None,
                 best_channel=None,
                 shortcuts=None,
                 quality=None,
                 similarity=None,
                 new_cluster_id=None,
                 ):

        self.gui = None
        self.quality = quality  # function cluster => quality
        self.similarity = similarity  # function cluster => [(cl, sim), ...]
        self.best_channel = best_channel  # function cluster_id => channel_id

        assert hasattr(spikes_per_cluster, '__call__')
        self.spikes_per_cluster = spikes_per_cluster

        # Load default shortcuts, and override with any user shortcuts.
        self.shortcuts = self.default_shortcuts.copy()
        self.shortcuts.update(shortcuts or {})

        # Create Clustering and ClusterMeta.
        self.clustering = Clustering(spike_clusters,
                                     new_cluster_id=new_cluster_id)
        self.cluster_groups = cluster_groups or {}
        self.cluster_meta = create_cluster_meta(self.cluster_groups)
        self._global_history = GlobalHistory(process_ups=_process_ups)
        self._register_logging()

        # Create the cluster views.
        self._create_cluster_views()
        self._add_default_columns()

        self._best = None
        self._current_similarity_values = {}

    # Internal methods
    # -------------------------------------------------------------------------

    def _register_logging(self):
        # Log the actions.
        @self.clustering.connect
        def on_cluster(up):
            if up.history:
                logger.info(up.history.title() + " cluster assign.")
            elif up.description == 'merge':
                logger.info("Merge clusters %s to %s.",
                            ', '.join(map(str, up.deleted)),
                            up.added[0])
            else:
                logger.info("Assigned %s spikes.", len(up.spike_ids))

            if self.gui:
                self.gui.emit('cluster', up)

        @self.cluster_meta.connect  # noqa
        def on_cluster(up):
            # Update the original dictionary when groups change.
            for clu in up.metadata_changed:
                self.cluster_groups[clu] = up.metadata_value

            if up.history:
                logger.info(up.history.title() + " move.")
            else:
                logger.info("Move clusters %s to %s.",
                            ', '.join(map(str, up.metadata_changed)),
                            up.metadata_value)

            if self.gui:
                self.gui.emit('cluster', up)

    def _add_default_columns(self):
        # Default columns.
        @self.add_column(name='n_spikes')
        def n_spikes(cluster_id):
            return len(self.spikes_per_cluster(cluster_id))

        self.add_column(self.best_channel, name='channel')

        @self.add_column(show=False)
        def skip(cluster_id):
            """Whether to skip that cluster."""
            return (self.cluster_meta.get('group', cluster_id)
                    in ('noise', 'mua'))

        @self.add_column(show=False)
        def good(cluster_id):
            """Good column for color."""
            return self.cluster_meta.get('group', cluster_id) == 'good'

        def similarity(cluster_id):
            # NOTE: there is a dictionary with the similarity to the current
            # best cluster. It is updated when the selection changes in the
            # cluster view. This is a bit of a hack: the HTML table expects
            # a function that returns a value for every row, but here we
            # cache all similarity view rows in self._current_similarity_values
            return self._current_similarity_values.get(cluster_id, 0)
        if self.similarity:
            self.similarity_view.add_column(similarity,
                                            name=self.similarity.__name__)

    def _create_actions(self, gui):
        self.actions = Actions(gui,
                               name='Clustering',
                               menu='&Clustering',
                               default_shortcuts=self.shortcuts)

        # Selection.
        self.actions.add(self.select, alias='c')
        self.actions.separator()

        # Clustering.
        self.actions.add(self.merge, alias='g')
        self.actions.add(self.split, alias='k')
        self.actions.separator()

        # Move.
        self.actions.add(self.move)

        for group in ('noise', 'mua', 'good'):
            self.actions.add(partial(self.move_best, group),
                             name='move_best_to_' + group,
                             docstring='Move the best clusters to %s.' % group)
            self.actions.add(partial(self.move_similar, group),
                             name='move_similar_to_' + group,
                             docstring='Move the similar clusters to %s.' %
                             group)
            self.actions.add(partial(self.move_all, group),
                             name='move_all_to_' + group,
                             docstring='Move all selected clusters to %s.' %
                             group)
        self.actions.separator()

        # Others.
        self.actions.add(self.undo)
        self.actions.add(self.redo)
        self.actions.add(self.save)

        # Wizard.
        self.actions.add(self.reset, menu='&Wizard')
        self.actions.add(self.next, menu='&Wizard')
        self.actions.add(self.previous, menu='&Wizard')
        self.actions.add(self.next_best, menu='&Wizard')
        self.actions.add(self.previous_best, menu='&Wizard')
        self.actions.separator()

    def _create_cluster_views(self):
        # Create the cluster view.
        self.cluster_view = ClusterView()
        self.cluster_view.build()

        # Create the similarity view.
        self.similarity_view = ClusterView()
        self.similarity_view.build()

        # Selection in the cluster view.
        @self.cluster_view.connect_
        def on_select(cluster_ids):
            # Emit GUI.select when the selection changes in the cluster view.
            self._emit_select(cluster_ids)
            # Pin the clusters and update the similarity view.
            self._update_similarity_view()

        # Selection in the similarity view.
        @self.similarity_view.connect_  # noqa
        def on_select(cluster_ids):
            # Select the clusters from both views.
            cluster_ids = self.cluster_view.selected + cluster_ids
            self._emit_select(cluster_ids)

        # Save the current selection when an action occurs.
        def on_request_undo_state(up):
            return {'selection': (self.cluster_view.selected,
                                  self.similarity_view.selected)}

        self.clustering.connect(on_request_undo_state)
        self.cluster_meta.connect(on_request_undo_state)

        self._update_cluster_view()

    def _update_cluster_view(self):
        """Initialize the cluster view with cluster data."""
        logger.log(5, "Update the cluster view.")
        cluster_ids = [int(c) for c in self.clustering.cluster_ids]
        self.cluster_view.set_rows(cluster_ids)

    def _update_similarity_view(self):
        """Update the similarity view with matches for the specified
        clusters."""
        if not self.similarity:
            return
        selection = self.cluster_view.selected
        if not len(selection):
            return
        cluster_id = selection[0]
        cluster_ids = self.clustering.cluster_ids
        self._best = cluster_id
        logger.log(5, "Update the similarity view.")
        # This is a list of pairs (closest_cluster, similarity).
        similarities = self.similarity(cluster_id)
        # We save the similarity values wrt the currently-selected clusters.
        # Note that we keep the order of the output of the self.similary()
        # function.
        clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities])
        # List of similar clusters, remove non-existing ones.
        clusters = [c for c in clusters_sim.keys()
                    if c in cluster_ids]
        # The similarity view will use these values.
        self._current_similarity_values = clusters_sim
        # Set the rows of the similarity view.
        # TODO: instead of the self._current_similarity_values hack,
        # give the possibility to specify the values here (?).
        self.similarity_view.set_rows([c for c in clusters
                                       if c not in selection])

    def _emit_select(self, cluster_ids):
        """Choose spikes from the specified clusters and emit the
        `select` event on the GUI."""
        logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids)))
        if self.gui:
            self.gui.emit('select', cluster_ids)

    # Public methods
    # -------------------------------------------------------------------------

    def add_column(self, func=None, name=None, show=True, default=False):
        if func is None:
            return lambda f: self.add_column(f, name=name, show=show,
                                             default=default)
        name = name or func.__name__
        assert name
        self.cluster_view.add_column(func, name=name, show=show)
        self.similarity_view.add_column(func, name=name, show=show)
        if default:
            self.set_default_sort(name)

    def set_default_sort(self, name, sort_dir='desc'):
        assert name
        logger.debug("Set default sort `%s` %s.", name, sort_dir)
        # Set the default sort.
        self.cluster_view.set_default_sort(name, sort_dir)
        # Reset the cluster view.
        self._update_cluster_view()
        # Sort by the default sort.
        self.cluster_view.sort_by(name, sort_dir)

    def on_cluster(self, up):
        """Update the cluster views after clustering actions."""

        similar = self.similarity_view.selected

        # Reinitialize the cluster view if clusters have changed.
        if up.added:
            self._update_cluster_view()

        # Select all new clusters in view 1.
        if up.history == 'undo':
            # Select the clusters that were selected before the undone
            # action.
            clusters_0, clusters_1 = up.undo_state[0]['selection']
            self.cluster_view.select(clusters_0)
            self.similarity_view.select(clusters_1)
        elif up.added:
            if up.description == 'assign':
                # NOTE: we reverse the order such that the last selected
                # cluster (with a new color) is the split cluster.
                added = up.added[::-1]
            else:
                added = up.added
            self.select(added)
            if similar:
                self.similarity_view.next()
        elif up.metadata_changed:
            # Select next in similarity view if all moved are in that view.
            if set(up.metadata_changed) <= set(similar):

                # Update the cluster view, and select the clusters that
                # were selected before the action.
                selected = self.similarity_view.selected
                self._update_similarity_view()
                self.similarity_view.select(selected, do_emit=False)
                self.similarity_view.next()
            # Otherwise, select next in cluster view.
            else:
                # Update the cluster view, and select the clusters that
                # were selected before the action.
                selected = self.cluster_view.selected
                self._update_cluster_view()
                self.cluster_view.select(selected, do_emit=False)
                self.cluster_view.next()
                if similar:
                    self.similarity_view.next()

    def attach(self, gui):
        self.gui = gui

        # Create the actions.
        self._create_actions(gui)

        # Add the cluster views.
        gui.add_view(self.cluster_view, name='ClusterView')

        # Add the quality column in the cluster view.
        if self.quality:
            self.cluster_view.add_column(self.quality,
                                         name=self.quality.__name__,
                                         )

        # Update the cluster view and sort by n_spikes at the beginning.
        self._update_cluster_view()
        # if not self.quality:
        #     self.cluster_view.sort_by('n_spikes', 'desc')

        # Add the similarity view if there is a similarity function.
        if self.similarity:
            gui.add_view(self.similarity_view, name='SimilarityView')

        # Set the view state.
        cv = self.cluster_view
        cv.set_state(gui.state.get_view_state(cv))

        # Save the view state in the GUI state.
        @gui.connect_
        def on_close():
            gui.state.update_view_state(cv, cv.state)
            # NOTE: create_gui() already saves the state, but the event
            # is registered *before* we add all views.
            gui.state.save()

        # Update the cluster views and selection when a cluster event occurs.
        self.gui.connect_(self.on_cluster)
        return self

    # Selection actions
    # -------------------------------------------------------------------------

    def select(self, *cluster_ids):
        """Select a list of clusters."""
        # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])`
        # This makes it more convenient to select multiple clusters with
        # the snippet: `:c 1 2 3` instead of `:c 1,2,3`.
        if cluster_ids and isinstance(cluster_ids[0], (tuple, list)):
            cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:])
        # Update the cluster view selection.
        self.cluster_view.select(cluster_ids)

    @property
    def selected(self):
        return self.cluster_view.selected + self.similarity_view.selected

    # Clustering actions
    # -------------------------------------------------------------------------

    def merge(self, cluster_ids=None):
        """Merge the selected clusters."""
        if cluster_ids is None:
            cluster_ids = self.selected
        if len(cluster_ids or []) <= 1:
            return
        self.clustering.merge(cluster_ids)
        self._global_history.action(self.clustering)

    def split(self, spike_ids=None, spike_clusters_rel=0):
        """Split the selected spikes."""
        if spike_ids is None:
            spike_ids = self.gui.emit('request_split')
            spike_ids = np.concatenate(spike_ids).astype(np.int64)
        if len(spike_ids) == 0:
            msg = ("You first need to select spikes in the feature "
                   "view with a few Ctrl+Click around the spikes "
                   "that you want to split.")
            _show_box(self.gui.dialog(msg))
            return
        self.clustering.split(spike_ids,
                              spike_clusters_rel=spike_clusters_rel)
        self._global_history.action(self.clustering)

    # Move actions
    # -------------------------------------------------------------------------

    def move(self, cluster_ids, group):
        """Move clusters to a group."""
        if not hasattr(cluster_ids, '__len__'):
            cluster_ids = [cluster_ids]
        if len(cluster_ids) == 0:
            return
        self.cluster_meta.set('group', cluster_ids, group)
        self._global_history.action(self.cluster_meta)

    def move_best(self, group=None):
        """Move all selected best clusters to a group."""
        self.move(self.cluster_view.selected, group)

    def move_similar(self, group=None):
        """Move all selected similar clusters to a group."""
        self.move(self.similarity_view.selected, group)

    def move_all(self, group=None):
        """Move all selected clusters to a group."""
        self.move(self.selected, group)

    # Wizard actions
    # -------------------------------------------------------------------------

    def reset(self):
        """Reset the wizard."""
        self._update_cluster_view()
        self.cluster_view.next()

    def next_best(self):
        """Select the next best cluster."""
        self.cluster_view.next()

    def previous_best(self):
        """Select the previous best cluster."""
        self.cluster_view.previous()

    def next(self):
        """Select the next cluster."""
        if not self.selected:
            self.cluster_view.next()
        else:
            self.similarity_view.next()

    def previous(self):
        """Select the previous cluster."""
        self.similarity_view.previous()

    # Other actions
    # -------------------------------------------------------------------------

    def undo(self):
        """Undo the last action."""
        self._global_history.undo()

    def redo(self):
        """Undo the last undone action."""
        self._global_history.redo()

    def save(self):
        """Save the manual clustering back to disk."""
        spike_clusters = self.clustering.spike_clusters
        groups = {c: self.cluster_meta.get('group', c) or 'unsorted'
                  for c in self.clustering.cluster_ids}
        self.gui.emit('request_save', spike_clusters, groups)
Esempio n. 5
0
class Supervisor(EventEmitter):
    """Component that brings manual clustering facilities to a GUI:

    * Clustering instance: merge, split, undo, redo
    * ClusterMeta instance: change cluster metadata (e.g. group)
    * Selection
    * Many manual clustering-related actions, snippets, shortcuts, etc.

    Parameters
    ----------

    spike_clusters : ndarray
    cluster_groups : dictionary
    shortcuts : dict
    quality: func
    similarity: func

    GUI events
    ----------

    When this component is attached to a GUI, the GUI emits the following
    events:

    select(cluster_ids)
        when clusters are selected
    cluster(up)
        when a merge or split happens
    request_save(spike_clusters, cluster_groups)
        when a save is requested by the user

    """

    default_shortcuts = {
        # Clustering.
        'merge': 'g',
        'split': 'k',

        'label': 'l',

        # Move.
        'move_best_to_noise': 'alt+n',
        'move_best_to_mua': 'alt+m',
        'move_best_to_good': 'alt+g',

        'move_similar_to_noise': 'ctrl+n',
        'move_similar_to_mua': 'ctrl+m',
        'move_similar_to_good': 'ctrl+g',

        'move_all_to_noise': 'ctrl+alt+n',
        'move_all_to_mua': 'ctrl+alt+m',
        'move_all_to_good': 'ctrl+alt+g',

        # Wizard.
        'reset': 'ctrl+alt+space',
        'next': 'space',
        'previous': 'shift+space',
        'next_best': 'down',
        'previous_best': 'up',

        # Misc.
        'save': 'Save',
        'show_shortcuts': 'Save',
        'undo': 'Undo',
        'redo': ('ctrl+shift+z', 'ctrl+y'),
    }

    def __init__(self,
                 spike_clusters,
                 cluster_groups=None,
                 shortcuts=None,
                 quality=None,
                 similarity=None,
                 new_cluster_id=None,
                 context=None,
                 ):
        super(Supervisor, self).__init__()
        self.context = context
        self.quality = quality or self.n_spikes  # function cluster => quality
        self.similarity = similarity  # function cluster => [(cl, sim), ...]

        self._best = None
        self._current_similarity_values = {}

        # Load default shortcuts, and override with any user shortcuts.
        self.shortcuts = self.default_shortcuts.copy()
        self.shortcuts.update(shortcuts or {})

        # Create Clustering and ClusterMeta.
        # Load the cached spikes_per_cluster array.
        spc = context.load('spikes_per_cluster') if context else None
        self.clustering = Clustering(spike_clusters,
                                     spikes_per_cluster=spc,
                                     new_cluster_id=new_cluster_id)
        # Cache the spikes_per_cluster array.
        self._save_spikes_per_cluster()

        self.cluster_groups = cluster_groups or {}
        self.cluster_meta = create_cluster_meta(self.cluster_groups)
        self._global_history = GlobalHistory(process_ups=_process_ups)

        self.cluster_meta.add_field('next_cluster')

        @self.clustering.connect
        def on_cluster(up):
            """Register the next cluster in the list before the cluster
            view is updated."""
            if not up.added or not hasattr(self, 'cluster_view'):
                return
            cluster = up.added[0]
            next_cluster = self.cluster_view.get_next_id()
            logger.debug("Register next_cluster to %d: %s",
                         cluster, next_cluster)
            self.cluster_meta.set('next_cluster', [cluster], next_cluster,
                                  add_to_stack=False)

        # NOTE: global on_cluster() occurs here.
        self._register_logging()

    # Internal methods
    # -------------------------------------------------------------------------

    def _save_spikes_per_cluster(self):
        if self.context:
            self.context.save('spikes_per_cluster',
                              self.clustering.spikes_per_cluster,
                              kind='pickle',
                              )

    def _register_logging(self):
        # Log the actions.
        @self.clustering.connect
        def on_cluster(up):
            if up.history:
                logger.info(up.history.title() + " cluster assign.")
            elif up.description == 'merge':
                logger.info("Merge clusters %s to %s.",
                            ', '.join(map(str, up.deleted)),
                            up.added[0])
            else:
                logger.info("Assigned %s spikes.", len(up.spike_ids))

            self.emit('cluster', up)

        @self.cluster_meta.connect  # noqa
        def on_cluster(up):
            # Log changes.
            if up.history:
                logger.info(up.history.title() + " move.")
            else:
                logger.info("Change %s for clusters %s to %s.",
                            up.description,
                            ', '.join(map(str, up.metadata_changed)),
                            up.metadata_value)

            # Skip cluster metadata other than groups.
            if up.description != 'metadata_group':
                return

            # Update the original dictionary when groups change.
            for clu in up.metadata_changed:
                self.cluster_groups[clu] = up.metadata_value

            self.emit('cluster', up)

    def _add_field_column(self, field):  # pragma: no cover
        """Add a column for a given label field."""
        @self.add_column(name=field)
        def get_my_label(cluster_id):
            return self.cluster_meta.get(field, cluster_id)

    def _add_default_columns(self):
        @self.add_column(show=False)
        def skip(cluster_id):
            """Whether to skip that cluster."""
            return (self.cluster_meta.get('group', cluster_id)
                    in ('noise', 'mua'))

        @self.add_column(show=False)
        def good(cluster_id):
            """Good column for color."""
            return self.cluster_meta.get('group', cluster_id) == 'good'

        # Default columns.
        self.add_column(self.n_spikes)

        @self.add_column
        def group(cluster_id):
            g = self.cluster_meta.get('group', cluster_id)
            g = g or 'unsorted'
            return g

        # Add columns for labels.
        for field in self.fields:  # pragma: no cover
            self._add_field_column(field)

        def similarity(cluster_id):
            # NOTE: there is a dictionary with the similarity to the current
            # best cluster. It is updated when the selection changes in the
            # cluster view. This is a bit of a hack: the HTML table expects
            # a function that returns a value for every row, but here we
            # cache all similarity view rows in self._current_similarity_values
            return self._current_similarity_values.get(cluster_id, 0)
        if self.similarity:
            self.similarity_view.add_column(similarity,
                                            name=self.similarity.__name__)

    def n_spikes(self, cluster_id):
        return len(self.clustering.spikes_per_cluster[cluster_id])

    def _create_actions(self, gui):
        self.actions = Actions(gui,
                               name='Clustering',
                               menu='&Clustering',
                               default_shortcuts=self.shortcuts)

        # Selection.
        self.actions.add(self.select, alias='c')
        self.actions.separator()

        self.actions.add(self.undo)
        self.actions.add(self.redo)
        self.actions.separator()

        # Clustering.
        self.actions.add(self.merge, alias='g')
        self.actions.add(self.split, alias='k')
        self.actions.separator()

        # Move.
        self.actions.add(self.move)
        self.actions.separator()

        for group in ('noise', 'mua', 'good'):
            self.actions.add(partial(self.move_best, group),
                             name='move_best_to_' + group,
                             docstring='Move the best clusters to %s.' % group)
            self.actions.add(partial(self.move_similar, group),
                             name='move_similar_to_' + group,
                             docstring='Move the similar clusters to %s.' %
                             group)
            self.actions.add(partial(self.move_all, group),
                             name='move_all_to_' + group,
                             docstring='Move all selected clusters to %s.' %
                             group)
            self.actions.separator()

        # Label.
        self.actions.add(self.label, alias='l')

        # Others.
        self.actions.add(self.save, menu='&File')

        # Wizard.
        self.actions.add(self.reset, menu='&Wizard')
        self.actions.separator(menu='&Wizard')
        self.actions.add(self.next, menu='&Wizard')
        self.actions.add(self.previous, menu='&Wizard')
        self.actions.separator(menu='&Wizard')
        self.actions.add(self.next_best, menu='&Wizard')
        self.actions.add(self.previous_best, menu='&Wizard')
        self.actions.separator(menu='&Wizard')

    def _keep_existing_clusters(self, cluster_ids):
        return [c for c in cluster_ids
                if c in self.clustering.cluster_ids]

    def _emit_select(self, cluster_ids, **kwargs):
        """Choose spikes from the specified clusters and emit the
        `select` event on the GUI."""
        # Remove non-existing clusters from the selection.
        cluster_ids = self._keep_existing_clusters(cluster_ids)
        logger.debug("Select cluster(s): %s.",
                     ', '.join(map(str, cluster_ids)))
        self.emit('select', cluster_ids, **kwargs)

    def _create_cluster_views(self):
        # Create the cluster view.
        self.cluster_view = ClusterView()
        self.cluster_view.build()

        # Create the similarity view.
        self.similarity_view = ClusterView()
        self.similarity_view.build()

        # Selection in the cluster view.
        @self.cluster_view.connect_
        def on_select(cluster_ids, **kwargs):
            # Emit GUI.select when the selection changes in the cluster view.
            self._emit_select(cluster_ids, **kwargs)
            # Pin the clusters and update the similarity view.
            self._update_similarity_view()

        # Selection in the similarity view.
        @self.similarity_view.connect_  # noqa
        def on_select(cluster_ids, **kwargs):
            # Select the clusters from both views.
            cluster_ids = self.cluster_view.selected + cluster_ids
            self._emit_select(cluster_ids, **kwargs)

        # Save the current selection when an action occurs.
        def on_request_undo_state(up):
            return {'selection': (self.cluster_view.selected,
                                  self.similarity_view.selected)}

        self.clustering.connect(on_request_undo_state)
        self.cluster_meta.connect(on_request_undo_state)

        self._update_cluster_view()

    def _update_cluster_view(self):
        """Initialize the cluster view with cluster data."""
        logger.log(5, "Update the cluster view.")
        cluster_ids = [int(c) for c in self.clustering.cluster_ids]
        self.cluster_view.set_rows(cluster_ids)

    def _update_similarity_view(self):
        """Update the similarity view with matches for the specified
        clusters."""
        if not self.similarity:
            return
        selection = self.cluster_view.selected
        if not len(selection):
            return
        cluster_id = selection[0]
        cluster_ids = self.clustering.cluster_ids
        self._best = cluster_id
        logger.log(5, "Update the similarity view.")
        # This is a list of pairs (closest_cluster, similarity).
        similarities = self.similarity(cluster_id)
        # We save the similarity values wrt the currently-selected clusters.
        # Note that we keep the order of the output of the self.similary()
        # function.
        clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities])
        # List of similar clusters, remove non-existing ones.
        clusters = [c for c in clusters_sim.keys()
                    if c in cluster_ids]
        # The similarity view will use these values.
        self._current_similarity_values = clusters_sim
        # Set the rows of the similarity view.
        # TODO: instead of the self._current_similarity_values hack,
        # give the possibility to specify the values here (?).
        self.similarity_view.set_rows([c for c in clusters
                                       if c not in selection])

    # Public methods
    # -------------------------------------------------------------------------

    def add_column(self, func=None, name=None, show=True, default=False):
        if func is None:
            return lambda f: self.add_column(f, name=name, show=show,
                                             default=default)
        name = name or func.__name__
        assert name
        self.cluster_view.add_column(func, name=name, show=show)
        self.similarity_view.add_column(func, name=name, show=show)
        if default:
            self.set_default_sort(name)

    def set_default_sort(self, name, sort_dir='desc'):
        assert name
        logger.debug("Set default sort `%s` %s.", name, sort_dir)
        # Set the default sort.
        self.cluster_view.set_default_sort(name, sort_dir)
        # Reset the cluster view.
        self._update_cluster_view()
        # Sort by the default sort.
        self.cluster_view.sort_by(name, sort_dir)

    def on_cluster(self, up):
        """Update the cluster views after clustering actions."""

        similar = self.similarity_view.selected

        # Reinitialize the cluster view if clusters have changed.
        if up.added:
            self._update_cluster_view()

        # Select all new clusters in view 1.
        if up.history == 'undo':
            # Select the clusters that were selected before the undone
            # action.
            clusters_0, clusters_1 = up.undo_state[0]['selection']
            # Select rows in the tables.
            self.cluster_view.select(clusters_0, up=up)
            self.similarity_view.select(clusters_1, up=up)
        elif up.added:
            if up.description == 'assign':
                # NOTE: we change the order such that the last selected
                # cluster (with a new color) is the split cluster.
                added = list(up.added[1:]) + [up.added[0]]
            else:
                added = up.added
            # Select the new clusters in the cluster view.
            self.cluster_view.select(added, up=up)
            if similar:
                self.similarity_view.next()
        elif up.metadata_changed:
            # Select next in similarity view if all moved are in that view.
            if set(up.metadata_changed) <= set(similar):
                next_cluster = self.similarity_view.get_next_id()
                self._update_similarity_view()
                if next_cluster is not None:
                    # Select the cluster in the similarity view.
                    self.similarity_view.select([next_cluster])
            # Otherwise, select next in cluster view.
            else:
                self._update_cluster_view()
                # Determine if there is a next cluster set from a
                # previous clustering action.
                cluster = up.metadata_changed[0]
                next_cluster = self.cluster_meta.get('next_cluster', cluster)
                logger.debug("Get next_cluster for %d: %s.",
                             cluster, next_cluster)
                # If there is not, fallback on the next cluster in the list.
                if next_cluster is None:
                    self.cluster_view.select([cluster], do_emit=False)
                    self.cluster_view.next()
                else:
                    self.cluster_view.select([next_cluster])

    def attach(self, gui):
        # Create the cluster views.
        self._create_cluster_views()
        self._add_default_columns()

        # Create the actions.
        self._create_actions(gui)

        self.emit('create_cluster_views')

        # Add the cluster views.
        gui.add_view(self.cluster_view, name='ClusterView')

        # Add the quality column in the cluster view.
        if self.quality:
            self.cluster_view.add_column(self.quality,
                                         name=self.quality.__name__,
                                         )

        # Update the cluster view and sort by n_spikes at the beginning.
        self._update_cluster_view()

        # Add the similarity view if there is a similarity function.
        if self.similarity:
            gui.add_view(self.similarity_view, name='SimilarityView')

        # Set the view state.
        cv = self.cluster_view
        cv.set_state(gui.state.get_view_state(cv))

        # Save the new cluster id on disk.
        @self.clustering.connect
        def on_cluster(up):
            new_cluster_id = self.clustering.new_cluster_id()
            if self.context:
                logger.debug("Save the new cluster id: %d.", new_cluster_id)
                self.context.save('new_cluster_id',
                                  dict(new_cluster_id=new_cluster_id))

        # The GUI emits the select event too.
        @self.connect
        def on_select(cluster_ids, **kwargs):
            gui.emit('select', cluster_ids, **kwargs)

        @self.connect
        def on_request_split():
            return gui.emit('request_split', single=True)

        # Save the view state in the GUI state.
        @gui.connect_
        def on_close():
            gui.state.update_view_state(cv, cv.state)
            # NOTE: create_gui() already saves the state, but the event
            # is registered *before* we add all views.
            gui.state.save()

        # Update the cluster views and selection when a cluster event occurs.
        self.connect(self.on_cluster)

        self.emit('attach_gui', gui)

        return self

    # Selection actions
    # -------------------------------------------------------------------------

    def select(self, *cluster_ids):
        """Select a list of clusters."""
        # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])`
        # This makes it more convenient to select multiple clusters with
        # the snippet: `:c 1 2 3` instead of `:c 1,2,3`.
        if cluster_ids and isinstance(cluster_ids[0], (tuple, list)):
            cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:])
        # Remove non-existing clusters from the selection.
        cluster_ids = self._keep_existing_clusters(cluster_ids)
        # Update the cluster view selection.
        self.cluster_view.select(cluster_ids)

    @property
    def selected(self):
        return self.cluster_view.selected + self.similarity_view.selected

    # Clustering actions
    # -------------------------------------------------------------------------

    def merge(self, cluster_ids=None, to=None):
        """Merge the selected clusters."""
        if cluster_ids is None:
            cluster_ids = self.selected
        if len(cluster_ids or []) <= 1:
            return
        self.clustering.merge(cluster_ids, to=to)
        self._global_history.action(self.clustering)

    def split(self, spike_ids=None, spike_clusters_rel=0):
        """Split the selected spikes."""
        if spike_ids is None:
            spike_ids = self.emit('request_split', single=True)
            spike_ids = np.asarray(spike_ids, dtype=np.int64)
            assert spike_ids.dtype == np.int64
            assert spike_ids.ndim == 1
        if len(spike_ids) == 0:
            msg = ("You first need to select spikes in the feature "
                   "view with a few Ctrl+Click around the spikes "
                   "that you want to split.")
            self.emit('error', msg)
            return
        self.clustering.split(spike_ids,
                              spike_clusters_rel=spike_clusters_rel)
        self._global_history.action(self.clustering)

    # Move actions
    # -------------------------------------------------------------------------

    @property
    def fields(self):
        """Tuple of label fields."""
        return tuple(f for f in self.cluster_meta.fields
                     if f not in ('group', 'next_cluster'))

    def get_labels(self, field):
        """Return the labels of all clusters, for a given field."""
        return {c: self.cluster_meta.get(field, c)
                for c in self.clustering.cluster_ids}

    def label(self, name, value, cluster_ids=None):
        """Assign a label to clusters.

        Example: `quality 3`

        """
        if cluster_ids is None:
            cluster_ids = self.cluster_view.selected
        if not hasattr(cluster_ids, '__len__'):
            cluster_ids = [cluster_ids]
        if len(cluster_ids) == 0:
            return
        self.cluster_meta.set(name, cluster_ids, value)
        self._global_history.action(self.cluster_meta)

    def move(self, group, cluster_ids=None):
        """Assign a group to some clusters.

        Example: `good`

        """
        if isinstance(cluster_ids, string_types):
            logger.warn("The list of clusters should be a list of integers, "
                        "not a string.")
            return
        self.label('group', group, cluster_ids=cluster_ids)

    def move_best(self, group=None):
        """Move all selected best clusters to a group."""
        self.move(group, self.cluster_view.selected)

    def move_similar(self, group=None):
        """Move all selected similar clusters to a group."""
        self.move(group, self.similarity_view.selected)

    def move_all(self, group=None):
        """Move all selected clusters to a group."""
        self.move(group, self.selected)

    # Wizard actions
    # -------------------------------------------------------------------------

    def reset(self):
        """Reset the wizard."""
        self._update_cluster_view()
        self.cluster_view.next()

    def next_best(self):
        """Select the next best cluster."""
        self.cluster_view.next()

    def previous_best(self):
        """Select the previous best cluster."""
        self.cluster_view.previous()

    def next(self):
        """Select the next cluster."""
        if not self.selected:
            self.cluster_view.next()
        else:
            self.similarity_view.next()

    def previous(self):
        """Select the previous cluster."""
        self.similarity_view.previous()

    # Other actions
    # -------------------------------------------------------------------------

    def undo(self):
        """Undo the last action."""
        self._global_history.undo()

    def redo(self):
        """Undo the last undone action."""
        self._global_history.redo()

    def save(self):
        """Save the manual clustering back to disk."""
        spike_clusters = self.clustering.spike_clusters
        groups = {c: self.cluster_meta.get('group', c) or 'unsorted'
                  for c in self.clustering.cluster_ids}
        # List of tuples (field_name, dictionary).
        labels = [(field, self.get_labels(field))
                  for field in self.cluster_meta.fields
                  if field not in ('next_cluster')]
        # TODO: add option in add_field to declare a field unsavable.
        self.emit('request_save', spike_clusters, groups, *labels)
        # Cache the spikes_per_cluster array.
        self._save_spikes_per_cluster()
Esempio n. 6
0
class ActionCreator(object):
    """Companion class to the Supervisor that manages the related GUI actions."""

    default_shortcuts = {
        # Clustering.
        'merge': 'g',
        'split': 'k',

        'label': 'l',

        # Move.
        'move_best_to_noise': 'alt+n',
        'move_best_to_mua': 'alt+m',
        'move_best_to_good': 'alt+g',
        'move_best_to_unsorted': 'alt+u',

        'move_similar_to_noise': 'ctrl+n',
        'move_similar_to_mua': 'ctrl+m',
        'move_similar_to_good': 'ctrl+g',
        'move_similar_to_unsorted': 'ctrl+u',

        'move_all_to_noise': 'ctrl+alt+n',
        'move_all_to_mua': 'ctrl+alt+m',
        'move_all_to_good': 'ctrl+alt+g',
        'move_all_to_unsorted': 'ctrl+alt+u',

        # Wizard.
        'first': 'home',
        'last': 'end',
        'reset': 'ctrl+alt+space',
        'next': 'space',
        'previous': 'shift+space',
        'unselect_similar': 'backspace',
        'next_best': 'down',
        'previous_best': 'up',

        # Misc.
        'undo': 'ctrl+z',
        'redo': ('ctrl+shift+z', 'ctrl+y'),

        'clear_filter': 'esc',
    }

    default_snippets = {
        'merge': 'g',
        'split': 'k',
        'label': 'l',
        'select': 'c',
        'filter': 'f',
        'sort': 's',
    }

    def __init__(self, supervisor=None):
        self.supervisor = supervisor

    def add(self, which, name, **kwargs):
        """Add an action to a given menu."""
        # This special keyword argument lets us use a different name for the
        # action and the event name/method (used for different move flavors).
        method_name = kwargs.pop('method_name', name)
        method_args = kwargs.pop('method_args', ())
        emit_fun = partial(emit, 'action', self, method_name, *method_args)
        f = getattr(self.supervisor, method_name, None)
        docstring = inspect.getdoc(f) if f else name
        if not kwargs.get('docstring', None):
            kwargs['docstring'] = docstring
        getattr(self, '%s_actions' % which).add(emit_fun, name=name, **kwargs)

    def attach(self, gui):
        """Attach the GUI and create the menus."""
        # Create the menus.
        ds = self.default_shortcuts
        dsp = self.default_snippets
        self.edit_actions = Actions(
            gui, name='Edit', menu='&Edit', insert_menu_before='&View',
            default_shortcuts=ds, default_snippets=dsp)
        self.select_actions = Actions(
            gui, name='Select', menu='Sele&ct', insert_menu_before='&View',
            default_shortcuts=ds, default_snippets=dsp)

        # Create the actions.
        self._create_edit_actions()
        self._create_select_actions()
        self._create_toolbar(gui)

    def _create_edit_actions(self):
        w = 'edit'
        self.add(w, 'undo', set_busy=True, icon='f0e2')
        self.add(w, 'redo', set_busy=True, icon='f01e')
        self.edit_actions.separator()

        # Clustering.
        self.add(w, 'merge', set_busy=True, icon='f247')
        self.add(w, 'split', set_busy=True)
        self.edit_actions.separator()

        # Move.
        self.add(w, 'move', prompt=True, n_args=2)
        for which in ('best', 'similar', 'all'):
            for group in ('noise', 'mua', 'good', 'unsorted'):
                self.add(
                    w, 'move_%s_to_%s' % (which, group),
                    method_name='move',
                    method_args=(group, which),
                    submenu='Move %s to' % which,
                    docstring='Move %s to %s.' % (which, group))
        self.edit_actions.separator()

        # Label.
        self.add(w, 'label', prompt=True, n_args=2)
        self.edit_actions.separator()

    def _create_select_actions(self):
        w = 'select'

        # Selection.
        self.add(w, 'select', prompt=True, n_args=1)
        self.add(w, 'unselect_similar')
        self.select_actions.separator()

        # Sort and filter
        self.add(w, 'filter', prompt=True, n_args=1)
        self.add(w, 'sort', prompt=True, n_args=1)
        self.add(w, 'clear_filter')

        # Sort by:
        for column in getattr(self.supervisor, 'columns', ()):
            self.add(
                w, 'sort_by_%s' % column.lower(), method_name='sort', method_args=(column,),
                docstring='Sort by %s' % column,
                submenu='Sort by', alias='s%s' % column.replace('_', '')[:2])

        self.select_actions.separator()
        self.add(w, 'first')
        self.add(w, 'last')

        self.select_actions.separator()

        self.add(w, 'reset_wizard', icon='f015')
        self.select_actions.separator()

        self.add(w, 'next', icon='f061')
        self.add(w, 'previous', icon='f060')
        self.select_actions.separator()

        self.add(w, 'next_best', icon='f0a9')
        self.add(w, 'previous_best', icon='f0a8')
        self.select_actions.separator()

    def _create_toolbar(self, gui):
        gui._toolbar.addAction(self.edit_actions.get('undo'))
        gui._toolbar.addAction(self.edit_actions.get('redo'))
        gui._toolbar.addSeparator()
        gui._toolbar.addAction(self.select_actions.get('reset_wizard'))
        gui._toolbar.addAction(self.select_actions.get('previous_best'))
        gui._toolbar.addAction(self.select_actions.get('next_best'))
        gui._toolbar.addAction(self.select_actions.get('previous'))
        gui._toolbar.addAction(self.select_actions.get('next'))
        gui._toolbar.addSeparator()
        gui._toolbar.show()
Esempio n. 7
0
class ManualClustering(object):
    """Component that brings manual clustering facilities to a GUI:

    * Clustering instance: merge, split, undo, redo
    * ClusterMeta instance: change cluster metadata (e.g. group)
    * Selection
    * Many manual clustering-related actions, snippets, shortcuts, etc.

    Parameters
    ----------

    spike_clusters : ndarray
    spikes_per_cluster : function `cluster_id -> spike_ids`
    cluster_groups : dictionary
    shortcuts : dict
    quality: func
    similarity: func

    GUI events
    ----------

    When this component is attached to a GUI, the GUI emits the following
    events:

    select(cluster_ids)
        when clusters are selected
    cluster(up)
        when a merge or split happens
    request_save(spike_clusters, cluster_groups)
        when a save is requested by the user

    """

    default_shortcuts = {
        # Clustering.
        'merge': 'g',
        'split': 'k',

        # Move.
        'move_best_to_noise': 'alt+n',
        'move_best_to_mua': 'alt+m',
        'move_best_to_good': 'alt+g',

        'move_similar_to_noise': 'ctrl+n',
        'move_similar_to_mua': 'ctrl+m',
        'move_similar_to_good': 'ctrl+g',

        'move_all_to_noise': 'ctrl+alt+n',
        'move_all_to_mua': 'ctrl+alt+m',
        'move_all_to_good': 'ctrl+alt+g',

        # Wizard.
        'reset': 'ctrl+alt+space',
        'next': 'space',
        'previous': 'shift+space',
        'next_best': 'down',
        'previous_best': 'up',

        # Misc.
        'save': 'Save',
        'show_shortcuts': 'Save',
        'undo': 'Undo',
        'redo': ('ctrl+shift+z', 'ctrl+y'),
    }

    def __init__(self,
                 spike_clusters,
                 spikes_per_cluster,
                 cluster_groups=None,
                 best_channel=None,
                 shortcuts=None,
                 quality=None,
                 similarity=None,
                 new_cluster_id=None,
                 ):

        self.gui = None
        self.quality = quality  # function cluster => quality
        self.similarity = similarity  # function cluster => [(cl, sim), ...]
        self.best_channel = best_channel  # function cluster_id => channel_id

        assert hasattr(spikes_per_cluster, '__call__')
        self.spikes_per_cluster = spikes_per_cluster

        # Load default shortcuts, and override with any user shortcuts.
        self.shortcuts = self.default_shortcuts.copy()
        self.shortcuts.update(shortcuts or {})

        # Create Clustering and ClusterMeta.
        self.clustering = Clustering(spike_clusters,
                                     new_cluster_id=new_cluster_id)
        self.cluster_groups = cluster_groups or {}
        self.cluster_meta = create_cluster_meta(self.cluster_groups)
        self._global_history = GlobalHistory(process_ups=_process_ups)
        self._register_logging()

        # Create the cluster views.
        self._create_cluster_views()
        self._add_default_columns()

        self._best = None
        self._current_similarity_values = {}

    # Internal methods
    # -------------------------------------------------------------------------

    def _register_logging(self):
        # Log the actions.
        @self.clustering.connect
        def on_cluster(up):
            if up.history:
                logger.info(up.history.title() + " cluster assign.")
            elif up.description == 'merge':
                logger.info("Merge clusters %s to %s.",
                            ', '.join(map(str, up.deleted)),
                            up.added[0])
            else:
                logger.info("Assigned %s spikes.", len(up.spike_ids))

            if self.gui:
                self.gui.emit('cluster', up)

        @self.cluster_meta.connect  # noqa
        def on_cluster(up):
            # Update the original dictionary when groups change.
            for clu in up.metadata_changed:
                self.cluster_groups[clu] = up.metadata_value

            if up.history:
                logger.info(up.history.title() + " move.")
            else:
                logger.info("Move clusters %s to %s.",
                            ', '.join(map(str, up.metadata_changed)),
                            up.metadata_value)

            if self.gui:
                self.gui.emit('cluster', up)

    def _add_default_columns(self):
        # Default columns.
        @self.add_column(name='n_spikes')
        def n_spikes(cluster_id):
            return len(self.spikes_per_cluster(cluster_id))

        self.add_column(self.best_channel, name='channel')

        @self.add_column(show=False)
        def skip(cluster_id):
            """Whether to skip that cluster."""
            return (self.cluster_meta.get('group', cluster_id)
                    in ('noise', 'mua'))

        @self.add_column(show=False)
        def good(cluster_id):
            """Good column for color."""
            return self.cluster_meta.get('group', cluster_id) == 'good'

        def similarity(cluster_id):
            # NOTE: there is a dictionary with the similarity to the current
            # best cluster. It is updated when the selection changes in the
            # cluster view. This is a bit of a hack: the HTML table expects
            # a function that returns a value for every row, but here we
            # cache all similarity view rows in self._current_similarity_values
            return self._current_similarity_values.get(cluster_id, 0)
        if self.similarity:
            self.similarity_view.add_column(similarity,
                                            name=self.similarity.__name__)

    def _create_actions(self, gui):
        self.actions = Actions(gui,
                               name='Clustering',
                               menu='&Clustering',
                               default_shortcuts=self.shortcuts)

        # Selection.
        self.actions.add(self.select, alias='c')
        self.actions.separator()

        # Clustering.
        self.actions.add(self.merge, alias='g')
        self.actions.add(self.split, alias='k')
        self.actions.separator()

        # Move.
        self.actions.add(self.move)

        for group in ('noise', 'mua', 'good'):
            self.actions.add(partial(self.move_best, group),
                             name='move_best_to_' + group,
                             docstring='Move the best clusters to %s.' % group)
            self.actions.add(partial(self.move_similar, group),
                             name='move_similar_to_' + group,
                             docstring='Move the similar clusters to %s.' %
                             group)
            self.actions.add(partial(self.move_all, group),
                             name='move_all_to_' + group,
                             docstring='Move all selected clusters to %s.' %
                             group)
        self.actions.separator()

        # Others.
        self.actions.add(self.undo)
        self.actions.add(self.redo)
        self.actions.add(self.save)

        # Wizard.
        self.actions.add(self.reset, menu='&Wizard')
        self.actions.add(self.next, menu='&Wizard')
        self.actions.add(self.previous, menu='&Wizard')
        self.actions.add(self.next_best, menu='&Wizard')
        self.actions.add(self.previous_best, menu='&Wizard')
        self.actions.separator()

    def _create_cluster_views(self):
        # Create the cluster view.
        self.cluster_view = ClusterView()
        self.cluster_view.build()

        # Create the similarity view.
        self.similarity_view = ClusterView()
        self.similarity_view.build()

        # Selection in the cluster view.
        @self.cluster_view.connect_
        def on_select(cluster_ids):
            # Emit GUI.select when the selection changes in the cluster view.
            self._emit_select(cluster_ids)
            # Pin the clusters and update the similarity view.
            self._update_similarity_view()

        # Selection in the similarity view.
        @self.similarity_view.connect_  # noqa
        def on_select(cluster_ids):
            # Select the clusters from both views.
            cluster_ids = self.cluster_view.selected + cluster_ids
            self._emit_select(cluster_ids)

        # Save the current selection when an action occurs.
        def on_request_undo_state(up):
            return {'selection': (self.cluster_view.selected,
                                  self.similarity_view.selected)}

        self.clustering.connect(on_request_undo_state)
        self.cluster_meta.connect(on_request_undo_state)

        self._update_cluster_view()

    def _update_cluster_view(self):
        """Initialize the cluster view with cluster data."""
        logger.log(5, "Update the cluster view.")
        cluster_ids = [int(c) for c in self.clustering.cluster_ids]
        self.cluster_view.set_rows(cluster_ids)

    def _update_similarity_view(self):
        """Update the similarity view with matches for the specified
        clusters."""
        if not self.similarity:
            return
        selection = self.cluster_view.selected
        if not len(selection):
            return
        cluster_id = selection[0]
        cluster_ids = self.clustering.cluster_ids
        self._best = cluster_id
        logger.log(5, "Update the similarity view.")
        # This is a list of pairs (closest_cluster, similarity).
        similarities = self.similarity(cluster_id)
        # We save the similarity values wrt the currently-selected clusters.
        # Note that we keep the order of the output of the self.similary()
        # function.
        clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities])
        # List of similar clusters, remove non-existing ones.
        clusters = [c for c in clusters_sim.keys()
                    if c in cluster_ids]
        # The similarity view will use these values.
        self._current_similarity_values = clusters_sim
        # Set the rows of the similarity view.
        # TODO: instead of the self._current_similarity_values hack,
        # give the possibility to specify the values here (?).
        self.similarity_view.set_rows([c for c in clusters
                                       if c not in selection])

    def _emit_select(self, cluster_ids):
        """Choose spikes from the specified clusters and emit the
        `select` event on the GUI."""
        logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids)))
        if self.gui:
            self.gui.emit('select', cluster_ids)

    # Public methods
    # -------------------------------------------------------------------------

    def add_column(self, func=None, name=None, show=True, default=False):
        if func is None:
            return lambda f: self.add_column(f, name=name, show=show,
                                             default=default)
        name = name or func.__name__
        assert name
        self.cluster_view.add_column(func, name=name, show=show)
        self.similarity_view.add_column(func, name=name, show=show)
        if default:
            self.set_default_sort(name)

    def set_default_sort(self, name, sort_dir='desc'):
        assert name
        logger.debug("Set default sort `%s` %s.", name, sort_dir)
        # Set the default sort.
        self.cluster_view.set_default_sort(name, sort_dir)
        # Reset the cluster view.
        self._update_cluster_view()
        # Sort by the default sort.
        self.cluster_view.sort_by(name, sort_dir)

    def on_cluster(self, up):
        """Update the cluster views after clustering actions."""

        similar = self.similarity_view.selected

        # Reinitialize the cluster view if clusters have changed.
        if up.added:
            self._update_cluster_view()

        # Select all new clusters in view 1.
        if up.history == 'undo':
            # Select the clusters that were selected before the undone
            # action.
            clusters_0, clusters_1 = up.undo_state[0]['selection']
            self.cluster_view.select(clusters_0)
            self.similarity_view.select(clusters_1)
        elif up.added:
            if up.description == 'assign':
                # NOTE: we reverse the order such that the last selected
                # cluster (with a new color) is the split cluster.
                added = up.added[::-1]
            else:
                added = up.added
            self.select(added)
            if similar:
                self.similarity_view.next()
        elif up.metadata_changed:
            # Select next in similarity view if all moved are in that view.
            if set(up.metadata_changed) <= set(similar):

                # Update the cluster view, and select the clusters that
                # were selected before the action.
                selected = self.similarity_view.selected
                self._update_similarity_view()
                self.similarity_view.select(selected, do_emit=False)
                self.similarity_view.next()
            # Otherwise, select next in cluster view.
            else:
                # Update the cluster view, and select the clusters that
                # were selected before the action.
                selected = self.cluster_view.selected
                self._update_cluster_view()
                self.cluster_view.select(selected, do_emit=False)
                self.cluster_view.next()
                if similar:
                    self.similarity_view.next()

    def attach(self, gui):
        self.gui = gui

        # Create the actions.
        self._create_actions(gui)

        # Add the cluster views.
        gui.add_view(self.cluster_view, name='ClusterView')

        # Add the quality column in the cluster view.
        if self.quality:
            self.cluster_view.add_column(self.quality,
                                         name=self.quality.__name__,
                                         )

        # Update the cluster view and sort by n_spikes at the beginning.
        self._update_cluster_view()
        # if not self.quality:
        #     self.cluster_view.sort_by('n_spikes', 'desc')

        # Add the similarity view if there is a similarity function.
        if self.similarity:
            gui.add_view(self.similarity_view, name='SimilarityView')

        # Set the view state.
        cv = self.cluster_view
        cv.set_state(gui.state.get_view_state(cv))

        # Save the view state in the GUI state.
        @gui.connect_
        def on_close():
            gui.state.update_view_state(cv, cv.state)
            # NOTE: create_gui() already saves the state, but the event
            # is registered *before* we add all views.
            gui.state.save()

        # Update the cluster views and selection when a cluster event occurs.
        self.gui.connect_(self.on_cluster)
        return self

    # Selection actions
    # -------------------------------------------------------------------------

    def select(self, *cluster_ids):
        """Select a list of clusters."""
        # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])`
        # This makes it more convenient to select multiple clusters with
        # the snippet: `:c 1 2 3` instead of `:c 1,2,3`.
        if cluster_ids and isinstance(cluster_ids[0], (tuple, list)):
            cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:])
        # Update the cluster view selection.
        self.cluster_view.select(cluster_ids)

    @property
    def selected(self):
        return self.cluster_view.selected + self.similarity_view.selected

    # Clustering actions
    # -------------------------------------------------------------------------

    def merge(self, cluster_ids=None):
        """Merge the selected clusters."""
        if cluster_ids is None:
            cluster_ids = self.selected
        if len(cluster_ids or []) <= 1:
            return
        self.clustering.merge(cluster_ids)
        self._global_history.action(self.clustering)

    def split(self, spike_ids=None, spike_clusters_rel=0):
        """Split the selected spikes."""
        if spike_ids is None:
            spike_ids = self.gui.emit('request_split')
            spike_ids = np.concatenate(spike_ids).astype(np.int64)
        if len(spike_ids) == 0:
            msg = ("You first need to select spikes in the feature "
                   "view with a few Ctrl+Click around the spikes "
                   "that you want to split.")
            _show_box(self.gui.dialog(msg))
            return
        self.clustering.split(spike_ids,
                              spike_clusters_rel=spike_clusters_rel)
        self._global_history.action(self.clustering)

    # Move actions
    # -------------------------------------------------------------------------

    def move(self, cluster_ids, group):
        """Move clusters to a group."""
        if len(cluster_ids) == 0:
            return
        self.cluster_meta.set('group', cluster_ids, group)
        self._global_history.action(self.cluster_meta)

    def move_best(self, group=None):
        """Move all selected best clusters to a group."""
        self.move(self.cluster_view.selected, group)

    def move_similar(self, group=None):
        """Move all selected similar clusters to a group."""
        self.move(self.similarity_view.selected, group)

    def move_all(self, group=None):
        """Move all selected clusters to a group."""
        self.move(self.selected, group)

    # Wizard actions
    # -------------------------------------------------------------------------

    def reset(self):
        """Reset the wizard."""
        self._update_cluster_view()
        self.cluster_view.next()

    def next_best(self):
        """Select the next best cluster."""
        self.cluster_view.next()

    def previous_best(self):
        """Select the previous best cluster."""
        self.cluster_view.previous()

    def next(self):
        """Select the next cluster."""
        if not self.selected:
            self.cluster_view.next()
        else:
            self.similarity_view.next()

    def previous(self):
        """Select the previous cluster."""
        self.similarity_view.previous()

    # Other actions
    # -------------------------------------------------------------------------

    def undo(self):
        """Undo the last action."""
        self._global_history.undo()

    def redo(self):
        """Undo the last undone action."""
        self._global_history.redo()

    def save(self):
        """Save the manual clustering back to disk."""
        spike_clusters = self.clustering.spike_clusters
        groups = {c: self.cluster_meta.get('group', c) or 'unsorted'
                  for c in self.clustering.cluster_ids}
        self.gui.emit('request_save', spike_clusters, groups)
Esempio n. 8
0
class ActionCreator(object):
    """Companion class to the Supervisor that manages the related GUI actions."""

    default_shortcuts = {
        # Clustering.
        'merge': 'g',
        'split': 'k',

        'label': 'l',

        # Move.
        'move_best_to_noise': 'alt+n',
        'move_best_to_mua': 'alt+m',
        'move_best_to_good': 'alt+g',
        'move_best_to_unsorted': 'alt+u',

        'move_similar_to_noise': 'ctrl+n',
        'move_similar_to_mua': 'ctrl+m',
        'move_similar_to_good': 'ctrl+g',
        'move_similar_to_unsorted': 'ctrl+u',

        'move_all_to_noise': 'ctrl+alt+n',
        'move_all_to_mua': 'ctrl+alt+m',
        'move_all_to_good': 'ctrl+alt+g',
        'move_all_to_unsorted': 'ctrl+alt+u',

        # Wizard.
        'reset': 'ctrl+alt+space',
        'next': 'space',
        'previous': 'shift+space',
        'unselect_similar': 'backspace',
        'next_best': 'down',
        'previous_best': 'up',

        # Misc.
        'undo': 'ctrl+z',
        'redo': ('ctrl+shift+z', 'ctrl+y'),
    }

    default_snippets = {
        'merge': 'g',
        'split': 'k',
        'label': 'l',
        'select': 'c',
        'filter': 'f',
        'sort': 's',
    }

    def __init__(self, supervisor=None):
        self.supervisor = supervisor

    def add(self, which, name, **kwargs):
        """Add an action to a given menu."""
        # This special keyword argument lets us use a different name for the
        # action and the event name/method (used for different move flavors).
        method_name = kwargs.pop('method_name', name)
        method_args = kwargs.pop('method_args', ())
        emit_fun = partial(emit, 'action', self, method_name, *method_args)
        f = getattr(self.supervisor, method_name, None)
        docstring = inspect.getdoc(f) if f else name
        if not kwargs.get('docstring', None):
            kwargs['docstring'] = docstring
        getattr(self, '%s_actions' % which).add(emit_fun, name=name, **kwargs)

    def attach(self, gui):
        """Attach the GUI and create the menus."""
        # Create the menus.
        ds = self.default_shortcuts
        dsp = self.default_snippets
        self.edit_actions = Actions(
            gui, menu='&Edit', default_shortcuts=ds, default_snippets=dsp)
        self.select_actions = Actions(
            gui, menu='Sele&ct', default_shortcuts=ds, default_snippets=dsp)
        self.view_actions = Actions(
            gui, menu='&View', default_shortcuts=ds, default_snippets=dsp)

        # Create the actions.
        self._create_edit_actions(gui.state)
        self._create_select_actions(gui.state)
        self._create_view_actions(gui.state)

    def _create_edit_actions(self, state):
        w = 'edit'
        self.add(w, 'undo', set_busy=True)
        self.add(w, 'redo', set_busy=True)
        self.edit_actions.separator()

        # Clustering.
        self.add(w, 'merge', set_busy=True)
        self.add(w, 'split', set_busy=True)
        self.edit_actions.separator()

        # Move.
        self.add(w, 'move', prompt=True, n_args=2)
        for which in ('best', 'similar', 'all'):
            for group in ('noise', 'mua', 'good', 'unsorted'):
                self.add(
                    w, 'move_%s_to_%s' % (which, group),
                    method_name='move',
                    method_args=(group, which),
                    submenu='Move to %s' % which,
                    docstring='Move %s to %s.' % (which, group))
        self.edit_actions.separator()

        # Label.
        self.add(w, 'label', prompt=True, n_args=2)
        self.edit_actions.separator()

    def _create_select_actions(self, state):
        w = 'select'

        # Selection.
        self.add(w, 'select', prompt=True, n_args=1)
        self.add(w, 'unselect_similar')
        self.select_actions.separator()

        # Sort and filter
        self.add(w, 'filter', prompt=True, n_args=1)
        self.add(w, 'sort', prompt=True, n_args=1)

        # Sort by:
        for column in getattr(self.supervisor, 'columns', ()):
            self.add(
                w, 'sort_by_%s' % column.lower(), method_name='sort', method_args=(column,),
                docstring='Sort by %s' % column,
                submenu='Sort by', alias='s%s' % column.replace('_', '')[:2])

        self.select_actions.separator()

        self.add(w, 'reset_wizard')
        self.select_actions.separator()

        self.add(w, 'next')
        self.add(w, 'previous')
        self.select_actions.separator()

        self.add(w, 'next_best')
        self.add(w, 'previous_best')
        self.select_actions.separator()

    def _create_view_actions(self, state):
        w = 'view'
        cluster_labels_keys = getattr(self.supervisor, 'cluster_labels', {}).keys()
        cluster_metrics_keys = getattr(self.supervisor, 'cluster_metrics', {}).keys()

        # Change color field action.
        for field in chain(
                ('cluster', 'group', 'n_spikes'), cluster_labels_keys, cluster_metrics_keys):
            self.add(
                w, name='color_field_%s' % field.lower(),
                method_name='change_color_field',
                method_args=(field,),
                docstring='Change color field to %s' % field,
                alias='cf%s' % field.replace('_', '')[:2],
                submenu='Change color field')

        # Change color map action.
        for colormap in ('categorical', 'cluster_group', 'diverging', 'linear', 'rainbow'):
            self.add(
                w, name='colormap_%s' % colormap.lower(),
                method_name='change_colormap',
                method_args=(colormap,),
                docstring='Change colormap to %s' % colormap,
                alias='cm%s' % colormap[:2],
                submenu='Change colormap')

        # Change colormap categorical or continous.
        categorical = state.get('color_selector', Bunch()).get('categorical', None)
        self.add(w, 'toggle_categorical_colormap', checkable=True, checked=categorical is True)

        # Change colormap logarithmic.
        logarithmic = state.get('color_selector', Bunch()).get('logarithmic', None)
        self.add(w, 'toggle_logarithmic_colormap', checkable=True, checked=logarithmic is True)

        self.view_actions.separator()