Example #1
0
class ProbeView(ManualClusteringView):
    """This view displays the positions of all channels on the probe, highlighting channels
    where the selected clusters belong.

    Constructor
    -----------

    positions : array-like
        An `(n_channels, 2)` array with the channel positions
    best_channels : function
        Maps `cluster_id` to the list of the best_channel_ids.
    channel_labels : list
        List of channel label strings.
    dead_channels : list
        List of dead channel ids.

    """

    _default_position = 'right'

    # Marker size of channels without selected clusters.
    unselected_marker_size = 10

    # Marker size of channels with selected clusters.
    selected_marker_size = 15

    # Alpha value of the dead channels.
    dead_channel_alpha = .25

    do_show_labels = False

    def __init__(
            self, positions=None, best_channels=None, channel_labels=None,
            dead_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)
        self.state_attrs += ('do_show_labels',)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        positions = positions.astype(np.float32)
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)]
        self.dead_channels = dead_channels if dead_channels is not None else ()

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        color = np.ones((self.n_channels, 4))
        color[:, :3] = .5
        # Change alpha value for dead channels.
        if len(self.dead_channels):
            color[self.dead_channels, 3] = self.dead_channel_alpha
        self.probe_visual.set_data(
            pos=self.positions, data_bounds=self.data_bounds,
            color=color, size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

        # Text visual
        color[:] = 1
        color[self.dead_channels, :3] = self.dead_channel_alpha * 2
        self.text_visual = TextVisual()
        self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_channels >= 200 * u_zoom.y) && '
            '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
        self.text_visual.set_data(
            pos=self.positions, text=self.channel_labels, anchor=[0, -1],
            data_bounds=self.data_bounds, color=color
        )
        self.text_visual.program['n_channels'] = self.n_channels
        self.canvas.update()

    def _get_clu_positions(self, cluster_ids):
        """Get the positions of the channels containing selected clusters."""

        # List of channels per cluster.
        cluster_channels = {i: self.best_channels(cl) for i, cl in enumerate(cluster_ids)}

        # List of clusters per channel.
        clusters_per_channel = defaultdict(lambda: [])
        for clu_idx, channels in cluster_channels.items():
            for channel in channels:
                clusters_per_channel[channel].append(clu_idx)

        # Enumerate the discs for each channel.
        w = self.data_bounds[2] - self.data_bounds[0]
        clu_pos = []
        clu_colors = []
        for channel_id, (x, y) in enumerate(self.positions):
            for i, clu_idx in enumerate(clusters_per_channel[channel_id]):
                n = len(clusters_per_channel[channel_id])
                # Translation.
                t = .025 * w * (i - .5 * (n - 1))
                x += t
                alpha = 1.0 if channel_id not in self.dead_channels else self.dead_channel_alpha
                clu_pos.append((x, y))
                clu_colors.append(selected_cluster_color(clu_idx, alpha=alpha))
        return np.array(clu_pos), np.array(clu_colors)

    def on_select(self, cluster_ids=(), **kwargs):
        """Update the view with the selected clusters."""
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        pos, colors = self._get_clu_positions(cluster_ids)
        self.cluster_visual.set_data(
            pos=pos, color=colors, size=self.selected_marker_size, data_bounds=self.data_bounds)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(ProbeView, self).attach(gui)
        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)

        if not self.do_show_labels:
            self.text_visual.hide()

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual._hidden = not checked
        self.canvas.update()
