コード例 #1
0
ファイル: raster.py プロジェクト: zsong30/phy
    def __init__(self, spike_times, spike_clusters, cluster_ids=None, **kwargs):
        self.spike_times = spike_times
        self.n_spikes = len(spike_times)
        self.duration = spike_times[-1] * 1.01
        self.n_clusters = 1

        assert len(spike_clusters) == self.n_spikes
        self.set_spike_clusters(spike_clusters)
        self.set_cluster_ids(cluster_ids)

        super(RasterView, self).__init__(**kwargs)

        self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False)
        self.canvas.enable_axes()

        self.visual = ScatterVisual(
            marker='vbar',
            marker_scaling='''
                point_size = v_size * u_zoom.y + 5.;
                float width = 0.2;
                float height = 0.5;
                vec2 marker_size = point_size * vec2(width, height);
                marker_size.x = clamp(marker_size.x, 1, 20);
            ''',
        )
        self.visual.inserter.insert_vert('''
                gl_PointSize = a_size * u_zoom.y + 5.0;
        ''', 'end')
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))
コード例 #2
0
ファイル: amplitude.py プロジェクト: yingluo227/Labtools
    def __init__(self, amplitudes=None, amplitude_name=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitude_name',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes
        self.amplitude_names = list(amplitudes.keys())
        # Current amplitude type.
        self.amplitude_name = amplitude_name or self.amplitude_names[0]
        assert self.amplitude_name in amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add_on_gpu([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')])
        self.canvas.add_visual(self.hist_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        # Amplitude name.
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))
コード例 #3
0
ファイル: feature.py プロジェクト: zsong30/phy
    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)
コード例 #4
0
ファイル: scatter.py プロジェクト: yingluo227/Labtools
    def __init__(self, coords=None, **kwargs):
        super(ScatterView, self).__init__(**kwargs)
        # Save the marker size in the global and local view's config.

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        assert coords
        self.coords = coords
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
コード例 #5
0
    def __init__(
            self, positions=None, best_channels=None, channel_labels=None,
            dead_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)
        self.state_attrs += ('do_show_labels',)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        positions = positions.astype(np.float32)
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)]
        self.dead_channels = dead_channels if dead_channels is not None else ()

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        color = np.ones((self.n_channels, 4))
        color[:, :3] = .5
        # Change alpha value for dead channels.
        if len(self.dead_channels):
            color[self.dead_channels, 3] = self.dead_channel_alpha
        self.probe_visual.set_data(
            pos=self.positions, data_bounds=self.data_bounds,
            color=color, size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

        # Text visual
        color[:] = 1
        color[self.dead_channels, :3] = self.dead_channel_alpha * 2
        self.text_visual = TextVisual()
        self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_channels >= 200 * u_zoom.y) && '
            '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
        self.text_visual.set_data(
            pos=self.positions, text=self.channel_labels, anchor=[0, -1],
            data_bounds=self.data_bounds, color=color
        )
        self.text_visual.program['n_channels'] = self.n_channels
        self.canvas.update()
コード例 #6
0
    def __init__(
            self, spike_times, spike_clusters, cluster_ids=None, cluster_color_selector=None,
            **kwargs):
        self.spike_times = spike_times
        self.n_spikes = len(spike_times)
        self.duration = spike_times[-1] * 1.01
        self.n_clusters = 1

        assert len(spike_clusters) == self.n_spikes
        self.set_spike_clusters(spike_clusters)
        self.set_cluster_ids(cluster_ids if cluster_ids is not None else None)
        self.cluster_color_selector = cluster_color_selector

        super(RasterView, self).__init__(**kwargs)

        self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False)
        self.canvas.enable_axes()

        self.visual = ScatterVisual(marker='vbar')
        self.canvas.add_visual(self.visual)
コード例 #7
0
    def __init__(self,
                 cluster_ids=None,
                 cluster_info=None,
                 bindings=None,
                 **kwargs):
        super(ClusterScatterView, self).__init__(**kwargs)
        self.state_attrs += (
            'scaling',
            'x_axis',
            'y_axis',
            'size',
            'x_axis_log_scale',
            'y_axis_log_scale',
            'size_log_scale',
        )
        self.local_state_attrs += ()

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        bindings = bindings or {}
        self.cluster_info = cluster_info
        # update self.x_axis, y_axis, size
        self.__dict__.update({(k, v)
                              for k, v in bindings.items() if k in self._dims})

        # Size range computed initially so that it doesn't change during the course of the session.
        self._size_min = self._size_max = None

        # Full list of clusters.
        self.all_cluster_ids = cluster_ids

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.label_visual = TextVisual()
        self.canvas.add_visual(self.label_visual,
                               exclude_origins=(self.canvas.panzoom, ))

        self.marker_positions = self.marker_colors = self.marker_sizes = None
コード例 #8
0
ファイル: probe.py プロジェクト: yingluo227/Labtools
    def __init__(self, positions=None, best_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        self.probe_visual.set_data(pos=self.positions,
                                   data_bounds=self.data_bounds,
                                   color=(.5, .5, .5, 1.),
                                   size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)
コード例 #9
0
class RasterView(MarkerSizeMixin, BaseGlobalView, ManualClusteringView):
    """This view shows a raster plot of all clusters.

    Constructor
    -----------

    spike_times : array-like
        An `(n_spikes,)` array with the spike times, in seconds.
    spike_clusters : array-like
        An `(n_spikes,)` array with the spike-cluster assignments.
    cluster_ids : array-like
        The list of all clusters to show initially.
    cluster_color_selector : ClusterColorSelector
        The object managing the color mapping.

    """

    _default_position = 'right'

    default_shortcuts = {
        'change_marker_size': 'ctrl+wheel',
        'decrease': 'ctrl+shift+-',
        'increase': 'ctrl+shift++',
        'select_cluster': 'ctrl+click',
    }

    def __init__(
            self, spike_times, spike_clusters, cluster_ids=None, cluster_color_selector=None,
            **kwargs):
        self.spike_times = spike_times
        self.n_spikes = len(spike_times)
        self.duration = spike_times[-1] * 1.01
        self.n_clusters = 1

        assert len(spike_clusters) == self.n_spikes
        self.set_spike_clusters(spike_clusters)
        self.set_cluster_ids(cluster_ids if cluster_ids is not None else None)
        self.cluster_color_selector = cluster_color_selector

        super(RasterView, self).__init__(**kwargs)

        self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False)
        self.canvas.enable_axes()

        self.visual = ScatterVisual(marker='vbar')
        self.canvas.add_visual(self.visual)

    # Data-related functions
    # -------------------------------------------------------------------------

    def set_spike_clusters(self, spike_clusters):
        """Set the spike clusters for all spikes."""
        self.spike_clusters = spike_clusters

    def set_cluster_ids(self, cluster_ids):
        """Set the shown clusters, which can be filtered and in any order (from top to bottom)."""
        if cluster_ids is None or not len(cluster_ids):
            return
        self.all_cluster_ids = cluster_ids
        self.n_clusters = len(self.all_cluster_ids)
        # Only keep spikes that belong to the selected clusters.
        self.spike_ids = np.isin(self.spike_clusters, self.all_cluster_ids)

    # Internal plotting functions
    # -------------------------------------------------------------------------

    def _get_x(self):
        """Return the x position of the spikes."""
        return self.spike_times[self.spike_ids]

    def _get_y(self):
        """Return the y position of the spikes, given the relative position of the clusters."""
        return np.zeros(np.sum(self.spike_ids))

    def _get_box_index(self):
        """Return, for every spike, its row in the raster plot. This depends on the ordering
        in self.cluster_ids."""
        cl = self.spike_clusters[self.spike_ids]
        # Sanity check.
        # assert np.all(np.in1d(cl, self.cluster_ids))
        return _index_of(cl, self.all_cluster_ids)

    def _get_color(self, box_index, selected_clusters=None):
        """Return, for every spike, its color, based on its box index."""
        cluster_colors = self.cluster_color_selector.get_colors(self.all_cluster_ids, alpha=.75)
        # Selected cluster colors.
        if selected_clusters is not None:
            cluster_colors = _add_selected_clusters_colors(
                selected_clusters, self.all_cluster_ids, cluster_colors)
        return cluster_colors[box_index, :]

    # Main methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self):
        """Bounds of the raster plot view."""
        return (0, 0, self.duration, self.n_clusters)

    def update_cluster_sort(self, cluster_ids):
        """Update the order of all clusters."""
        self.all_cluster_ids = cluster_ids
        self.visual.set_box_index(self._get_box_index())
        self.canvas.update()

    def update_color(self, selected_clusters=None):
        """Update the color of the spikes, depending on the selected clustersd."""
        box_index = self._get_box_index()
        color = self._get_color(box_index, selected_clusters=selected_clusters)
        self.visual.set_color(color)
        self.canvas.update()

    def plot(self, **kwargs):
        """Make the raster plot."""
        if not len(self.spike_clusters):
            return
        x = self._get_x()  # spike times for the selected spikes
        y = self._get_y()  # just 0
        box_index = self._get_box_index()
        color = self._get_color(box_index)
        assert x.shape == y.shape == box_index.shape
        assert color.shape[0] == len(box_index)
        self.data_bounds = self._get_data_bounds()

        self.visual.set_data(
            x=x, y=y, color=color, size=self.marker_size,
            data_bounds=(0, -1, self.duration, 1))
        self.visual.set_box_index(box_index)
        self.canvas.stacked.n_boxes = self.n_clusters
        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(RasterView, self).attach(gui)

        self.actions.add(self.increase)
        self.actions.add(self.decrease)
        self.actions.separator()

    def on_mouse_click(self, e):
        """Select a cluster by clicking in the raster plot."""
        b = e.button
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            cluster_idx, _ = self.canvas.stacked.box_map(e.pos)
            cluster_id = self.all_cluster_ids[cluster_idx]
            logger.debug("Click on cluster %d with button %s.", cluster_id, b)
            emit('cluster_click', self, cluster_id, button=b)
