Exemple #1
0
    def __init__(
            self, templates=None, channel_ids=None, channel_labels=None,
            cluster_ids=None, cluster_color_selector=None, **kwargs):
        super(TemplateView, self).__init__(**kwargs)
        self.state_attrs += ()
        self.local_state_attrs += ('scaling',)

        self.cluster_color_selector = cluster_color_selector

        # Full list of channels.
        self.channel_ids = channel_ids
        self.n_channels = len(channel_ids)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(self.n_channels)])
        assert len(self.channel_labels) == self.n_channels
        # TODO: show channel and cluster labels

        # Full list of clusters.
        if cluster_ids is not None:
            self.set_cluster_ids(cluster_ids)

        self.canvas.set_layout('grid', box_bounds=[[-1, -1, +1, +1]], has_clip=False)
        self.canvas.enable_axes()
        self.templates = templates

        self.visual = PlotVisual()
        self.canvas.add_visual(self.visual)
        self._cluster_box_index = {}  # dict {cluster_id: box_index} used to quickly reorder

        self.select_visual = PlotVisual()
        self.canvas.add_visual(self.select_visual)
Exemple #2
0
    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)
Exemple #3
0
    def __init__(self, cluster_stat=None):
        super(HistogramView, self).__init__()
        self.state_attrs += self._state_attrs
        self.local_state_attrs += self._local_state_attrs
        self.canvas.set_layout(layout='stacked', n_plots=1)
        self.canvas.enable_axes()

        self.cluster_stat = cluster_stat

        self.visual = HistogramVisual()
        self.canvas.add_visual(self.visual)

        self.plot_visual = PlotVisual()
        self.canvas.add_visual(self.plot_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)
Exemple #4
0
    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
Exemple #5
0
    def __init__(self, waveforms=None, waveforms_type=None, **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]])
        self.canvas.enable_axes()

        self._box_scaling = (1., 1.)
        self._probe_scaling = (1., 1.)

        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        assert waveforms
        self.waveforms = waveforms
        self.waveforms_types = list(waveforms.keys())
        # Current waveforms type.
        self.waveforms_type = waveforms_type or self.waveforms_types[0]
        assert self.waveforms_type in waveforms
        assert 'waveforms' in waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)
Exemple #6
0
    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        # Gradient of color for the traces.
        if self.trace_color_0 and self.trace_color_1:
            self.trace_visual.inserter.insert_frag(
                'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % (
                    self.trace_color_0, self.trace_color_1, self.n_channels), 'end')
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
    def __init__(self, model=None):

        super(WaveformClusteringView, self).__init__()
        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

        self.model = model

        self.gain = 0.195
        self.Fs = 30  # kHz

        self.visual = PlotVisual()

        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0)
        self.cluster_ids = None
        self.cluster_id = None
        self.channel_ids = None
        self.wavefs = None
        self.current_channel_idx = None
Exemple #8
0
    def __init__(self, templates=None):
        """
        Typically, the constructor takes as arguments *functions* that take as input
        one or several cluster ids, and return as many Bunch instances which contain
        the data as NumPy arrays. Many such functions are defined in the TemplateController.
        """

        super(MyOpenGLView, self).__init__()
        """
        The View instance contains a special `canvas` object which is a `̀PlotCanvas` instance.
        This class derives from `BaseCanvas` which itself derives from the PyQt5 `QOpenGLWindow`.
        The canvas represents a rectangular black window where you can draw geometric objects
        with OpenGL.

        phy uses the notion of **Layout** that lets you organize graphical elements in different
        subplots. These subplots can be organized in several ways:

        * Grid layout: a `(n_rows, n_cols)` grid of subplots (example: FeatureView).
        * Boxed layout: boxes arbitrarily located (example: WaveformView, using the
          probe geometry)
        * Stacked layout: one column with `n_boxes` subplots (example: TraceView,
          one row per channel)

        In this example, we use the stacked layout, with one subplot per cluster. This number
        will change at each cluster selection, depending on the number of selected clusters.
        But initially, we just use 1 subplot.

        """
        self.canvas.set_layout('stacked', n_plots=1)

        self.templates = templates
        """
        phy uses the notion of **Visual**. This is a graphical element that is represented with
        a single type of graphical element. phy provides many visuals:

        * PlotVisual (plots)
        * ScatterVisual (points with a given marker type and different colors and sizes)
        * LineVisual (for lines segments)
        * HistogramVisual
        * PolygonVisual
        * TextVisual
        * ImageVisual

        Each visual comes with a single OpenGL program, which is defined by a vertex shader
        and a fragment shader. These are programs written in a C-like language called GLSL.
        A visual also comes with a primitive type, which can be points, line segments, or
        triangles. This is all a GPU is able to render, but the position and the color of
        these primitives can be entirely customized in the shaders.

        The vertex shader acts on data arrays represented as NumPy arrays.

        These low-level details are hidden by the visuals abstraction, so it is unlikely that
        you'll ever need to write your own visual.

        In ManualClusteringViews, you typically define one or several visuals. For example
        if you need to add text, you would add `self.text_visual = TextVisual()`.

        """
        self.visual = PlotVisual()
        """
        For internal reasons, you need to add all visuals (empty for now) directly to the
        canvas, in the view's constructor. Later, we will use the `visual.set_data()` method
        to update the visual's data and display something in the figure.

        """
        self.canvas.add_visual(self.visual)