Example #2
0
class ClusterScatterView(MarkerSizeMixin, BaseColorView, BaseGlobalView,
                         ManualClusteringView):
    """This view shows all clusters in a customizable scatter plot.

    Constructor
    -----------

    cluster_ids : array-like
    cluster_info: function
        Maps cluster_id => Bunch() with attributes.
    bindings: dict
        Maps plot dimension to cluster attributes.

    """
    _default_position = 'right'
    _scaling = 1.
    _default_alpha = .75
    _min_marker_size = 5.0
    _max_marker_size = 30.0
    _dims = ('x_axis', 'y_axis', 'size')

    # NOTE: this is not the actual marker size, but a scaling factor for the normal marker size.
    _marker_size = 1.
    _default_marker_size = 1.

    x_axis = ''
    y_axis = ''
    size = ''
    x_axis_log_scale = False
    y_axis_log_scale = False
    size_log_scale = False

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'switch_color_scheme': 'shift+wheel',
        'select_cluster': 'click',
        'select_more': 'shift+click',
        'add_to_lasso': 'control+left click',
        'clear_lasso': 'control+right click',
    }

    default_snippets = {
        'set_x_axis': 'csx',
        'set_y_axis': 'csy',
        'set_size': 'css',
    }

    def __init__(self,
                 cluster_ids=None,
                 cluster_info=None,
                 bindings=None,
                 **kwargs):
        super(ClusterScatterView, self).__init__(**kwargs)
        self.state_attrs += (
            'scaling',
            'x_axis',
            'y_axis',
            'size',
            'x_axis_log_scale',
            'y_axis_log_scale',
            'size_log_scale',
        )
        self.local_state_attrs += ()

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        bindings = bindings or {}
        self.cluster_info = cluster_info
        # update self.x_axis, y_axis, size
        self.__dict__.update({(k, v)
                              for k, v in bindings.items() if k in self._dims})

        # Size range computed initially so that it doesn't change during the course of the session.
        self._size_min = self._size_max = None

        # Full list of clusters.
        self.all_cluster_ids = cluster_ids

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.label_visual = TextVisual()
        self.canvas.add_visual(self.label_visual,
                               exclude_origins=(self.canvas.panzoom, ))

        self.marker_positions = self.marker_colors = self.marker_sizes = None

    def _update_labels(self):
        self.label_visual.set_data(pos=[[-1, -1], [1, 1]],
                                   text=[self.x_axis, self.y_axis],
                                   anchor=[[1.25, 3], [-3, -1.25]])

    # Data access
    # -------------------------------------------------------------------------

    @property
    def bindings(self):
        return {k: getattr(self, k) for k in self._dims}

    def get_cluster_data(self, cluster_id):
        """Return the data of one cluster."""
        data = self.cluster_info(cluster_id)
        return {k: data.get(v, 0.) for k, v in self.bindings.items()}

    def get_clusters_data(self, cluster_ids):
        """Return the data of a set of clusters, as a dictionary {cluster_id: Bunch}."""
        return {
            cluster_id: self.get_cluster_data(cluster_id)
            for cluster_id in cluster_ids
        }

    def set_cluster_ids(self, all_cluster_ids):
        """Update the cluster data by specifying the list of all cluster ids."""
        self.all_cluster_ids = all_cluster_ids
        if len(all_cluster_ids) == 0:
            return
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    # Data preparation
    # -------------------------------------------------------------------------

    def set_fields(self):
        data = self.cluster_info(self.all_cluster_ids[0])
        self.fields = sorted(data.keys())
        self.fields = [f for f in self.fields if not isinstance(data[f], str)]

    def prepare_data(self):
        """Prepare the marker position, size, and color from the cluster information."""
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    def prepare_position(self):
        """Compute the marker positions."""
        self.cluster_data = self.get_clusters_data(self.all_cluster_ids)

        # Get the list of fields returned by cluster_info.
        self.set_fields()

        # Create the x array.
        x = np.array([
            self.cluster_data[cluster_id]['x_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.x_axis_log_scale:
            x = np.log(1.0 + x - x.min())

        # Create the y array.
        y = np.array([
            self.cluster_data[cluster_id]['y_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.y_axis_log_scale:
            y = np.log(1.0 + y - y.min())

        self.marker_positions = np.c_[x, y]

        # Update the data bounds.
        self.data_bounds = (x.min(), y.min(), x.max(), y.max())

    def prepare_size(self):
        """Compute the marker sizes."""
        size = np.array([
            self.cluster_data[cluster_id]['size'] or 1.
            for cluster_id in self.all_cluster_ids
        ])
        # Log scale for the size.
        if self.size_log_scale:
            size = np.log(1.0 + size - size.min())
        # Find the size range.
        if self._size_min is None:
            self._size_min, self._size_max = size.min(), size.max()
        m, M = self._size_min, self._size_max
        # Normalize the marker size.
        size = (size - m) / ((M - m) or 1.0)  # size is in [0, 1]
        ms, Ms = self._min_marker_size, self._max_marker_size
        size = ms + size * (Ms - ms)  # now, size is in [ms, Ms]
        self.marker_sizes = size

    def prepare_color(self):
        """Compute the marker colors."""
        colors = self.get_cluster_colors(self.all_cluster_ids,
                                         self._default_alpha)
        self.marker_colors = colors

    # Marker size
    # -------------------------------------------------------------------------

    @property
    def marker_size(self):
        """Size of the spike markers, in pixels."""
        return self._marker_size

    @marker_size.setter
    def marker_size(self, val):
        # We override this method so as to use self._marker_size as a scaling factor, not
        # as an actual fixed marker size.
        self._marker_size = val
        self._set_marker_size()
        self.canvas.update()

    def _set_marker_size(self):
        if self.marker_sizes is not None:
            self.visual.set_marker_size(self.marker_sizes * self._marker_size)

    # Plotting functions
    # -------------------------------------------------------------------------

    def update_color(self):
        """Update the cluster colors depending on the current color scheme."""
        self.prepare_color()
        self.visual.set_color(self.marker_colors)
        self.canvas.update()

    def update_select_color(self):
        """Update the cluster colors after the cluster selection changes."""
        if self.marker_colors is None:
            return
        selected_clusters = self.cluster_ids
        if selected_clusters is not None and len(selected_clusters) > 0:
            colors = _add_selected_clusters_colors(selected_clusters,
                                                   self.all_cluster_ids,
                                                   self.marker_colors.copy())
            self.visual.set_color(colors)
            self.canvas.update()

    def plot(self, **kwargs):
        """Make the scatter plot."""
        if self.marker_positions is None:
            self.prepare_data()
        self.visual.set_data(
            pos=self.marker_positions,
            color=self.marker_colors,
            size=self.marker_sizes *
            self._marker_size,  # marker size scaling factor
            data_bounds=self.data_bounds)
        self.canvas.axes.reset_data_bounds(self.data_bounds)
        self.canvas.update()

    def change_bindings(self, **kwargs):
        """Change the bindings."""
        # Ensure the specified fields are valid.
        kwargs = {k: v for k, v in kwargs.items() if v in self.fields}
        assert set(kwargs.keys()) <= set(self._dims)
        # Reset the size scaling.
        if 'size' in kwargs:
            self._size_min = self._size_max = None
        self.__dict__.update(kwargs)
        self._update_labels()
        self.update_status()
        self.prepare_data()
        self.plot()

    def toggle_log_scale(self, dim, checked):
        """Toggle logarithmic scaling for one of the dimensions."""
        self._size_min = None
        setattr(self, '%s_log_scale' % dim, checked)
        self.prepare_data()
        self.plot()
        self.canvas.update()

    def set_x_axis(self, field):
        """Set the dimension for the x axis."""
        self.change_bindings(x_axis=field)

    def set_y_axis(self, field):
        """Set the dimension for the y axis."""
        self.change_bindings(y_axis=field)

    def set_size(self, field):
        """Set the dimension for the marker size."""
        self.change_bindings(size=field)

    # Misc functions
    # -------------------------------------------------------------------------

    def attach(self, gui):
        """Attach the GUI."""
        super(ClusterScatterView, self).attach(gui)

        def _make_action(dim, name):
            def callback():
                self.change_bindings(**{dim: name})

            return callback

        def _make_log_toggle(dim):
            def callback(checked):
                self.toggle_log_scale(dim, checked)

            return callback

        # Change the bindings.
        for dim in self._dims:
            view_submenu = 'Change %s' % dim

            # Change to every cluster info.
            for name in self.fields:
                self.actions.add(_make_action(dim, name),
                                 show_shortcut=False,
                                 name='Change %s to %s' % (dim, name),
                                 view_submenu=view_submenu)

            # Toggle logarithmic scale.
            self.actions.separator(view_submenu=view_submenu)
            self.actions.add(_make_log_toggle(dim),
                             checkable=True,
                             view_submenu=view_submenu,
                             name='Toggle log scale for %s' % dim,
                             show_shortcut=False,
                             checked=getattr(self, '%s_log_scale' % dim))

        self.actions.separator()
        self.actions.add(self.set_x_axis,
                         prompt=True,
                         prompt_default=lambda: self.x_axis)
        self.actions.add(self.set_y_axis,
                         prompt=True,
                         prompt_default=lambda: self.y_axis)
        self.actions.add(self.set_size,
                         prompt=True,
                         prompt_default=lambda: self.size)

        connect(self.on_select)
        connect(self.on_cluster)

        @connect(sender=self.canvas)
        def on_lasso_updated(sender, polygon):
            if len(polygon) < 3:
                return
            pos = range_transform([self.data_bounds], [NDC],
                                  self.marker_positions)
            ind = self.canvas.lasso.in_polygon(pos)
            cluster_ids = self.all_cluster_ids[ind]
            emit("request_select", self, list(cluster_ids))

        @connect(sender=self)
        def on_close_view(view_, gui):
            """Unconnect all events when closing the view."""
            unconnect(self.on_select)
            unconnect(self.on_cluster)
            unconnect(on_lasso_updated)

        if self.all_cluster_ids is not None:
            self.set_cluster_ids(self.all_cluster_ids)
        self._update_labels()

    def on_select(self, *args, **kwargs):
        super(ClusterScatterView, self).on_select(*args, **kwargs)
        self.update_select_color()

    def on_cluster(self, sender, up):
        if 'all_cluster_ids' in up:
            self.set_cluster_ids(up.all_cluster_ids)
            self.plot()

    @property
    def status(self):
        return 'Size: %s. Color scheme: %s.' % (self.size, self.color_scheme)

    # Interactivity
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on its template waveform."""
        if 'Control' in e.modifiers:
            return
        b = e.button
        pos = self.canvas.window_to_ndc(e.pos)
        marker_pos = range_transform([self.data_bounds], [NDC],
                                     self.marker_positions)
        cluster_rel = np.argmin(((marker_pos - pos)**2).sum(axis=1))
        cluster_id = self.all_cluster_ids[cluster_rel]
        logger.debug("Click on cluster %d with button %s.", cluster_id, b)
        if 'Shift' in e.modifiers:
            emit('select_more', self, [cluster_id])
        else:
            emit('request_select', self, [cluster_id])