コード例 #10
0
ファイル: amplitude.py プロジェクト: yingluo227/Labtools
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView):
    """This view displays an amplitude plot for all selected clusters.

    Constructor
    -----------

    amplitudes : function
        Maps `cluster_ids` to a list `[Bunch(amplitudes, spike_ids), ...]` for each cluster.
        Use `cluster_id=None` for background amplitudes.

    """

    _default_position = 'right'

    # Alpha channel of the markers in the scatter plot.
    marker_alpha = 1.

    # Number of bins in the histogram.
    n_bins = 100

    # Alpha channel of the histogram in the background.
    histogram_alpha = .5

    # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers).
    quantile = .99

    # Size of the histogram, between 0 and 1.
    histogram_scale = .25

    default_shortcuts = {
        'change_marker_size': 'ctrl+wheel',
        'next_amplitude_type': 'a',
        'previous_amplitude_type': 'shift+a',
        'select_x_dim': 'alt+left click',
        'select_y_dim': 'alt+right click',
    }

    def __init__(self, amplitudes=None, amplitude_name=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitude_name',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes
        self.amplitude_names = list(amplitudes.keys())
        # Current amplitude type.
        self.amplitude_name = amplitude_name or self.amplitude_names[0]
        assert self.amplitude_name in amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add_on_gpu([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')])
        self.canvas.add_visual(self.hist_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        # Amplitude name.
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

    def _get_data_bounds(self, bunchs):
        """Compute the data bounds."""
        if not bunchs:  # pragma: no cover
            return (0, 0, self.duration, 1)
        m = min(np.quantile(bunch.amplitudes, 1 - self.quantile) for bunch in bunchs)
        m = min(0, m)  # ensure ymin <= 0
        M = max(np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs)
        return (0, m, self.duration, M)

    def _add_histograms(self, bunchs):
        # We do this after get_clusters_data because we need x_max.
        for bunch in bunchs:
            bunch.histogram = _compute_histogram(
                bunch.amplitudes,
                x_min=self.data_bounds[1],
                x_max=self.data_bounds[3],
                n_bins=self.n_bins,
                normalize=False,
                ignore_zeros=True,
            )
        return bunchs

    def _plot_cluster(self, bunch):
        """Make the scatter plot."""
        ms = self._marker_size

        # Histogram in the background.
        self.hist_visual.add_batch_data(
            hist=bunch.histogram,
            ylim=self._ylim,
            color=add_alpha(bunch.color, self.histogram_alpha))

        # Scatter plot.
        self.visual.add_batch_data(
            pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds)

    def _plot_amplitude_name(self):
        """Show the amplitude name."""
        self.text_visual.add_batch_data(pos=[0, 1], anchor=[0, -1], text=self.amplitude_name)

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids."""
        if not len(self.cluster_ids):
            return
        cluster_ids = list(self.cluster_ids)
        # Don't need the background when splitting.
        if not load_all:
            # Add None cluster which means background spikes.
            cluster_ids = [None] + cluster_ids
        bunchs = self.amplitudes[self.amplitude_name](cluster_ids, load_all=load_all) or ()
        # Add a pos attribute in bunchs in addition to x and y.
        for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
            spike_ids = _as_array(bunch.spike_ids)
            spike_times = _as_array(bunch.spike_times)
            amplitudes = _as_array(bunch.amplitudes)
            assert spike_ids.shape == spike_times.shape == amplitudes.shape
            # Ensure that bunch.pos exists, as it used by the LassoMixin.
            bunch.pos = np.c_[spike_times, amplitudes]
            assert bunch.pos.ndim == 2
            bunch.cluster_id = cluster_id
            bunch.color = (
                selected_cluster_color(i - 1, self.marker_alpha)
                # Background amplitude color.
                if cluster_id is not None else (.5, .5, .5, .5))
        return bunchs

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data()
        if not bunchs:
            return
        self.data_bounds = self._get_data_bounds(bunchs)
        bunchs = self._add_histograms(bunchs)
        # Use the same scale for all histograms.
        self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1.

        self.visual.reset_batch()
        self.hist_visual.reset_batch()
        self.text_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self._plot_amplitude_name()
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.hist_visual)
        self.canvas.update_visual(self.text_visual)

        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(AmplitudeView, self).attach(gui)
        self.actions.add(self.next_amplitude_type, set_busy=True)
        self.actions.add(self.previous_amplitude_type, set_busy=True)

    def _change_amplitude_type(self, dir=+1):
        i = self.amplitude_names.index(self.amplitude_name)
        n = len(self.amplitude_names)
        self.amplitude_name = self.amplitude_names[(i + dir) % n]
        logger.debug("Switch to amplitude type: %s.", self.amplitude_name)
        self.plot()

    def next_amplitude_type(self):
        """Switch to the next amplitude type."""
        self._change_amplitude_type(+1)

    def previous_amplitude_type(self):
        """Switch to the previous amplitude type."""
        self._change_amplitude_type(-1)
コード例 #11
0
ファイル: feature.py プロジェクト: zsong30/phy
class FeatureView(MarkerSizeMixin, ScalingMixin, ManualClusteringView):
    """This view displays a 4x4 subplot matrix with different projections of the principal
    component features. This view keeps track of which channels are currently shown.

    Constructor
    -----------

    features : function
        Maps `(cluster_id, channel_ids=None, load_all=False)` to
        `Bunch(data, channel_ids, channel_labels, spike_ids , masks)`.
        * `data` is an `(n_spikes, n_channels, n_features)` array
        * `channel_ids` contains the channel ids of every row in `data`
        * `channel_labels` contains the channel labels of every row in `data`
        * `spike_ids` is a `(n_spikes,)` array
        * `masks` is an `(n_spikes, n_channels)` array

        This allows for a sparse format.

    attributes : dict
        Maps an attribute name to a 1D array with `n_spikes` numbers (for example, spike times).

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    cluster_ids = ()

    # Whether to disable automatic selection of channels.
    fixed_channels = False
    feature_scaling = 1.

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'increase': 'ctrl++',
        'decrease': 'ctrl+-',
        'add_lasso_point': 'ctrl+click',
        'stop_lasso': 'ctrl+right click',
        'toggle_automatic_channel_selection': 'c',
    }

    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

    def set_grid_dim(self, grid_dim):
        """Change the grid dim dynamically.

        Parameters
        ----------
        grid_dim : array-like (2D)
            `grid_dim[row, col]` is a string with two values separated by a comma. Each value
            is the relative channel id (0, 1, 2...) followed by the PC (A, B, C...). For example,
            `grid_dim[row, col] = 0B,1A`. Each value can also be an attribute name, for example
            `time`. For example, `grid_dim[row, col] = time,2C`.

        """
        self.grid_dim = grid_dim
        self.n_rows, self.n_cols = np.array(grid_dim).shape
        self.canvas.grid.shape = (self.n_rows, self.n_cols)

    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self):
        """Yield (i, j, dim)."""
        for i in range(self.n_rows):
            for j in range(self.n_cols):
                dim = self.grid_dim[i][j]
                dim_x, dim_y = dim.split(',')
                yield i, j, dim_x, dim_y

    def _get_axis_label(self, dim):
        """Return the channel label from a dimension, if applicable."""
        if str(dim[:-1]).isdecimal():
            n = len(self.channel_ids)
            channel_id = self.channel_ids[int(dim[:-1]) % n]
            return self.channel_labels[channel_id] + dim[-1]
        else:
            return dim

    def _get_channel_and_pc(self, dim):
        """Return the channel_id and PC of a dim."""
        if self.channel_ids is None:
            return
        assert dim not in self.attributes  # This is called only on PC data.
        s = 'ABCDEFGHIJ'
        # Channel relative index, typically just 0 or 1.
        c_rel = int(dim[:-1])
        # Get the channel_id from the currently-selected channels.
        channel_id = self.channel_ids[c_rel % len(self.channel_ids)]
        pc = s.index(dim[-1])
        return channel_id, pc

    def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None):
        """Extract the points from the data on a given dimension.

        bunch is returned by the features() function.
        dim is the string specifying the dimensions to extract for the data.

        """
        if dim in self.attributes:
            return self.attributes[dim](cluster_id, load_all=load_all)
        masks = bunch.get('masks', None)
        channel_id, pc = self._get_channel_and_pc(dim)
        # Skip the plot if the channel id is not displayed.
        if channel_id not in bunch.channel_ids:  # pragma: no cover
            return Bunch(data=np.zeros((bunch.data.shape[0],)))
        # Get the column index of the current channel in data.
        c = list(bunch.channel_ids).index(channel_id)
        if masks is not None:
            masks = masks[:, c]
        return Bunch(data=self.feature_scaling * bunch.data[:, c, pc], masks=masks)

    def _get_axis_bounds(self, dim, bunch):
        """Return the min/max of an axis."""
        if dim in self.attributes:
            # Attribute: specified lim, or compute the min/max.
            vmin, vmax = bunch.get('lim', (0, 0))
            assert vmin is not None
            assert vmax is not None
            return vmin, vmax
        return (-self._lim, +self._lim)

    def _plot_points(self, bunch, clu_idx=None):
        if not bunch:
            return
        cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None
        for i, j, dim_x, dim_y in self._iter_subplots():
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id)
            # Skip empty data.
            if px is None or py is None:  # pragma: no cover
                logger.warning("Skipping empty data for cluster %d.", cluster_id)
                return
            assert px.data.shape == py.data.shape
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            assert xmin <= xmax
            assert ymin <= ymax
            data_bounds = (xmin, ymin, xmax, ymax)
            masks = _get_masks_max(px, py)
            # Prepare the batch visual with all subplots
            # for the selected cluster.
            self.visual.add_batch_data(
                x=px.data, y=py.data,
                color=_get_point_color(clu_idx),
                # Reduced marker size for background features
                size=self._marker_size,
                masks=_get_point_masks(clu_idx=clu_idx, masks=masks),
                data_bounds=data_bounds,
                box_index=(i, j),
            )
            # Get the channel ids corresponding to the relative channel indices
            # specified in the dimensions. Channel 0 corresponds to the first
            # best channel for the selected cluster, and so on.
            label_x = self._get_axis_label(dim_x)
            label_y = self._get_axis_label(dim_y)
            # Add labels.
            self.text_visual.add_batch_data(
                pos=[1, 1],
                anchor=[-1, -1],
                text=label_y,
                data_bounds=None,
                box_index=(i, j),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1.],
                anchor=[0, 1],
                text=label_x,
                data_bounds=None,
                box_index=(i, j),
            )

    def _plot_axes(self):
        self.line_visual.reset_batch()
        for i, j, dim_x, dim_y in self._iter_subplots():
            self.line_visual.add_batch_data(
                pos=[[-1., 0., +1., 0.],
                     [0., -1., 0., +1.]],
                color=(.5, .5, .5, .5),
                box_index=(i, j),
                data_bounds=None,
            )
        self.canvas.update_visual(self.line_visual)

    def _get_lim(self, bunchs):
        if not bunchs:  # pragma: no cover
            return 1
        m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs)
        M = max(abs(m), abs(M))
        return M

    def _get_scaling_value(self):
        return self.feature_scaling

    def _set_scaling_value(self, value):
        self.feature_scaling = value
        self.plot(fixed_channels=True)

    # Public methods
    # -------------------------------------------------------------------------

    def clear_channels(self):
        """Reset the current channels."""
        self.channel_ids = None
        self.plot()

    def get_clusters_data(self, fixed_channels=None, load_all=None):
        # Get the feature data.
        # Specify the channel ids if these are fixed, otherwise
        # choose the first cluster's best channels.
        c = self.channel_ids if fixed_channels else None
        bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids]
        bunchs = [b for b in bunchs if b]
        if not bunchs:  # pragma: no cover
            return []
        for cluster_id, bunch in zip(self.cluster_ids, bunchs):
            bunch.cluster_id = cluster_id

        # Choose the channels based on the first selected cluster.
        channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else []
        common_channels = list(channel_ids)
        # Intersection (with order kept) of channels belonging to all clusters.
        for bunch in bunchs:
            common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels]
        # The selected channels will be (1) the channels common to all clusters, followed
        # by (2) remaining channels from the first cluster (excluding those already selected
        # in (1)).
        n = len(channel_ids)
        not_common_channels = [c for c in channel_ids if c not in common_channels]
        channel_ids = common_channels + not_common_channels[:n - len(common_channels)]
        assert len(channel_ids) > 0

        # Choose the channels automatically unless fixed_channels is set.
        if (not fixed_channels or self.channel_ids is None):
            self.channel_ids = channel_ids
        assert len(self.channel_ids)

        # Channel labels.
        self.channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])])
            self.channel_labels.update({
                channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))})

        return bunchs

    def plot(self, **kwargs):
        """Update the view with the selected clusters."""

        # Determine whether the channels should be fixed or not.
        added = kwargs.get('up', {}).get('added', None)
        # Fix the channels if the view updates after a cluster event
        # and there are new clusters.
        fixed_channels = (
            self.fixed_channels or kwargs.get('fixed_channels', None) or added is not None)

        # Get the clusters data.
        bunchs = self.get_clusters_data(fixed_channels=fixed_channels)
        bunchs = [b for b in bunchs if b]
        if not bunchs:
            return
        self._lim = self._get_lim(bunchs)

        # Get the background data.
        background = self.features(channel_ids=self.channel_ids)

        # Plot all features.
        self._plot_axes()

        # NOTE: the columns in bunch.data are ordered by decreasing quality
        # of the associated channels. The channels corresponding to each
        # column are given in bunch.channel_ids in the same order.

        # Plot points.
        self.visual.reset_batch()
        self.text_visual.reset_batch()

        self._plot_points(background)  # background spikes

        # Plot each cluster.
        for clu_idx, bunch in enumerate(bunchs):
            self._plot_points(bunch, clu_idx=clu_idx)

        # Upload the data on the GPU.
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(FeatureView, self).attach(gui)

        self.actions.add(
            self.toggle_automatic_channel_selection,
            checked=not self.fixed_channels, checkable=True)
        self.actions.add(self.clear_channels)
        self.actions.separator()

    def toggle_automatic_channel_selection(self, checked):
        """Toggle the automatic selection of channels when the cluster selection changes."""
        self.fixed_channels = not checked

    @property
    def status(self):
        if self.channel_ids is None:  # pragma: no cover
            return ''
        channel_labels = [self.channel_labels[ch] for ch in self.channel_ids[:2]]
        return 'channels: %s' % ', '.join(channel_labels)

    # Dimension selection
    # -------------------------------------------------------------------------

    def on_select_channel(self, sender=None, channel_id=None, key=None, button=None):
        """Respond to the click on a channel from another view, and update the
        relevant subplots."""
        channels = self.channel_ids
        if channels is None:
            return
        if len(channels) == 1:
            self.plot()
            return
        assert len(channels) >= 2
        # Get the axis from the pressed button (1, 2, etc.)
        if key is not None:
            d = np.clip(len(channels) - 1, 0, key - 1)
        else:
            d = 0 if button == 'Left' else 1
        # Change the first or second best channel.
        old = channels[d]
        # Avoid updating the view if the channel doesn't change.
        if channel_id == old:
            return
        channels[d] = channel_id
        # Ensure that the first two channels are different.
        if channels[1 - min(d, 1)] == channel_id:
            channels[1 - min(d, 1)] = old
        assert channels[0] != channels[1]
        # Remove duplicate channels.
        self.channel_ids = _uniq(channels)
        logger.debug("Choose channels %d and %d in feature view.", *channels[:2])
        # Fix the channels temporarily.
        self.plot(fixed_channels=True)
        self.update_status()

    def on_mouse_click(self, e):
        """Select a feature dimension by clicking on a box in the feature view."""
        b = e.button
        if 'Alt' in e.modifiers:
            # Get mouse position in NDC.
            (i, j), _ = self.canvas.grid.box_map(e.pos)
            dim = self.grid_dim[i][j]
            dim_x, dim_y = dim.split(',')
            dim = dim_x if b == 'Left' else dim_y
            other_dim = dim_y if b == 'Left' else dim_x
            if dim not in self.attributes:
                # When a regular (channel, PC) dimension is selected.
                channel_pc = self._get_channel_and_pc(dim)
                if channel_pc is None:
                    return
                channel_id, pc = channel_pc
                logger.debug("Click on feature dim %s, channel id %s, PC %s.", dim, channel_id, pc)
            else:
                # When the selected dimension is an attribute, e.g. "time".
                pc = None
                # Take the channel id in the other dimension.
                channel_pc = self._get_channel_and_pc(other_dim)
                channel_id = channel_pc[0] if channel_pc is not None else None
                logger.debug("Click on feature dim %s.", dim)
            emit('select_feature', self, dim=dim, channel_id=channel_id, pc=pc)

    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.canvas.layout.active_box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True)
            if not bunch:
                continue
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()
        return np.unique(spike_ids[ind])