Exemple #9
0
class MyOpenGLView(ManualClusteringView):
    """All OpenGL views derive from ManualClusteringView."""
    def __init__(self, templates=None):
        """
        Typically, the constructor takes as arguments *functions* that take as input
        one or several cluster ids, and return as many Bunch instances which contain
        the data as NumPy arrays. Many such functions are defined in the TemplateController.
        """

        super(MyOpenGLView, self).__init__()
        """
        The View instance contains a special `canvas` object which is a `̀PlotCanvas` instance.
        This class derives from `BaseCanvas` which itself derives from the PyQt5 `QOpenGLWindow`.
        The canvas represents a rectangular black window where you can draw geometric objects
        with OpenGL.

        phy uses the notion of **Layout** that lets you organize graphical elements in different
        subplots. These subplots can be organized in several ways:

        * Grid layout: a `(n_rows, n_cols)` grid of subplots (example: FeatureView).
        * Boxed layout: boxes arbitrarily located (example: WaveformView, using the
          probe geometry)
        * Stacked layout: one column with `n_boxes` subplots (example: TraceView,
          one row per channel)

        In this example, we use the stacked layout, with one subplot per cluster. This number
        will change at each cluster selection, depending on the number of selected clusters.
        But initially, we just use 1 subplot.

        """
        self.canvas.set_layout('stacked', n_plots=1)

        self.templates = templates
        """
        phy uses the notion of **Visual**. This is a graphical element that is represented with
        a single type of graphical element. phy provides many visuals:

        * PlotVisual (plots)
        * ScatterVisual (points with a given marker type and different colors and sizes)
        * LineVisual (for lines segments)
        * HistogramVisual
        * PolygonVisual
        * TextVisual
        * ImageVisual

        Each visual comes with a single OpenGL program, which is defined by a vertex shader
        and a fragment shader. These are programs written in a C-like language called GLSL.
        A visual also comes with a primitive type, which can be points, line segments, or
        triangles. This is all a GPU is able to render, but the position and the color of
        these primitives can be entirely customized in the shaders.

        The vertex shader acts on data arrays represented as NumPy arrays.

        These low-level details are hidden by the visuals abstraction, so it is unlikely that
        you'll ever need to write your own visual.

        In ManualClusteringViews, you typically define one or several visuals. For example
        if you need to add text, you would add `self.text_visual = TextVisual()`.

        """
        self.visual = PlotVisual()
        """
        For internal reasons, you need to add all visuals (empty for now) directly to the
        canvas, in the view's constructor. Later, we will use the `visual.set_data()` method
        to update the visual's data and display something in the figure.

        """
        self.canvas.add_visual(self.visual)

    def on_select(self, cluster_ids=(), **kwargs):
        """
        The main method to implement in ManualClusteringView is `on_select()`, called whenever
        new clusters are selected.

        *Note*: `cluster_ids` contains the clusters selected in the cluster view, followed
        by clusters selected in the similarity view.

        """
        """
        This method should always start with these few lines of code.
        """
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        """
        We update the number of boxes in the stacked layout, which is the number of
        selected clusters.
        """
        self.canvas.stacked.n_boxes = len(cluster_ids)
        """
        We obtain the template data.
        """
        bunchs = {
            cluster_id: self.templates(cluster_id).data
            for cluster_id in cluster_ids
        }
        """
        For performance reasons, it is best to use as few visuals as possible. In this example,
        we want 1 waveform template per subplot. We will use a single visual covering all
        subplots at once. This is the key to achieve good performance with OpenGL in Python.
        However, this comes with the drawback that the programming interface is more complicated.

        In principle, we would have to concatenate all data (x and y coordinates) of all subplots
        to pass it to `self.visual.set_data()` in order to draw all subplots at once. But this
        is tedious.

        phy uses the notion of **batch**: for each subplot, we set *partial data* for the subplot
        which just prepares the data for concatenation *after* we're done with looping through
        all clusters. The concatenation happens in the special call
        `self.canvas.update_visual(self.visual)`.

        We need to call `visual.reset_batch()` before constructing a batch.

        """
        self.visual.reset_batch()
        """
        We iterate through all selected clusters.
        """
        for idx, cluster_id in enumerate(cluster_ids):
            bunch = bunchs[cluster_id]
            """
            In this example, we just keep the peak channel. Note that `bunch.template` is a
            2D array `(n_samples, n_channels)` where `n_channels` in the number of "best"
            channels for the cluster. The channels are sorted by decreasing template amplitude,
            so the first one is the peak channel. The channel ids can be found in
            `bunch.channel_ids`.
            """
            y = bunch.template[:, 0]
            """
            We decide to use, on the x axis, values ranging from -1 to 1. This is the
            standard viewport in OpenGL and phy.
            """
            x = np.linspace(-1., 1., len(y))
            """
            phy requires you to specify explicitly the x and y range of the plots.
            The `data_bounds` variable is a `(xmin, ymin, xmax, ymax)` tuple representing the
            lower-left and upper-right corners of a rectangle. By default, the data bounds
            of the entire view is (-1, -1, 1, 1), also called normalized device coordinates.
            Eventually, OpenGL uses this coordinate system for display, but phy provides
            a transform system to convert from different coordinate systems, both on the CPU
            and the GPU.

            Here, the x range is (-1, 1), and the y range is (m, M) where m and M are
            respectively the min and max of the template.
            """
            m, M = y.min(), y.max()
            data_bounds = (-1, m, +1, M)
            """
            This function gives the color of the i-th selected cluster. This is a 4-tuple with
            values between 0 and 1 for RGBA: red, green, blue, alpha channel (transparency,
            1 by default).
            """
            color = selected_cluster_color(idx)
            """
            The plot visual takes as input the x and y coordinates of the points, the color,
            and the data bounds.
            There is also a special keyword argument `box_index` which is the subplot index.
            In the stacked layout, this is just an integer identifying the subplot index, from
            top to bottom. Note that in the grid view, the box index is a pair (row, col).
            """
            self.visual.add_batch_data(x=x,
                                       y=y,
                                       color=color,
                                       data_bounds=data_bounds,
                                       box_index=idx)
        """
        After the loop, this special call automatically builds the data to upload to the GPU
        by concatenating the partial data set in `add_batch_data()`.
        """
        self.canvas.update_visual(self.visual)
        """
        After updating the data on the GPU, we need to refresh the canvas.
        """
        self.canvas.update()