コード例 #12
0
class ProbeView(ManualClusteringView):
    """This view displays the positions of all channels on the probe, highlighting channels
    where the selected clusters belong.

    Constructor
    -----------

    positions : array-like
        An `(n_channels, 2)` array with the channel positions
    best_channels : function
        Maps `cluster_id` to the list of the best_channel_ids.
    channel_labels : list
        List of channel label strings.
    dead_channels : list
        List of dead channel ids.

    """

    _default_position = 'right'

    # Marker size of channels without selected clusters.
    unselected_marker_size = 10

    # Marker size of channels with selected clusters.
    selected_marker_size = 15

    # Alpha value of the dead channels.
    dead_channel_alpha = .25

    do_show_labels = False

    def __init__(
            self, positions=None, best_channels=None, channel_labels=None,
            dead_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)
        self.state_attrs += ('do_show_labels',)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        positions = positions.astype(np.float32)
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)]
        self.dead_channels = dead_channels if dead_channels is not None else ()

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        color = np.ones((self.n_channels, 4))
        color[:, :3] = .5
        # Change alpha value for dead channels.
        if len(self.dead_channels):
            color[self.dead_channels, 3] = self.dead_channel_alpha
        self.probe_visual.set_data(
            pos=self.positions, data_bounds=self.data_bounds,
            color=color, size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

        # Text visual
        color[:] = 1
        color[self.dead_channels, :3] = self.dead_channel_alpha * 2
        self.text_visual = TextVisual()
        self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_channels >= 200 * u_zoom.y) && '
            '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
        self.text_visual.set_data(
            pos=self.positions, text=self.channel_labels, anchor=[0, -1],
            data_bounds=self.data_bounds, color=color
        )
        self.text_visual.program['n_channels'] = self.n_channels
        self.canvas.update()

    def _get_clu_positions(self, cluster_ids):
        """Get the positions of the channels containing selected clusters."""

        # List of channels per cluster.
        cluster_channels = {i: self.best_channels(cl) for i, cl in enumerate(cluster_ids)}

        # List of clusters per channel.
        clusters_per_channel = defaultdict(lambda: [])
        for clu_idx, channels in cluster_channels.items():
            for channel in channels:
                clusters_per_channel[channel].append(clu_idx)

        # Enumerate the discs for each channel.
        w = self.data_bounds[2] - self.data_bounds[0]
        clu_pos = []
        clu_colors = []
        for channel_id, (x, y) in enumerate(self.positions):
            for i, clu_idx in enumerate(clusters_per_channel[channel_id]):
                n = len(clusters_per_channel[channel_id])
                # Translation.
                t = .025 * w * (i - .5 * (n - 1))
                x += t
                alpha = 1.0 if channel_id not in self.dead_channels else self.dead_channel_alpha
                clu_pos.append((x, y))
                clu_colors.append(selected_cluster_color(clu_idx, alpha=alpha))
        return np.array(clu_pos), np.array(clu_colors)

    def on_select(self, cluster_ids=(), **kwargs):
        """Update the view with the selected clusters."""
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        pos, colors = self._get_clu_positions(cluster_ids)
        self.cluster_visual.set_data(
            pos=pos, color=colors, size=self.selected_marker_size, data_bounds=self.data_bounds)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(ProbeView, self).attach(gui)
        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)

        if not self.do_show_labels:
            self.text_visual.hide()

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual._hidden = not checked
        self.canvas.update()
コード例 #13
0
ファイル: probe.py プロジェクト: yingluo227/Labtools
class ProbeView(ManualClusteringView):
    """This view displays the positions of all channels on the probe, highlighting channels
    where the selected clusters belong.

    Constructor
    -----------

    positions : array-like
        An `(n_channels, 2)` array with the channel positions
    best_channels : function
        Maps `cluster_id` to the list of the best_channel_ids.

    """

    _default_position = 'right'

    # Marker size of channels without selected clusters.
    unselected_marker_size = 10

    # Marker size of channels with selected clusters.
    selected_marker_size = 15

    def __init__(self, positions=None, best_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        self.probe_visual.set_data(pos=self.positions,
                                   data_bounds=self.data_bounds,
                                   color=(.5, .5, .5, 1.),
                                   size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

    def _get_clu_positions(self, cluster_ids):
        """Get the positions of the channels containing selected clusters."""

        # List of channels per cluster.
        cluster_channels = {
            i: self.best_channels(cl)
            for i, cl in enumerate(cluster_ids)
        }

        # List of clusters per channel.
        clusters_per_channel = defaultdict(lambda: [])
        for clu_idx, channels in cluster_channels.items():
            for channel in channels:
                clusters_per_channel[channel].append(clu_idx)

        # Enumerate the discs for each channel.
        w = self.data_bounds[2] - self.data_bounds[0]
        clu_pos = []
        clu_colors = []
        for channel_id, (x, y) in enumerate(self.positions):
            for i, clu_idx in enumerate(clusters_per_channel[channel_id]):
                n = len(clusters_per_channel[channel_id])
                # Translation.
                t = .025 * w * (i - .5 * (n - 1))
                x += t
                clu_pos.append((x, y))
                clu_colors.append(selected_cluster_color(clu_idx))
        return np.array(clu_pos), np.array(clu_colors)

    def on_select(self, cluster_ids=(), **kwargs):
        """Update the view with the selected clusters."""
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        pos, colors = self._get_clu_positions(cluster_ids)
        self.cluster_visual.set_data(pos=pos,
                                     color=colors,
                                     size=self.selected_marker_size,
                                     data_bounds=self.data_bounds)
        self.canvas.update()
コード例 #14
0
ファイル: raster.py プロジェクト: zsong30/phy
class RasterView(MarkerSizeMixin, BaseColorView, BaseGlobalView, ManualClusteringView):
    """This view shows a raster plot of all clusters.

    Constructor
    -----------

    spike_times : array-like
        An `(n_spikes,)` array with the spike times, in seconds.
    spike_clusters : array-like
        An `(n_spikes,)` array with the spike-cluster assignments.
    cluster_ids : array-like
        The list of all clusters to show initially.

    """

    _default_position = 'right'

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'switch_color_scheme': 'shift+wheel',
        'decrease_marker_size': 'ctrl+shift+-',
        'increase_marker_size': 'ctrl+shift++',
        'select_cluster': 'ctrl+click',
        'select_more': 'shift+click',
    }

    def __init__(self, spike_times, spike_clusters, cluster_ids=None, **kwargs):
        self.spike_times = spike_times
        self.n_spikes = len(spike_times)
        self.duration = spike_times[-1] * 1.01
        self.n_clusters = 1

        assert len(spike_clusters) == self.n_spikes
        self.set_spike_clusters(spike_clusters)
        self.set_cluster_ids(cluster_ids)

        super(RasterView, self).__init__(**kwargs)

        self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False)
        self.canvas.enable_axes()

        self.visual = ScatterVisual(
            marker='vbar',
            marker_scaling='''
                point_size = v_size * u_zoom.y + 5.;
                float width = 0.2;
                float height = 0.5;
                vec2 marker_size = point_size * vec2(width, height);
                marker_size.x = clamp(marker_size.x, 1, 20);
            ''',
        )
        self.visual.inserter.insert_vert('''
                gl_PointSize = a_size * u_zoom.y + 5.0;
        ''', 'end')
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))

    # Data-related functions
    # -------------------------------------------------------------------------

    def set_spike_clusters(self, spike_clusters):
        """Set the spike clusters for all spikes."""
        self.spike_clusters = spike_clusters

    def set_cluster_ids(self, cluster_ids):
        """Set the shown clusters, which can be filtered and in any order (from top to bottom)."""
        if cluster_ids is None or not len(cluster_ids):
            return
        self.all_cluster_ids = cluster_ids
        self.n_clusters = len(self.all_cluster_ids)
        # Only keep spikes that belong to the selected clusters.
        self.spike_ids = np.isin(self.spike_clusters, self.all_cluster_ids)

    # Internal plotting functions
    # -------------------------------------------------------------------------

    def _get_x(self):
        """Return the x position of the spikes."""
        return self.spike_times[self.spike_ids]

    def _get_y(self):
        """Return the y position of the spikes, given the relative position of the clusters."""
        return np.zeros(np.sum(self.spike_ids))

    def _get_box_index(self):
        """Return, for every spike, its row in the raster plot. This depends on the ordering
        in self.cluster_ids."""
        cl = self.spike_clusters[self.spike_ids]
        # Sanity check.
        # assert np.all(np.in1d(cl, self.cluster_ids))
        return _index_of(cl, self.all_cluster_ids)

    def _get_color(self, box_index, selected_clusters=None):
        """Return, for every spike, its color, based on its box index."""
        cluster_colors = self.get_cluster_colors(self.all_cluster_ids, alpha=.75)
        # Selected cluster colors.
        if selected_clusters is not None:
            cluster_colors = _add_selected_clusters_colors(
                selected_clusters, self.all_cluster_ids, cluster_colors)
        return cluster_colors[box_index, :]

    # Main methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self):
        """Bounds of the raster plot view."""
        return (0, 0, self.duration, self.n_clusters)

    def update_cluster_sort(self, cluster_ids):
        """Update the order of all clusters."""
        self.all_cluster_ids = cluster_ids
        self.visual.set_box_index(self._get_box_index())
        self.canvas.update()

    def update_color(self):
        """Update the color of the spikes, depending on the selected clusters."""
        box_index = self._get_box_index()
        color = self._get_color(box_index, selected_clusters=self.cluster_ids)
        self.visual.set_color(color)
        self.canvas.update()

    @property
    def status(self):
        return 'Color scheme: %s' % self.color_scheme

    def plot(self, **kwargs):
        """Make the raster plot."""
        if not len(self.spike_clusters):
            return
        x = self._get_x()  # spike times for the selected spikes
        y = self._get_y()  # just 0
        box_index = self._get_box_index()
        color = self._get_color(box_index)
        assert x.shape == y.shape == box_index.shape
        assert color.shape[0] == len(box_index)
        self.data_bounds = self._get_data_bounds()

        self.visual.set_data(
            x=x, y=y, color=color, size=self.marker_size,
            data_bounds=(0, -1, self.duration, 1))
        self.visual.set_box_index(box_index)
        self.canvas.stacked.n_boxes = self.n_clusters
        self._update_axes()
        # self.canvas.stacked.add_boxes(self.canvas)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(RasterView, self).attach(gui)

        self.actions.add(self.increase_marker_size)
        self.actions.add(self.decrease_marker_size)
        self.actions.separator()

    def on_select(self, *args, **kwargs):
        super(RasterView, self).on_select(*args, **kwargs)
        self.update_color()

    def zoom_to_time_range(self, interval):
        """Zoom to a time interval."""
        if not interval:
            return
        t0, t1 = interval
        w = .5 * (t1 - t0)  # half width
        tm = .5 * (t0 + t1)
        w = min(5, w)  # minimum 5s time range
        t0, t1 = tm - w, tm + w
        x0 = -1 + 2 * t0 / self.duration
        x1 = -1 + 2 * t1 / self.duration
        box = (x0, -1, x1, +1)
        self.canvas.panzoom.set_range(box)

    # Interactivity
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking in the raster plot."""
        if 'Control' not in e.modifiers:
            return
        b = e.button
        # Get mouse position in NDC.
        cluster_idx, _ = self.canvas.stacked.box_map(e.pos)
        cluster_id = self.all_cluster_ids[cluster_idx]
        logger.debug("Click on cluster %d with button %s.", cluster_id, b)
        if 'Shift' in e.modifiers:
            emit('select_more', self, [cluster_id])
        else:
            emit('request_select', self, [cluster_id])
コード例 #15
0
    def __init__(self, amplitudes=None, amplitudes_type=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitudes_type', )

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes

        # Rotating property amplitudes types.
        self.amplitudes_types = RotatingProperty()
        for name, value in self.amplitudes.items():
            self.amplitudes_types.add(name, value)
        # Current amplitudes type.
        self.amplitudes_types.set(amplitudes_type)
        assert self.amplitudes_type in self.amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1.

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)),
            Rotate('cw'),
            Scale((1, -1)),
            Translate((2.05, 0)),
        ])
        self.canvas.add_visual(self.hist_visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0)

        # Yellow vertical bar showing the selected time interval.
        self.patch_visual = PatchVisual(primitive_type='triangle_fan')
        self.patch_visual.inserter.insert_vert(
            '''
            const float MIN_INTERVAL_SIZE = 0.01;
            uniform float u_interval_size;
        ''', 'header')
        self.patch_visual.inserter.insert_vert(
            '''
            gl_Position.y = pos_orig.y;

            // The following is used to ensure that (1) the bar width increases with the zoom level
            // but also (2) there is a minimum absolute width so that the bar remains visible
            // at low zoom levels.
            float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x);
            // HACK: the z coordinate is used to store 0 or 1, depending on whether the current
            // vertex is on the left or right edge of the bar.
            gl_Position.x += w * (-1 + 2 * int(a_position.z == 0));

        ''', 'after_transforms')
        self.canvas.add_visual(self.patch_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))
コード例 #16
0
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView):
    """This view displays an amplitude plot for all selected clusters.

    Constructor
    -----------

    amplitudes : dict
        Dictionary `{amplitudes_type: function}`, for different types of amplitudes.

        Each function maps `cluster_ids` to a list
        `[Bunch(amplitudes, spike_ids, spike_times), ...]` for each cluster.
        Use `cluster_id=None` for background amplitudes.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'

    # Alpha channel of the markers in the scatter plot.
    marker_alpha = 1.
    time_range_color = (1., 1., 0., .25)

    # Number of bins in the histogram.
    n_bins = 100

    # Alpha channel of the histogram in the background.
    histogram_alpha = .5

    # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers).
    quantile = .99

    # Size of the histogram, between 0 and 1.
    histogram_scale = .25

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'next_amplitudes_type': 'a',
        'previous_amplitudes_type': 'shift+a',
        'select_x_dim': 'shift+left click',
        'select_y_dim': 'shift+right click',
        'select_time': 'alt+click',
    }

    def __init__(self, amplitudes=None, amplitudes_type=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitudes_type', )

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes

        # Rotating property amplitudes types.
        self.amplitudes_types = RotatingProperty()
        for name, value in self.amplitudes.items():
            self.amplitudes_types.add(name, value)
        # Current amplitudes type.
        self.amplitudes_types.set(amplitudes_type)
        assert self.amplitudes_type in self.amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1.

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)),
            Rotate('cw'),
            Scale((1, -1)),
            Translate((2.05, 0)),
        ])
        self.canvas.add_visual(self.hist_visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0)

        # Yellow vertical bar showing the selected time interval.
        self.patch_visual = PatchVisual(primitive_type='triangle_fan')
        self.patch_visual.inserter.insert_vert(
            '''
            const float MIN_INTERVAL_SIZE = 0.01;
            uniform float u_interval_size;
        ''', 'header')
        self.patch_visual.inserter.insert_vert(
            '''
            gl_Position.y = pos_orig.y;

            // The following is used to ensure that (1) the bar width increases with the zoom level
            // but also (2) there is a minimum absolute width so that the bar remains visible
            // at low zoom levels.
            float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x);
            // HACK: the z coordinate is used to store 0 or 1, depending on whether the current
            // vertex is on the left or right edge of the bar.
            gl_Position.x += w * (-1 + 2 * int(a_position.z == 0));

        ''', 'after_transforms')
        self.canvas.add_visual(self.patch_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))

    def _get_data_bounds(self, bunchs):
        """Compute the data bounds."""
        if not bunchs:  # pragma: no cover
            return (0, 0, self.duration, 1)
        m = min(
            np.quantile(bunch.amplitudes, 1 - self.quantile)
            for bunch in bunchs if len(bunch.amplitudes))
        m = min(0, m)  # ensure ymin <= 0
        M = max(
            np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs
            if len(bunch.amplitudes))
        return (0, m, self.duration, M)

    def _add_histograms(self, bunchs):
        # We do this after get_clusters_data because we need x_max.
        for bunch in bunchs:
            bunch.histogram = _compute_histogram(
                bunch.amplitudes,
                x_min=self.data_bounds[1],
                x_max=self.data_bounds[3],
                n_bins=self.n_bins,
                normalize=True,
                ignore_zeros=True,
            )
        return bunchs

    def show_time_range(self, interval=(0, 0)):
        start, end = interval
        x0 = -1 + 2 * (start / self.duration)
        x1 = -1 + 2 * (end / self.duration)
        xm = .5 * (x0 + x1)
        pos = np.array([
            [xm, -1],
            [xm, +1],
            [xm, +1],
            [xm, -1],
        ])
        self.patch_visual.program['u_interval_size'] = .5 * (x1 - x0)
        self.patch_visual.set_data(pos=pos,
                                   color=self.time_range_color,
                                   depth=[0, 0, 1, 1])
        self.canvas.update()

    def _plot_cluster(self, bunch):
        """Make the scatter plot."""
        ms = self._marker_size
        if not len(bunch.histogram):
            return

        # Histogram in the background.
        self.hist_visual.add_batch_data(hist=bunch.histogram,
                                        ylim=self._ylim,
                                        color=add_alpha(
                                            bunch.color, self.histogram_alpha))

        # Scatter plot.
        self.visual.add_batch_data(pos=bunch.pos,
                                   color=bunch.color,
                                   size=ms,
                                   data_bounds=self.data_bounds)

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids."""
        if not len(self.cluster_ids):
            return
        cluster_ids = list(self.cluster_ids)
        # Don't need the background when splitting.
        if not load_all:
            # Add None cluster which means background spikes.
            cluster_ids = [None] + cluster_ids
        bunchs = self.amplitudes[self.amplitudes_type](cluster_ids,
                                                       load_all=load_all) or ()
        # Add a pos attribute in bunchs in addition to x and y.
        for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
            spike_ids = _as_array(bunch.spike_ids)
            spike_times = _as_array(bunch.spike_times)
            amplitudes = _as_array(bunch.amplitudes)
            assert spike_ids.shape == spike_times.shape == amplitudes.shape
            # Ensure that bunch.pos exists, as it used by the LassoMixin.
            bunch.pos = np.c_[spike_times, amplitudes]
            assert bunch.pos.ndim == 2
            bunch.cluster_id = cluster_id
            bunch.color = (
                selected_cluster_color(i - 1, self.marker_alpha)
                # Background amplitude color.
                if cluster_id is not None else (.5, .5, .5, .5))
        return bunchs

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data(**kwargs)
        if not bunchs:
            return
        self.data_bounds = self._get_data_bounds(bunchs)
        bunchs = self._add_histograms(bunchs)
        # Use the same scale for all histograms.
        self._ylim = max(bunch.histogram.max()
                         for bunch in bunchs) if bunchs else 1.

        self.visual.reset_batch()
        self.hist_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.hist_visual)

        self._update_axes()
        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(AmplitudeView, self).attach(gui)

        # Amplitude type actions.
        def _make_amplitude_action(a):
            def callback():
                self.amplitudes_type = a
                self.plot()

            return callback

        for a in self.amplitudes_types.keys():
            name = 'Change amplitudes type to %s' % a
            self.actions.add(_make_amplitude_action(a),
                             show_shortcut=False,
                             name=name,
                             view_submenu='Change amplitudes type')

        self.actions.add(self.next_amplitudes_type, set_busy=True)
        self.actions.add(self.previous_amplitudes_type, set_busy=True)

    @property
    def status(self):
        return self.amplitudes_type

    @property
    def amplitudes_type(self):
        return self.amplitudes_types.current

    @amplitudes_type.setter
    def amplitudes_type(self, value):
        self.amplitudes_types.set(value)

    def next_amplitudes_type(self):
        """Switch to the next amplitudes type."""
        self.amplitudes_types.next()
        logger.debug("Switch to amplitudes type: %s.",
                     self.amplitudes_types.current)
        self.plot()

    def previous_amplitudes_type(self):
        """Switch to the previous amplitudes type."""
        self.amplitudes_types.previous()
        logger.debug("Switch to amplitudes type: %s.",
                     self.amplitudes_types.current)
        self.plot()

    def on_mouse_click(self, e):
        """Select a time from the amplitude view to display in the trace view."""
        if 'Alt' in e.modifiers:
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            time = Range(NDC, self.data_bounds).apply(mouse_pos)[0][0]
            emit('select_time', self, time)
コード例 #17
0
ファイル: scatter.py プロジェクト: yingluo227/Labtools
class ScatterView(MarkerSizeMixin, LassoMixin, ManualClusteringView):
    """This view displays a scatter plot for all selected clusters.

    Constructor
    -----------

    coords : function
        Maps `cluster_ids` to a list `[Bunch(x, y, spike_ids, data_bounds), ...]` for each cluster.

    """

    _default_position = 'right'

    default_shortcuts = {
        'change_marker_size': 'ctrl+wheel',
    }

    def __init__(self, coords=None, **kwargs):
        super(ScatterView, self).__init__(**kwargs)
        # Save the marker size in the global and local view's config.

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        assert coords
        self.coords = coords
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

    def _plot_cluster(self, bunch):
        ms = self._marker_size
        self.visual.add_batch_data(
            pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds)

    def _get_split_cluster_data(self, bunchs):
        """Get the data when there is one Bunch per cluster."""
        # Add a pos attribute in bunchs in addition to x and y.
        for i, (cluster_id, bunch) in enumerate(zip(self.cluster_ids, bunchs)):
            bunch.cluster_id = cluster_id
            if 'pos' not in bunch:
                assert bunch.x.ndim == 1
                assert bunch.x.shape == bunch.y.shape
                bunch.pos = np.c_[bunch.x, bunch.y]
            assert bunch.pos.ndim == 2
            assert 'spike_ids' in bunch
            bunch.color = selected_cluster_color(i, .75)
        return bunchs

    def _get_collated_cluster_data(self, bunch):
        """Get the data when there is a single Bunch for all selected clusters."""
        assert 'spike_ids' in bunch
        if 'pos' not in bunch:
            assert bunch.x.ndim == 1
            assert bunch.x.shape == bunch.y.shape
            bunch.pos = np.c_[bunch.x, bunch.y]
        assert bunch.pos.ndim == 2
        bunch.color = spike_colors(bunch.spike_clusters, self.cluster_ids)
        return bunch

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids."""
        if not load_all:
            bunchs = self.coords(self.cluster_ids) or ()
        elif 'load_all' in inspect.signature(self.coords).parameters:
            bunchs = self.coords(self.cluster_ids, load_all=load_all) or ()
        else:
            logger.warning(
                "The view `%s` may not load all spikes when using the lasso for splitting.",
                self.__class__.__name__)
            bunchs = self.coords(self.cluster_ids)
        if isinstance(bunchs, dict):
            return [self._get_collated_cluster_data(bunchs)]
        elif isinstance(bunchs, (list, tuple)):
            return self._get_split_cluster_data(bunchs)
        raise ValueError("The output of `coords()` should be either a list of Bunch, or a Bunch.")

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data()
        # Hide the visual if there is no data.
        if not bunchs:
            self.visual.hide()
            self.canvas.update()
            return
        self.data_bounds = self._get_data_bounds(bunchs)

        self.visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.visual)
        self.visual.show()

        self._update_axes()
        self.canvas.update()
コード例 #18
0
class ClusterScatterView(MarkerSizeMixin, BaseColorView, BaseGlobalView,
                         ManualClusteringView):
    """This view shows all clusters in a customizable scatter plot.

    Constructor
    -----------

    cluster_ids : array-like
    cluster_info: function
        Maps cluster_id => Bunch() with attributes.
    bindings: dict
        Maps plot dimension to cluster attributes.

    """
    _default_position = 'right'
    _scaling = 1.
    _default_alpha = .75
    _min_marker_size = 5.0
    _max_marker_size = 30.0
    _dims = ('x_axis', 'y_axis', 'size')

    # NOTE: this is not the actual marker size, but a scaling factor for the normal marker size.
    _marker_size = 1.
    _default_marker_size = 1.

    x_axis = ''
    y_axis = ''
    size = ''
    x_axis_log_scale = False
    y_axis_log_scale = False
    size_log_scale = False

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'switch_color_scheme': 'shift+wheel',
        'select_cluster': 'click',
        'select_more': 'shift+click',
        'add_to_lasso': 'control+left click',
        'clear_lasso': 'control+right click',
    }

    default_snippets = {
        'set_x_axis': 'csx',
        'set_y_axis': 'csy',
        'set_size': 'css',
    }

    def __init__(self,
                 cluster_ids=None,
                 cluster_info=None,
                 bindings=None,
                 **kwargs):
        super(ClusterScatterView, self).__init__(**kwargs)
        self.state_attrs += (
            'scaling',
            'x_axis',
            'y_axis',
            'size',
            'x_axis_log_scale',
            'y_axis_log_scale',
            'size_log_scale',
        )
        self.local_state_attrs += ()

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        bindings = bindings or {}
        self.cluster_info = cluster_info
        # update self.x_axis, y_axis, size
        self.__dict__.update({(k, v)
                              for k, v in bindings.items() if k in self._dims})

        # Size range computed initially so that it doesn't change during the course of the session.
        self._size_min = self._size_max = None

        # Full list of clusters.
        self.all_cluster_ids = cluster_ids

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.label_visual = TextVisual()
        self.canvas.add_visual(self.label_visual,
                               exclude_origins=(self.canvas.panzoom, ))

        self.marker_positions = self.marker_colors = self.marker_sizes = None

    def _update_labels(self):
        self.label_visual.set_data(pos=[[-1, -1], [1, 1]],
                                   text=[self.x_axis, self.y_axis],
                                   anchor=[[1.25, 3], [-3, -1.25]])

    # Data access
    # -------------------------------------------------------------------------

    @property
    def bindings(self):
        return {k: getattr(self, k) for k in self._dims}

    def get_cluster_data(self, cluster_id):
        """Return the data of one cluster."""
        data = self.cluster_info(cluster_id)
        return {k: data.get(v, 0.) for k, v in self.bindings.items()}

    def get_clusters_data(self, cluster_ids):
        """Return the data of a set of clusters, as a dictionary {cluster_id: Bunch}."""
        return {
            cluster_id: self.get_cluster_data(cluster_id)
            for cluster_id in cluster_ids
        }

    def set_cluster_ids(self, all_cluster_ids):
        """Update the cluster data by specifying the list of all cluster ids."""
        self.all_cluster_ids = all_cluster_ids
        if len(all_cluster_ids) == 0:
            return
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    # Data preparation
    # -------------------------------------------------------------------------

    def set_fields(self):
        data = self.cluster_info(self.all_cluster_ids[0])
        self.fields = sorted(data.keys())
        self.fields = [f for f in self.fields if not isinstance(data[f], str)]

    def prepare_data(self):
        """Prepare the marker position, size, and color from the cluster information."""
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    def prepare_position(self):
        """Compute the marker positions."""
        self.cluster_data = self.get_clusters_data(self.all_cluster_ids)

        # Get the list of fields returned by cluster_info.
        self.set_fields()

        # Create the x array.
        x = np.array([
            self.cluster_data[cluster_id]['x_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.x_axis_log_scale:
            x = np.log(1.0 + x - x.min())

        # Create the y array.
        y = np.array([
            self.cluster_data[cluster_id]['y_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.y_axis_log_scale:
            y = np.log(1.0 + y - y.min())

        self.marker_positions = np.c_[x, y]

        # Update the data bounds.
        self.data_bounds = (x.min(), y.min(), x.max(), y.max())

    def prepare_size(self):
        """Compute the marker sizes."""
        size = np.array([
            self.cluster_data[cluster_id]['size'] or 1.
            for cluster_id in self.all_cluster_ids
        ])
        # Log scale for the size.
        if self.size_log_scale:
            size = np.log(1.0 + size - size.min())
        # Find the size range.
        if self._size_min is None:
            self._size_min, self._size_max = size.min(), size.max()
        m, M = self._size_min, self._size_max
        # Normalize the marker size.
        size = (size - m) / ((M - m) or 1.0)  # size is in [0, 1]
        ms, Ms = self._min_marker_size, self._max_marker_size
        size = ms + size * (Ms - ms)  # now, size is in [ms, Ms]
        self.marker_sizes = size

    def prepare_color(self):
        """Compute the marker colors."""
        colors = self.get_cluster_colors(self.all_cluster_ids,
                                         self._default_alpha)
        self.marker_colors = colors

    # Marker size
    # -------------------------------------------------------------------------

    @property
    def marker_size(self):
        """Size of the spike markers, in pixels."""
        return self._marker_size

    @marker_size.setter
    def marker_size(self, val):
        # We override this method so as to use self._marker_size as a scaling factor, not
        # as an actual fixed marker size.
        self._marker_size = val
        self._set_marker_size()
        self.canvas.update()

    def _set_marker_size(self):
        if self.marker_sizes is not None:
            self.visual.set_marker_size(self.marker_sizes * self._marker_size)

    # Plotting functions
    # -------------------------------------------------------------------------

    def update_color(self):
        """Update the cluster colors depending on the current color scheme."""
        self.prepare_color()
        self.visual.set_color(self.marker_colors)
        self.canvas.update()

    def update_select_color(self):
        """Update the cluster colors after the cluster selection changes."""
        if self.marker_colors is None:
            return
        selected_clusters = self.cluster_ids
        if selected_clusters is not None and len(selected_clusters) > 0:
            colors = _add_selected_clusters_colors(selected_clusters,
                                                   self.all_cluster_ids,
                                                   self.marker_colors.copy())
            self.visual.set_color(colors)
            self.canvas.update()

    def plot(self, **kwargs):
        """Make the scatter plot."""
        if self.marker_positions is None:
            self.prepare_data()
        self.visual.set_data(
            pos=self.marker_positions,
            color=self.marker_colors,
            size=self.marker_sizes *
            self._marker_size,  # marker size scaling factor
            data_bounds=self.data_bounds)
        self.canvas.axes.reset_data_bounds(self.data_bounds)
        self.canvas.update()

    def change_bindings(self, **kwargs):
        """Change the bindings."""
        # Ensure the specified fields are valid.
        kwargs = {k: v for k, v in kwargs.items() if v in self.fields}
        assert set(kwargs.keys()) <= set(self._dims)
        # Reset the size scaling.
        if 'size' in kwargs:
            self._size_min = self._size_max = None
        self.__dict__.update(kwargs)
        self._update_labels()
        self.update_status()
        self.prepare_data()
        self.plot()

    def toggle_log_scale(self, dim, checked):
        """Toggle logarithmic scaling for one of the dimensions."""
        self._size_min = None
        setattr(self, '%s_log_scale' % dim, checked)
        self.prepare_data()
        self.plot()
        self.canvas.update()

    def set_x_axis(self, field):
        """Set the dimension for the x axis."""
        self.change_bindings(x_axis=field)

    def set_y_axis(self, field):
        """Set the dimension for the y axis."""
        self.change_bindings(y_axis=field)

    def set_size(self, field):
        """Set the dimension for the marker size."""
        self.change_bindings(size=field)

    # Misc functions
    # -------------------------------------------------------------------------

    def attach(self, gui):
        """Attach the GUI."""
        super(ClusterScatterView, self).attach(gui)

        def _make_action(dim, name):
            def callback():
                self.change_bindings(**{dim: name})

            return callback

        def _make_log_toggle(dim):
            def callback(checked):
                self.toggle_log_scale(dim, checked)

            return callback

        # Change the bindings.
        for dim in self._dims:
            view_submenu = 'Change %s' % dim

            # Change to every cluster info.
            for name in self.fields:
                self.actions.add(_make_action(dim, name),
                                 show_shortcut=False,
                                 name='Change %s to %s' % (dim, name),
                                 view_submenu=view_submenu)

            # Toggle logarithmic scale.
            self.actions.separator(view_submenu=view_submenu)
            self.actions.add(_make_log_toggle(dim),
                             checkable=True,
                             view_submenu=view_submenu,
                             name='Toggle log scale for %s' % dim,
                             show_shortcut=False,
                             checked=getattr(self, '%s_log_scale' % dim))

        self.actions.separator()
        self.actions.add(self.set_x_axis,
                         prompt=True,
                         prompt_default=lambda: self.x_axis)
        self.actions.add(self.set_y_axis,
                         prompt=True,
                         prompt_default=lambda: self.y_axis)
        self.actions.add(self.set_size,
                         prompt=True,
                         prompt_default=lambda: self.size)

        connect(self.on_select)
        connect(self.on_cluster)

        @connect(sender=self.canvas)
        def on_lasso_updated(sender, polygon):
            if len(polygon) < 3:
                return
            pos = range_transform([self.data_bounds], [NDC],
                                  self.marker_positions)
            ind = self.canvas.lasso.in_polygon(pos)
            cluster_ids = self.all_cluster_ids[ind]
            emit("request_select", self, list(cluster_ids))

        @connect(sender=self)
        def on_close_view(view_, gui):
            """Unconnect all events when closing the view."""
            unconnect(self.on_select)
            unconnect(self.on_cluster)
            unconnect(on_lasso_updated)

        if self.all_cluster_ids is not None:
            self.set_cluster_ids(self.all_cluster_ids)
        self._update_labels()

    def on_select(self, *args, **kwargs):
        super(ClusterScatterView, self).on_select(*args, **kwargs)
        self.update_select_color()

    def on_cluster(self, sender, up):
        if 'all_cluster_ids' in up:
            self.set_cluster_ids(up.all_cluster_ids)
            self.plot()

    @property
    def status(self):
        return 'Size: %s. Color scheme: %s.' % (self.size, self.color_scheme)

    # Interactivity
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on its template waveform."""
        if 'Control' in e.modifiers:
            return
        b = e.button
        pos = self.canvas.window_to_ndc(e.pos)
        marker_pos = range_transform([self.data_bounds], [NDC],
                                     self.marker_positions)
        cluster_rel = np.argmin(((marker_pos - pos)**2).sum(axis=1))
        cluster_id = self.all_cluster_ids[cluster_rel]
        logger.debug("Click on cluster %d with button %s.", cluster_id, b)
        if 'Shift' in e.modifiers:
            emit('select_more', self, [cluster_id])
        else:
            emit('request_select', self, [cluster_id])