Exemple #10
0
class HistogramView(ScalingMixin, ManualClusteringView):
    """This view displays a histogram for every selected cluster, along with a possible plot
    and some text. To be overriden.

    Constructor
    -----------

    cluster_stat : function
        Maps `cluster_id` to `Bunch(data (1D array), plot (1D array), text)`.

    """

    _default_position = 'right'
    cluster_ids = ()

    # Number of bins in the histogram.
    n_bins = 100

    # Minimum value on the x axis (determines the range of the histogram)
    # If None, then `data.min()` is used.
    x_min = None

    # Maximum value on the x axis (determines the range of the histogram)
    # If None, then `data.max()` is used.
    x_max = None

    # Unit of the bin in the set_bin_size, set_x_min, set_x_max actions.
    bin_unit = 's'  # s (seconds) or ms (milliseconds)

    # The snippet to update this view are `hn` to change the number of bins, and `hm` to
    # change the maximum value on the x axis. The character `h` can be customized by child classes.
    alias_char = 'h'

    default_shortcuts = {
        'change_window_size': 'ctrl+wheel',
    }

    default_snippets = {
        'set_n_bins': '%sn' % alias_char,
        'set_bin_size (%s)' % bin_unit: '%sb' % alias_char,
        'set_x_min (%s)' % bin_unit: '%smin' % alias_char,
        'set_x_max (%s)' % bin_unit: '%smax' % alias_char,
    }

    _state_attrs = ('n_bins', 'x_min', 'x_max')
    _local_state_attrs = ()

    def __init__(self, cluster_stat=None):
        super(HistogramView, self).__init__()
        self.state_attrs += self._state_attrs
        self.local_state_attrs += self._local_state_attrs
        self.canvas.set_layout(layout='stacked', n_plots=1)
        self.canvas.enable_axes()

        self.cluster_stat = cluster_stat

        self.visual = HistogramVisual()
        self.canvas.add_visual(self.visual)

        self.plot_visual = PlotVisual()
        self.canvas.add_visual(self.plot_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)

    def _plot_cluster(self, bunch):
        assert bunch
        n_bins = self.n_bins
        assert n_bins >= 0

        # Update the visual's data.
        self.visual.add_batch_data(
            hist=bunch.histogram, ylim=bunch.ylim, color=bunch.color, box_index=bunch.index)

        # Plot.
        plot = bunch.get('plot', None)
        if plot is not None:
            x = np.linspace(self.x_min, self.x_max, len(plot))
            self.plot_visual.add_batch_data(
                x=x, y=plot, color=(1, 1, 1, 1), data_bounds=self.data_bounds,
                box_index=bunch.index,
            )

        text = bunch.get('text', None)
        if not text:
            return
        # Support multiline text.
        text = text.splitlines()
        n = len(text)
        self.text_visual.add_batch_data(
            text=text, pos=[(-1, .8)] * n, anchor=[(1, -1 - 2 * i) for i in range(n)],
            box_index=bunch.index,
        )

    def get_clusters_data(self, load_all=None):
        bunchs = []
        for i, cluster_id in enumerate(self.cluster_ids):
            bunch = self.cluster_stat(cluster_id)
            if not bunch.data.size:
                continue
            bmin, bmax = bunch.data.min(), bunch.data.max()
            # Update self.x_max if it was not set before.
            self.x_min = self.x_min or bunch.get('x_min', None) or bmin
            self.x_max = self.x_max or bunch.get('x_max', None) or bmax
            self.x_min = min(self.x_min, self.x_max)
            assert self.x_min is not None
            assert self.x_max is not None
            assert self.x_min <= self.x_max

            # Compute the histogram.
            bunch.histogram = _compute_histogram(
                bunch.data, x_min=self.x_min, x_max=self.x_max, n_bins=self.n_bins)
            bunch.ylim = bunch.histogram.max()

            bunch.color = selected_cluster_color(i)
            bunch.index = i
            bunch.cluster_id = cluster_id
            bunchs.append(bunch)
        return bunchs

    def _get_data_bounds(self, bunchs):
        # Get the axes data bounds (the last subplot's extended n_cluster times on the y axis).
        ylim = max(bunch.ylim for bunch in bunchs) if bunchs else 1
        return (self.x_min, 0, self.x_max, ylim * len(self.cluster_ids))

    def plot(self, **kwargs):
        """Update the view with the selected clusters."""
        bunchs = self.get_clusters_data()
        self.data_bounds = self._get_data_bounds(bunchs)

        self.canvas.stacked.n_boxes = len(self.cluster_ids)

        self.visual.reset_batch()
        self.plot_visual.reset_batch()
        self.text_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.plot_visual)
        self.canvas.update_visual(self.text_visual)

        self._update_axes()
        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(HistogramView, self).attach(gui)

        self.actions.add(
            self.set_n_bins, alias=self.alias_char + 'n',
            prompt=True, prompt_default=lambda: self.n_bins)
        self.actions.add(
            self.set_bin_size, alias=self.alias_char + 'b',
            prompt=True, prompt_default=lambda: self.bin_size)
        self.actions.add(
            self.set_x_min, alias=self.alias_char + 'min',
            prompt=True, prompt_default=lambda: self.x_min)
        self.actions.add(
            self.set_x_max, alias=self.alias_char + 'max',
            prompt=True, prompt_default=lambda: self.x_max)
        self.actions.separator()

    @property
    def status(self):
        f = 1 if self.bin_unit == 's' else 1000
        return '[{:.1f}{u}, {:.1f}{u:s}]'.format(
            (self.x_min or 0) * f, (self.x_max or 0) * f, u=self.bin_unit)

    # Histogram parameters
    # -------------------------------------------------------------------------

    def _get_scaling_value(self):
        return self.x_max

    def _set_scaling_value(self, value):
        if self.bin_unit == 'ms':
            value *= 1000
        self.set_x_max(value)

    def set_n_bins(self, n_bins):
        """Set the number of bins in the histogram."""
        self.n_bins = n_bins
        logger.debug("Change number of bins to %d for %s.", n_bins, self.__class__.__name__)
        self.plot()

    @property
    def bin_size(self):
        """Return the bin size (in seconds or milliseconds depending on `self.bin_unit`)."""
        bs = (self.x_max - self.x_min) / self.n_bins
        if self.bin_unit == 'ms':
            bs *= 1000
        return bs

    def set_bin_size(self, bin_size):
        """Set the bin size in the histogram."""
        assert bin_size > 0
        if self.bin_unit == 'ms':
            bin_size /= 1000
        self.n_bins = np.round((self.x_max - self.x_min) / bin_size)
        logger.debug("Change number of bins to %d for %s.", self.n_bins, self.__class__.__name__)
        self.plot()

    def set_x_min(self, x_min):
        """Set the minimum value on the x axis for the histogram."""
        if self.bin_unit == 'ms':
            x_min /= 1000
        x_min = min(x_min, self.x_max)
        if x_min == self.x_max:
            return
        self.x_min = x_min
        logger.debug("Change x min to %s for %s.", x_min, self.__class__.__name__)
        self.plot()

    def set_x_max(self, x_max):
        """Set the maximum value on the x axis for the histogram."""
        if self.bin_unit == 'ms':
            x_max /= 1000
        x_max = max(x_max, self.x_min)
        if x_max == self.x_min:
            return
        self.x_max = x_max
        logger.debug("Change x max to %s for %s.", x_max, self.__class__.__name__)
        self.plot()
Exemple #11
0
class TraceView(ScalingMixin, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_vertical_order : array-like
        Permutation of the channels. This 1D array gives the channel id of all channels from
        top to bottom (or conversely, depending on `origin=top|bottom`).
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None,
            channel_vertical_order=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False
        self._scaling = 1.

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel permutation.
        self._channel_perm = (
            np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order)
        assert self._channel_perm.shape == (n_channels,)
        self._channel_perm = np.argsort(self._channel_perm)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Box and probe scaling.
        self._origin = None

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        # Visuals.
        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.canvas.add_visual(self.text_visual)

        # Make a copy of the initial box pos and size. We'll apply the scaling
        # to these quantities.
        self.box_size = np.array(self.canvas.stacked.box_size)

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []

    @property
    def stacked(self):
        return self.canvas.stacked

    def _permute_channels(self, x, inv=False):
        cp = self._channel_perm
        cp = np.argsort(cp)
        return cp[x]

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self._permute_channels(np.arange(n_ch))
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # The box index depends on the channel.
        box_index = self._permute_channels(bunch.channel_ids)
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=bunch.color,
            data_bounds=self.data_bounds,
        )

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self._permute_channels(ch)
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def set_interval(self, interval=None, change_status=True):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        # Load the traces.
        traces = self.traces(interval)
        self.waveforms = traces.get('waveforms', [])

        if interval != self._interval:
            logger.debug("Redraw the entire trace view.")
            self._interval = interval
            start, end = interval

            # Set the status message.
            if change_status:
                self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end))

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        # Plot the waveforms.
        self.plot()

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def plot(self, **kwargs):
        """Plot the waveforms."""
        waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return self._origin

    @origin.setter
    def origin(self, value):
        self._origin = value
        if self.canvas.layout:
            self.canvas.layout.origin = value

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'top' if self._origin in ('bottom', None) else 'bottom'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.set_interval()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    # Scaling
    # -------------------------------------------------------------------------

    def _apply_scaling(self):
        self.canvas.layout.scaling = (self.canvas.layout.scaling[0], self._scaling)

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self._scaling

    @scaling.setter
    def scaling(self, value):
        self._scaling = value
        self._apply_scaling()

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = self._permute_channels(box_id, inv=True)
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the spike_click event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('spike_click', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)
class WaveformClusteringView(LassoMixin,ManualClusteringView):

    default_shortcuts = {
        'next_channel': 'f',
        'previous_channel': 'r',
    }

    def __init__(self, model=None):

        super(WaveformClusteringView, self).__init__()
        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

        self.model = model

        self.gain = 0.195
        self.Fs = 30  # kHz

        self.visual = PlotVisual()

        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0)
        self.cluster_ids = None
        self.cluster_id = None
        self.channel_ids = None
        self.wavefs = None
        self.current_channel_idx = None
    
    def on_select(self, cluster_ids=(), **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        
        if self.cluster_id != cluster_ids[0]:
            self.cluster_id = cluster_ids[0]

            self.channel_ids = self.model.get_cluster_channels(self.cluster_id)
            self.spike_ids = self.model.get_cluster_spikes(self.cluster_id)
            self.wavefs = self.model.get_waveforms(self.spike_ids,channel_ids=self.channel_ids)

            self.setChannelIdx(0)
        

    def plotWaveforms(self):

        Nspk,Ntime,Nchan = self.wavefs.shape

        self.visual.reset_batch()
        self.text_visual.reset_batch()

        x = np.tile(np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime), (Nspk, 1))

        M=np.max(np.abs(self.wavefs[:,:,self.current_channel_idx]))
        #print(M*self.gain)
        if M*self.gain<100:
            M = 10*np.ceil(M*self.gain/10)
        elif M*self.gain<1000:
            M = 100*np.ceil(M*self.gain/100)
        else:
            M = 1000*np.floor(M*self.gain/1000)
        self.data_bounds = (x[0][0], -M, x[0][-1], M)

        colorwavef = selected_cluster_color(0)
        colormedian = selected_cluster_color(3)#(1,156/256,0,.5)#selected_cluster_color(1)
        colorstd = (0,1,0,1)#selected_cluster_color(2)
        colorqtl = (1,1,0,1)

        if Nspk>100:
            medianCl = np.median(self.wavefs[:,:,self.current_channel_idx],axis=0)
            stdCl = np.std(self.wavefs[:,:,self.current_channel_idx],axis=0)
            q1 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.01,axis=0,interpolation='higher')
            q9 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.99,axis=0,interpolation='lower')

        self.visual.add_batch_data(
                x=x, y=self.gain*self.wavefs[:,:,self.current_channel_idx], color=colorwavef, data_bounds=self.data_bounds, box_index=0)

        #stats
        if Nspk>100:
            x1 = x[0]
            self.visual.add_batch_data(
                    x=x1, y=self.gain*medianCl, color=colormedian, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl+3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl-3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q1, color=colorqtl, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q9, color=colorqtl, data_bounds=self.data_bounds, box_index=0)

        #axes
        self.text_visual.add_batch_data(
                pos=[.9, .98],
                text='[uV]',
                anchor=[-1, -1],
                box_index=0,
            )
        
        self.text_visual.add_batch_data(
                pos=[-1, -.95],
                text='[ms]',
                anchor=[1, 1],
                box_index=0,
            )

        label = 'Ch {a}'.format(a=self.channel_ids[self.current_channel_idx])
        self.text_visual.add_batch_data(
                pos=[-.98, .98],
                text=str(label),
                anchor=[1, -1],
                box_index=0,
            )
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.axes.reset_data_bounds(self.data_bounds)

        self.canvas.update()
    
    def setChannel(self,channel_id):
        self.channel_ids
        itemindex = np.where(self.channel_ids==channel_id)[0]
        if len(itemindex):
            self.setChannelIdx(itemindex[0])

    def setChannelIdx(self,channel_idx):
        self.current_channel_idx = channel_idx
        self.plotWaveforms()

    def setNextChannelIdx(self):
        if self.current_channel_idx == len(self.channel_ids)-1:
            return
        self.setChannelIdx(self.current_channel_idx+1)

    def setPrevChannelIdx(self):
        if self.current_channel_idx == 0:
            return
        self.setChannelIdx(self.current_channel_idx-1)

    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)

        pos = []
        spike_ids = []

        Ntime = self.wavefs.shape[1]
        x = np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime)

        for idx,spike in enumerate(self.spike_ids):
            points = np.c_[x,self.gain*self.wavefs[idx,:,self.current_channel_idx]]
            pos.append(points)
            spike_ids.append([spike]*len(x))

        if not pos:  # pragma: no cover
            logger.warning("Empty lasso.")
            return np.array([])
        pos = np.vstack(pos)
        pos = range_transform([self.data_bounds], [NDC], pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()

        # Return all spikes not lassoed, so the selected cluster is still the same we are working on
        spikes_to_remove = np.unique(spike_ids[ind])
        keepspikes=np.isin(self.spike_ids,spikes_to_remove,assume_unique=True,invert=True)
        A=self.spike_ids[self.spike_ids != spikes_to_remove]
        if len(A)>0:
            return self.spike_ids[keepspikes]
        else:
            return np.array([], dtype=np.int64)
Exemple #13
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveform_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    _default_position = 'right'
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self, waveforms=None, waveforms_type=None, **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]])
        self.canvas.enable_axes()

        self._box_scaling = (1., 1.)
        self._probe_scaling = (1., 1.)

        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        assert waveforms
        self.waveforms = waveforms
        self.waveforms_types = list(waveforms.keys())
        # Current waveforms type.
        self.waveforms_type = waveforms_type or self.waveforms_types[0]
        assert self.waveforms_type in waveforms
        assert 'waveforms' in waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        return [-1, m, +1, M]

    def get_clusters_data(self):
        bunchs = [
            self.waveforms[self.waveforms_type](cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        masks *= .99999
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, n_samples)
        box_index = np.tile(box_index, n_spikes_clu)
        assert box_index.shape == (n_spikes_clu * n_channels * n_samples, )

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

        self.waveform_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the box bounds as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids)
        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        self.data_bounds = self._get_data_bounds(bunchs)

        self.waveform_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.waveform_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)
        self._update_axes(bunchs)
        self.canvas.update()

    def _update_axes(self, bunchs):
        """Update the axes."""
        # Update the axes data bounds.
        _, m, _, M = self.data_bounds
        # Waveform duration, scaled by overlap factor if needed.
        wave_dur = bunchs[0].get('waveform_duration', 1.)
        wave_dur /= .5 * (1 + _overlap_transform(
            1, n=len(self.cluster_ids), overlap=self.overlap))
        x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC,
                                 [wave_dur, M - m])
        axes_data_bounds = (0, 0, x1, y1)
        self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True)

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def _update_boxes(self):
        self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling,
                                       self.box_size)

    def _apply_box_scaling(self):
        self.canvas.layout.scaling = self._box_scaling

    @property
    def box_scaling(self):
        """Scaling of the channel boxes."""
        return self._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        assert len(value) == 2
        self._box_scaling = value
        self._apply_box_scaling()

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w * self._scaling_param_increment, h)
        self._apply_box_scaling()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w / self._scaling_param_increment, h)
        self._apply_box_scaling()

    def _get_scaling_value(self):
        return self._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self._box_scaling
        self.box_scaling = (w, value)
        self._update_boxes()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        """Scaling of the entire probe."""
        return self._probe_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        assert len(value) == 2
        self._probe_scaling = value
        self._update_boxes()

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        w, h = self._probe_scaling
        self._probe_scaling = (w * self._scaling_param_increment, h)
        self._update_boxes()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w / self._scaling_param_increment, h)
        self._update_boxes()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h * self._scaling_param_increment)
        self._update_boxes()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h / self._scaling_param_increment)
        self._update_boxes()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('channel_click',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        i = self.waveforms_types.index(self.waveforms_type)
        n = len(self.waveforms_types)
        self.waveforms_type = self.waveforms_types[(i + 1) % n]
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms':
            self.waveforms_type = 'waveforms'
            self.plot()
        elif 'mean_waveforms' in self.waveforms_types:
            self.waveforms_type = 'mean_waveforms'
            self.plot()
Exemple #14
0
class TraceView(ScalingMixin, BaseColorView, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_positions : array-like
        Positions of the channels, used for displaying the channels in the right y order
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    trace_color_0 = (.353, .161, .443)
    trace_color_1 = (.133, .404, .396)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'switch_color_scheme': 'shift+wheel',
        'navigate': 'alt+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'jump_left': 'shift+alt+left',
        'jump_right': 'shift+alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'select_channel_pcA': 'shift+left click',
        'select_channel_pcB': 'shift+right click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None,
            n_channels=None, channel_positions=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        # self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel y ranking.
        self.channel_positions = (
            channel_positions if channel_positions is not None else
            np.c_[np.zeros(n_channels), np.arange(n_channels)])
        # channel_y_ranks[i] is the position of channel #i in the trace view.
        self.channel_y_ranks = np.argsort(np.argsort(self.channel_positions[:, 1]))
        assert self.channel_y_ranks.shape == (n_channels,)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        # Visuals.
        self._create_visuals()

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []
        self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))

    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        # Gradient of color for the traces.
        if self.trace_color_0 and self.trace_color_1:
            self.trace_visual.inserter.insert_frag(
                'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % (
                    self.trace_color_0, self.trace_color_1, self.n_channels), 'end')
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)

    @property
    def stacked(self):
        return self.canvas.stacked

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self.channel_y_ranks
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # Determine the spike color.
        i = bunch.select_index
        c = bunch.spike_cluster
        cs = self.color_schemes.get()
        color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1)

        # We could tweak the color of each spike waveform depending on the template amplitude
        # on each of its best channels.
        # channel_amps = bunch.get('channel_amps', None)
        # if channel_amps is not None:
        #     color = np.tile(color, (n_channels, 1))
        #     assert color.shape == (n_channels, 4)
        #     color[:, 3] = channel_amps

        # The box index depends on the channel.
        box_index = self.channel_y_ranks[bunch.channel_ids]
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=color,
            data_bounds=self.data_bounds,
        )

    def _plot_waveforms(self, waveforms, **kwargs):
        """Plot the waveforms."""
        # waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self.channel_y_ranks[ch]
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def plot(self, update_traces=True, update_waveforms=True):
        if update_waveforms:
            # Load the traces in the interval.
            traces = self.traces(self._interval)

        if update_traces:
            logger.log(5, "Redraw the entire trace view.")
            start, end = self._interval

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        if update_waveforms:
            self._plot_waveforms(traces.get('waveforms', []))

        self._update_axes()
        self.canvas.update()

    def set_interval(self, interval=None):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        if interval != self._interval:
            logger.log(5, "Redraw the entire trace view.")
            self._interval = interval
            emit('is_busy', self, True)
            self.plot(update_traces=True, update_waveforms=True)
            emit('is_busy', self, False)
            emit('time_range_selected', self, interval)
            self.update_status()
        else:
            self.plot(update_traces=False, update_waveforms=True)

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.add(self.jump_right)
        self.actions.add(self.jump_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    @property
    def status(self):
        a, b = self._interval
        return '[{:.2f}s - {:.2f}s]. Color scheme: {}.'.format(a, b, self.color_scheme)

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return getattr(self.canvas.layout, 'origin', Stacked._origin)

    @origin.setter
    def origin(self, value):
        if value is None:
            return
        if self.canvas.layout:
            self.canvas.layout.origin = value
        else:  # pragma: no cover
            logger.warning(
                "Could not set origin to %s because the layout instance was not initialized yet.",
                value)

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'bottom' if self.origin == 'top' else 'top'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def jump_right(self):
        """Jump to right."""
        delay = self.duration * .1
        self.shift(delay)

    def jump_left(self):
        """Jump to left."""
        delay = self.duration * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual.toggle()
        self.canvas.update()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    def update_color(self):
        """Update the view when the color scheme changes."""
        self.plot(update_traces=False, update_waveforms=True)

    # Scaling
    # -------------------------------------------------------------------------

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self.stacked._box_scaling[1]

    @scaling.setter
    def scaling(self, value):
        self.stacked._box_scaling = (self.stacked._box_scaling[0], value)

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value
        self.stacked.update()

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = np.nonzero(self.channel_y_ranks == box_id)[0]
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the select_spike event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('select_spike', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)

        if 'Shift' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0])
            emit('select_channel', self, channel_id=channel_id, button=e.button)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Scroll through the data with alt+wheel."""
        super(TraceView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt',):
            start, end = self._interval
            delay = e.delta * (end - start) * .1
            self.shift(-delay)
Exemple #15
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveforms_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    ax_color = (.75, .75, .75, 1.)
    tick_size = 5.
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'previous_waveforms_type': 'shift+w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    @property
    def _current_visual(self):
        if self.waveforms_type == 'waveforms':
            return self.waveform_visual
        else:
            return self.waveform_agg_visual

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        # Symmetrize on the y axis.
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def get_clusters_data(self):
        if self.waveforms_type not in self.waveforms:
            return
        bunchs = [
            self.waveforms_types.get()(cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()
        if not bunchs:
            return

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids
        if bunchs[0].data is not None:
            self.wave_duration = bunchs[0].data.shape[1] / float(
                self.sample_rate)
        else:  # pragma: no cover
            self.wave_duration = 1.

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the Boxed box positions as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids))

        self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs)

        self._current_visual.reset_batch()
        self.line_visual.reset_batch()
        self.tick_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.tick_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self._current_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)

        # Only show the current waveform visual.
        if self._current_visual == self.waveform_visual:
            self.waveform_visual.show()
            self.waveform_agg_visual.hide()
        elif self._current_visual == self.waveform_agg_visual:
            self.waveform_agg_visual.show()
            self.waveform_visual.hide()

        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.previous_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    @property
    def status(self):
        return self.waveforms_type

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        self.boxed.expand_box_width()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_box_width()

    @property
    def box_scaling(self):
        return self.boxed._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        self.boxed._box_scaling = value

    def _get_scaling_value(self):
        return self.boxed._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self.boxed._box_scaling
        self.boxed._box_scaling = (w, value)
        self.boxed.update()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        return self.boxed._layout_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        self.boxed._layout_scaling = value

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        self.boxed.expand_layout_width()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_layout_width()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        self.boxed.expand_layout_height()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        self.boxed.shrink_layout_height()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('select_channel',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    @property
    def waveforms_type(self):
        return self.waveforms_types.current

    @waveforms_type.setter
    def waveforms_type(self, value):
        self.waveforms_types.set(value)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        self.waveforms_types.next()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def previous_waveforms_type(self):
        """Switch to the previous waveforms type."""
        self.waveforms_types.previous()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms:
            self.waveforms_types.set('waveforms')
            logger.debug("Switch to raw waveforms.")
            self.plot()
        elif 'mean_waveforms' in self.waveforms:
            self.waveforms_types.set('mean_waveforms')
            logger.debug("Switch to mean waveforms.")
            self.plot()
Exemple #16
0
class TemplateView(ScalingMixin, BaseColorView, BaseGlobalView,
                   ManualClusteringView):
    """This view shows all template waveforms of all clusters in a large grid of shape
    `(n_channels, n_clusters)`.

    Constructor
    -----------

    templates : function
        Maps `cluster_ids` to a list of `[Bunch(template, channel_ids)]` where `template` is
        an `(n_samples, n_channels)` array, and `channel_ids` specifies the channels of the
        `template` array (sparse format).
    channel_ids : array-like
        The list of all channel ids.
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.
    cluster_ids : array-like
        The list of all clusters to show initially.

    """
    _default_position = 'right'
    _scaling = 1.

    default_shortcuts = {
        'change_template_size': 'ctrl+wheel',
        'switch_color_scheme': 'shift+wheel',
        'decrease': 'ctrl+alt+-',
        'increase': 'ctrl+alt++',
        'select_cluster': 'ctrl+click',
        'select_more': 'shift+click',
    }

    def __init__(self,
                 templates=None,
                 channel_ids=None,
                 channel_labels=None,
                 cluster_ids=None,
                 **kwargs):
        super(TemplateView, self).__init__(**kwargs)
        self.state_attrs += ()
        self.local_state_attrs += ('scaling', )

        # Full list of channels.
        self.channel_ids = channel_ids
        self.n_channels = len(channel_ids)

        # Channel labels.
        self.channel_labels = (channel_labels
                               if channel_labels is not None else
                               ['%d' % ch for ch in range(self.n_channels)])
        assert len(self.channel_labels) == self.n_channels
        # TODO: show channel and cluster labels

        # Full list of clusters.
        if cluster_ids is not None:
            self.set_cluster_ids(cluster_ids)

        self.canvas.set_layout('grid', has_clip=False)
        self.canvas.enable_axes()
        self.templates = templates

        self.visual = PlotVisual()
        self.canvas.add_visual(self.visual)
        self._cluster_box_index = {
        }  # dict {cluster_id: box_index} used to quickly reorder

        self.select_visual = PlotVisual()
        self.canvas.add_visual(self.select_visual)

    # Internal plot functions
    # -------------------------------------------------------------------------

    def _get_data_bounds(self, bunchs):
        """Get the data bounds."""
        m = np.median([b.template.min() for b in bunchs])
        M = np.median([b.template.max() for b in bunchs])
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def _get_box_index(self, bunch):
        """Get the box_index array for a cluster."""
        # Generate the box index (channel_idx, cluster_idx) per vertex.
        n_samples, nc = bunch.template.shape
        box_index = _index_of(bunch.channel_ids, self.channel_ids)
        box_index = np.repeat(box_index, n_samples)
        box_index = np.c_[box_index.reshape((-1, 1)),
                          bunch.cluster_idx * np.ones(
                              (n_samples * len(bunch.channel_ids), 1))]
        assert box_index.shape == (len(bunch.channel_ids) * n_samples, 2)
        assert box_index.size == bunch.template.size * 2
        return box_index

    def _plot_cluster(self, bunch, color=None):
        """Plot one cluster."""
        wave = bunch.template  # shape: (n_samples, n_channels)
        channel_ids_loc = bunch.channel_ids
        n_channels_loc = len(channel_ids_loc)

        n_samples, nc = wave.shape
        assert nc == n_channels_loc

        # Find the x coordinates.
        t = get_linear_x(n_channels_loc, n_samples)

        color = color or self.cluster_colors[bunch.cluster_rel]
        assert len(color) == 4

        box_index = self._get_box_index(bunch)

        return Bunch(x=t,
                     y=wave.T,
                     color=color,
                     box_index=box_index,
                     data_bounds=self.data_bounds)

    def set_cluster_ids(self, cluster_ids):
        """Update the cluster ids when their identity or order has changed."""
        if cluster_ids is None or not len(cluster_ids):
            return
        self.all_cluster_ids = np.array(cluster_ids, dtype=np.int32)
        # Permutation of the clusters.
        self.cluster_idxs = np.argsort(self.all_cluster_ids)
        self.sorted_cluster_ids = self.all_cluster_ids[self.cluster_idxs]
        # Cluster colors, ordered by cluster id.
        self.cluster_colors = self.get_cluster_colors(self.sorted_cluster_ids,
                                                      alpha=.75)

    def get_clusters_data(self, load_all=None):
        """Return all templates data."""
        bunchs = self.templates(self.all_cluster_ids)
        out = []
        for cluster_rel, cluster_idx, cluster_id in self._iter_clusters():
            b = bunchs[cluster_id]
            b.cluster_rel = cluster_rel
            b.cluster_idx = cluster_idx
            b.cluster_id = cluster_id
            out.append(b)
        return out

    # Main methods
    # -------------------------------------------------------------------------

    def update_cluster_sort(self, cluster_ids):
        """Update the order of the clusters."""
        if not self._cluster_box_index:  # pragma: no cover
            return self.plot()
        # Only the order of the cluster_ids is supposed to change here.
        # We just have to update box_index instead of replotting everything.
        assert len(cluster_ids) == len(self.all_cluster_ids)
        # Update the cluster ids, in the new order.
        self.all_cluster_ids = np.array(cluster_ids, dtype=np.int32)
        # Update the permutation of the clusters.
        self.cluster_idxs = np.argsort(self.all_cluster_ids)
        box_index = []
        for cluster_rel, cluster_idx in enumerate(self.cluster_idxs):
            cluster_id = self.all_cluster_ids[cluster_idx]
            clu_box_index = self._cluster_box_index[cluster_id]
            clu_box_index[:, 1] = cluster_idx
            box_index.append(clu_box_index)
        box_index = np.concatenate(box_index, axis=0)
        self.visual.set_box_index(box_index)
        self.canvas.update()

    def update_color(self):
        """Update the color of the clusters, taking the selected clusters into account."""
        # This method is only used when the view has been plotted at least once,
        # such that self._cluster_box_index has been filled.
        if not self._cluster_box_index:
            return self.plot()
        # The call to set_cluster_ids() update the cluster_colors array.
        self.set_cluster_ids(self.all_cluster_ids)
        # Selected cluster colors.
        cluster_colors = self.cluster_colors
        selected_clusters = self.cluster_ids
        if selected_clusters is not None:
            cluster_colors = _add_selected_clusters_colors(
                selected_clusters, self.sorted_cluster_ids, cluster_colors)
        # Number of vertices per cluster = number of vertices per signal
        n_vertices_clu = [
            len(self._cluster_box_index[cluster_id])
            for cluster_id in self.sorted_cluster_ids
        ]
        # The argument passed to set_color() must have 1 row per vertex.
        self.visual.set_color(np.repeat(cluster_colors, n_vertices_clu,
                                        axis=0))
        self.canvas.update()

    @property
    def status(self):
        return 'Color scheme: %s' % self.color_schemes.current

    def plot(self, **kwargs):
        """Make the template plot."""

        # Retrieve the waveform data.
        bunchs = self.get_clusters_data()
        if not bunchs:
            return
        n_clusters = len(self.all_cluster_ids)
        self.canvas.grid.shape = (self.n_channels, n_clusters)

        self.visual.reset_batch()
        # Go through all clusters, ordered by cluster id.
        self.data_bounds = self._get_data_bounds(bunchs)
        for bunch in bunchs:
            data = self._plot_cluster(bunch)
            self._cluster_box_index[bunch.cluster_id] = data.box_index
            self.visual.add_batch_data(**data)
        self.canvas.update_visual(self.visual)
        self._apply_scaling()
        self.canvas.axes.reset_data_bounds((0, 0, n_clusters, self.n_channels))
        self.canvas.update()

    def on_select(self, *args, **kwargs):
        super(TemplateView, self).on_select(*args, **kwargs)
        self.update_color()

    # Scaling
    # -------------------------------------------------------------------------

    def _set_scaling_value(self, value):
        self._scaling = value
        self._apply_scaling()

    def _apply_scaling(self):
        sx, sy = self.canvas.layout.scaling
        self.canvas.layout.scaling = (sx, self._scaling)

    @property
    def scaling(self):
        """Return the grid scaling."""
        return self._scaling

    @scaling.setter
    def scaling(self, value):
        self._scaling = value

    # Interactivity
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on its template waveform."""
        if 'Control' not in e.modifiers:
            return
        b = e.button
        # Get mouse position in NDC.
        (channel_idx, cluster_rel), _ = self.canvas.grid.box_map(e.pos)
        cluster_id = self.all_cluster_ids[cluster_rel]
        logger.debug("Click on cluster %d with button %s.", cluster_id, b)
        if 'Shift' in e.modifiers:
            emit('select_more', self, [cluster_id])
        else:
            emit('request_select', self, [cluster_id])
Exemple #17
0
    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None,
            channel_vertical_order=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False
        self._scaling = 1.

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel permutation.
        self._channel_perm = (
            np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order)
        assert self._channel_perm.shape == (n_channels,)
        self._channel_perm = np.argsort(self._channel_perm)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Box and probe scaling.
        self._origin = None

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        # Visuals.
        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.canvas.add_visual(self.text_visual)

        # Make a copy of the initial box pos and size. We'll apply the scaling
        # to these quantities.
        self.box_size = np.array(self.canvas.stacked.box_size)

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []