Exemplo n.º 1
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveforms_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    ax_color = (.75, .75, .75, 1.)
    tick_size = 5.
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'previous_waveforms_type': 'shift+w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    @property
    def _current_visual(self):
        if self.waveforms_type == 'waveforms':
            return self.waveform_visual
        else:
            return self.waveform_agg_visual

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        # Symmetrize on the y axis.
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def get_clusters_data(self):
        if self.waveforms_type not in self.waveforms:
            return
        bunchs = [
            self.waveforms_types.get()(cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()
        if not bunchs:
            return

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids
        if bunchs[0].data is not None:
            self.wave_duration = bunchs[0].data.shape[1] / float(
                self.sample_rate)
        else:  # pragma: no cover
            self.wave_duration = 1.

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the Boxed box positions as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids))

        self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs)

        self._current_visual.reset_batch()
        self.line_visual.reset_batch()
        self.tick_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.tick_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self._current_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)

        # Only show the current waveform visual.
        if self._current_visual == self.waveform_visual:
            self.waveform_visual.show()
            self.waveform_agg_visual.hide()
        elif self._current_visual == self.waveform_agg_visual:
            self.waveform_agg_visual.show()
            self.waveform_visual.hide()

        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.previous_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    @property
    def status(self):
        return self.waveforms_type

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        self.boxed.expand_box_width()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_box_width()

    @property
    def box_scaling(self):
        return self.boxed._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        self.boxed._box_scaling = value

    def _get_scaling_value(self):
        return self.boxed._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self.boxed._box_scaling
        self.boxed._box_scaling = (w, value)
        self.boxed.update()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        return self.boxed._layout_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        self.boxed._layout_scaling = value

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        self.boxed.expand_layout_width()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_layout_width()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        self.boxed.expand_layout_height()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        self.boxed.shrink_layout_height()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('select_channel',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    @property
    def waveforms_type(self):
        return self.waveforms_types.current

    @waveforms_type.setter
    def waveforms_type(self, value):
        self.waveforms_types.set(value)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        self.waveforms_types.next()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def previous_waveforms_type(self):
        """Switch to the previous waveforms type."""
        self.waveforms_types.previous()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms:
            self.waveforms_types.set('waveforms')
            logger.debug("Switch to raw waveforms.")
            self.plot()
        elif 'mean_waveforms' in self.waveforms:
            self.waveforms_types.set('mean_waveforms')
            logger.debug("Switch to mean waveforms.")
            self.plot()
Exemplo n.º 2
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveform_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    _default_position = 'right'
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self, waveforms=None, waveforms_type=None, **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]])
        self.canvas.enable_axes()

        self._box_scaling = (1., 1.)
        self._probe_scaling = (1., 1.)

        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        assert waveforms
        self.waveforms = waveforms
        self.waveforms_types = list(waveforms.keys())
        # Current waveforms type.
        self.waveforms_type = waveforms_type or self.waveforms_types[0]
        assert self.waveforms_type in waveforms
        assert 'waveforms' in waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        return [-1, m, +1, M]

    def get_clusters_data(self):
        bunchs = [
            self.waveforms[self.waveforms_type](cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        masks *= .99999
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, n_samples)
        box_index = np.tile(box_index, n_spikes_clu)
        assert box_index.shape == (n_spikes_clu * n_channels * n_samples, )

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

        self.waveform_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the box bounds as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids)
        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        self.data_bounds = self._get_data_bounds(bunchs)

        self.waveform_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.waveform_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)
        self._update_axes(bunchs)
        self.canvas.update()

    def _update_axes(self, bunchs):
        """Update the axes."""
        # Update the axes data bounds.
        _, m, _, M = self.data_bounds
        # Waveform duration, scaled by overlap factor if needed.
        wave_dur = bunchs[0].get('waveform_duration', 1.)
        wave_dur /= .5 * (1 + _overlap_transform(
            1, n=len(self.cluster_ids), overlap=self.overlap))
        x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC,
                                 [wave_dur, M - m])
        axes_data_bounds = (0, 0, x1, y1)
        self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True)

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def _update_boxes(self):
        self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling,
                                       self.box_size)

    def _apply_box_scaling(self):
        self.canvas.layout.scaling = self._box_scaling

    @property
    def box_scaling(self):
        """Scaling of the channel boxes."""
        return self._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        assert len(value) == 2
        self._box_scaling = value
        self._apply_box_scaling()

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w * self._scaling_param_increment, h)
        self._apply_box_scaling()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w / self._scaling_param_increment, h)
        self._apply_box_scaling()

    def _get_scaling_value(self):
        return self._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self._box_scaling
        self.box_scaling = (w, value)
        self._update_boxes()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        """Scaling of the entire probe."""
        return self._probe_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        assert len(value) == 2
        self._probe_scaling = value
        self._update_boxes()

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        w, h = self._probe_scaling
        self._probe_scaling = (w * self._scaling_param_increment, h)
        self._update_boxes()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w / self._scaling_param_increment, h)
        self._update_boxes()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h * self._scaling_param_increment)
        self._update_boxes()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h / self._scaling_param_increment)
        self._update_boxes()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('channel_click',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        i = self.waveforms_types.index(self.waveforms_type)
        n = len(self.waveforms_types)
        self.waveforms_type = self.waveforms_types[(i + 1) % n]
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms':
            self.waveforms_type = 'waveforms'
            self.plot()
        elif 'mean_waveforms' in self.waveforms_types:
            self.waveforms_type = 'mean_waveforms'
            self.plot()
Exemplo n.º 3
0
class CorrelogramView(ScalingMixin, ManualClusteringView):
    """A view showing the autocorrelogram of the selected clusters, and all cross-correlograms
    of cluster pairs.

    Constructor
    -----------

    correlograms : function
        Maps `(cluster_ids, bin_size, window_size)` to an `(n_clusters, n_clusters, n_bins) array`.

    firing_rate : function
        Maps `(cluster_ids, bin_size)` to an `(n_clusters, n_clusters) array`

    """

    # Do not show too many clusters.
    max_n_clusters = 20

    _default_position = 'left'
    cluster_ids = ()

    # Bin size, in seconds.
    bin_size = 1e-3

    # Window size, in seconds.
    window_size = 50e-3

    # Refactory period, in seconds
    refractory_period = 2e-3

    # Whether the normalization is uniform across entire rows or not.
    uniform_normalization = False

    default_shortcuts = {
        'change_window_size': 'ctrl+wheel',
        'change_bin_size': 'alt+wheel',
    }

    default_snippets = {
        'set_bin': 'cb',
        'set_window': 'cw',
        'set_refractory_period': 'cr',
    }

    def __init__(self,
                 correlograms=None,
                 firing_rate=None,
                 sample_rate=None,
                 **kwargs):
        super(CorrelogramView, self).__init__(**kwargs)
        self.state_attrs += ('bin_size', 'window_size', 'refractory_period',
                             'uniform_normalization')
        self.local_state_attrs += ()
        self.canvas.set_layout(layout='grid')

        # Outside margin to show labels.
        self.canvas.gpu_transforms.add(Scale(.9))

        assert sample_rate > 0
        self.sample_rate = float(sample_rate)

        # Function clusters => CCGs.
        self.correlograms = correlograms

        # Function clusters => firing rates (same unit as CCG).
        self.firing_rate = firing_rate

        # Set the default bin and window size.
        self._set_bin_window(bin_size=self.bin_size,
                             window_size=self.window_size)

        self.correlogram_visual = HistogramVisual()
        self.canvas.add_visual(self.correlogram_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)

    # -------------------------------------------------------------------------
    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self, n_clusters):
        for i in range(n_clusters):
            for j in range(n_clusters):
                yield i, j

    def get_clusters_data(self, load_all=None):
        ccg = self.correlograms(self.cluster_ids, self.bin_size,
                                self.window_size)
        fr = self.firing_rate(self.cluster_ids,
                              self.bin_size) if self.firing_rate else None
        assert ccg.ndim == 3
        n_bins = ccg.shape[2]
        bunchs = []
        m = ccg.max()
        for i, j in self._iter_subplots(len(self.cluster_ids)):
            b = Bunch()
            b.correlogram = ccg[i, j, :]
            if not self.uniform_normalization:
                # Normalization row per row.
                m = ccg[i, j, :].max()
            b.firing_rate = fr[i, j] if fr is not None else None
            b.data_bounds = (0, 0, n_bins, m)
            b.pair_index = i, j
            b.color = selected_cluster_color(i, 1)
            if i != j:
                b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1))
            bunchs.append(b)
        return bunchs

    def _plot_pair(self, bunch):
        # Plot the histogram.
        self.correlogram_visual.add_batch_data(hist=bunch.correlogram,
                                               color=bunch.color,
                                               ylim=bunch.data_bounds[3],
                                               box_index=bunch.pair_index)

        # Plot the firing rate.
        gray = (.25, .25, .25, 1.)
        if bunch.firing_rate is not None:
            # Line.
            pos = np.array([[
                0, bunch.firing_rate, bunch.data_bounds[2], bunch.firing_rate
            ]])
            self.line_visual.add_batch_data(pos=pos,
                                            color=gray,
                                            data_bounds=bunch.data_bounds,
                                            box_index=bunch.pair_index)
            # # Text.
            # self.text_visual.add_batch_data(
            #     pos=[bunch.data_bounds[2], bunch.firing_rate],
            #     text='%.2f' % bunch.firing_rate,
            #     anchor=(-1, 0),
            #     box_index=bunch.pair_index,
            #     data_bounds=bunch.data_bounds,
            # )

        # Refractory period.
        xrp0 = round(
            (self.window_size * .5 - self.refractory_period) / self.bin_size)
        xrp1 = round((self.window_size * .5 + self.refractory_period) /
                     self.bin_size) + 1
        ylim = bunch.data_bounds[3]
        pos = np.array([[xrp0, 0, xrp0, ylim], [xrp1, 0, xrp1, ylim]])
        self.line_visual.add_batch_data(pos=pos,
                                        color=gray,
                                        data_bounds=bunch.data_bounds,
                                        box_index=bunch.pair_index)

    def _plot_labels(self):
        n = len(self.cluster_ids)

        # Display the cluster ids in the subplots.
        for k in range(n):
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(self.cluster_ids[k]),
                anchor=[-1.25, 0],
                data_bounds=None,
                box_index=(k, 0),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1],
                text=str(self.cluster_ids[k]),
                anchor=[0, -1.25],
                data_bounds=None,
                box_index=(n - 1, k),
            )

        # # Display the window size in the bottom right subplot.
        # self.text_visual.add_batch_data(
        #     pos=[1, -1],
        #     anchor=[1.25, 1],
        #     text='%.1f ms' % (1000 * .5 * self.window_size),
        #     box_index=(n - 1, n - 1),
        # )

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        self.canvas.grid.shape = (len(self.cluster_ids), len(self.cluster_ids))

        bunchs = self.get_clusters_data()

        self.correlogram_visual.reset_batch()
        self.line_visual.reset_batch()
        self.text_visual.reset_batch()

        for bunch in bunchs:
            self._plot_pair(bunch)
        self._plot_labels()

        self.canvas.update_visual(self.correlogram_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self.text_visual)

        self.canvas.update()

    # -------------------------------------------------------------------------
    # Public methods
    # -------------------------------------------------------------------------

    def toggle_normalization(self, checked):
        """Change the normalization of the correlograms."""
        self.uniform_normalization = checked
        self.plot()

    def toggle_labels(self, checked):
        """Show or hide all labels."""
        if checked:
            self.text_visual.show()
        else:
            self.text_visual.hide()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(CorrelogramView, self).attach(gui)

        self.actions.add(self.toggle_normalization,
                         shortcut='n',
                         checkable=True)
        self.actions.add(self.toggle_labels, checkable=True, checked=True)
        self.actions.separator()

        self.actions.add(self.set_bin,
                         prompt=True,
                         prompt_default=lambda: self.bin_size * 1000)
        self.actions.add(self.set_window,
                         prompt=True,
                         prompt_default=lambda: self.window_size * 1000)
        self.actions.add(self.set_refractory_period,
                         prompt=True,
                         prompt_default=lambda: self.refractory_period * 1000)
        self.actions.separator()

    # -------------------------------------------------------------------------
    # Methods for changing the parameters
    # -------------------------------------------------------------------------

    def _set_bin_window(self, bin_size=None, window_size=None):
        """Set the bin and window sizes (in seconds)."""
        bin_size = bin_size or self.bin_size
        window_size = window_size or self.window_size
        bin_size = _clip(bin_size, 1e-6, 1e3)
        window_size = _clip(window_size, 1e-6, 1e3)
        assert 1e-6 <= bin_size <= 1e3
        assert 1e-6 <= window_size <= 1e3
        assert bin_size < window_size
        self.bin_size = bin_size
        self.window_size = window_size
        self.update_status()

    @property
    def status(self):
        b, w = self.bin_size * 1000, self.window_size * 1000
        return '{:.1f} ms ({:.1f} ms)'.format(w, b)

    def set_refractory_period(self, value):
        """Set the refractory period (in milliseconds)."""
        self.refractory_period = _clip(value, .1, 100) * 1e-3
        self.plot()

    def set_bin(self, bin_size):
        """Set the correlogram bin size (in milliseconds).

        Example: `1`

        """
        self._set_bin_window(bin_size=bin_size * 1e-3)
        self.plot()

    def set_window(self, window_size):
        """Set the correlogram window size (in milliseconds).

        Example: `100`

        """
        self._set_bin_window(window_size=window_size * 1e-3)
        self.plot()

    def increase(self):
        """Increase the window size."""
        self.set_window(1000 * self.window_size * 1.1)

    def decrease(self):
        """Decrease the window size."""
        self.set_window(1000 * self.window_size / 1.1)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Change the scaling with the wheel."""
        super(CorrelogramView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt', ):
            self._set_bin_window(bin_size=self.bin_size * 1.1**e.delta)
            self.plot()
Exemplo n.º 4
0
class EventMarker(IPlugin):
    # Line color of the event markers
    line_color = (1, 1, 1, 0.75)

    def attach_to_controller(self, controller):
        @connect
        def on_view_attached(view, gui):
            if isinstance(view, AmplitudeView):
                # Create batch of vertical lines (full height)
                self.line_visual = LineVisual()
                _fix_coordinate_in_visual(self.line_visual, 'y')
                view.canvas.add_visual(self.line_visual)

                # Create batch of annotative text
                self.text_visual = TextVisual(self.line_color)
                _fix_coordinate_in_visual(self.text_visual, 'y')
                self.text_visual.inserter.insert_vert(
                    'gl_Position.x += 0.001;', 'after_transforms')
                view.canvas.add_visual(self.text_visual)

                @view.actions.add(shortcut='alt+b',
                                  checkable=True,
                                  name='Toggle event markers')
                def toggle(on):
                    """Toggle event markers"""
                    # Use `show` and `hide` instead of `toggle` here in
                    # case synchronization issues
                    if on:
                        logger.debug('Toggle on markers.')
                        self.line_visual.show()
                        self.text_visual.show()
                        view.show_events = True
                    else:
                        logger.debug('Toggle off markers.')
                        self.line_visual.hide()
                        self.text_visual.hide()
                        view.show_events = False
                    view.canvas.update()

                @view.actions.add(shortcut='shift+alt+e',
                                  prompt=True,
                                  name='Go to event',
                                  alias='ge')
                def Go_to_event(event_num):
                    trace_view = gui.get_view(TraceView)
                    if 0 < event_num <= events.size:
                        trace_view.go_to(events[event_num - 1])

                # Disable the menu until events are successfully added
                view.actions.disable('Go to event')
                view.actions.disable('Toggle event markers')
                if not hasattr(view, 'show_events'):
                    view.show_events = True
                view.state_attrs += ('show_events', )

                # Read event markers from file
                filename = controller.dir_path / 'eventmarkers.txt'
                try:
                    events = np.genfromtxt(filename, usecols=0, dtype=None)
                except (FileNotFoundError, OSError):
                    logger.warn('Event marker file not found: `%s`.', filename)
                    view.show_events = False
                    return

                # Create list of event names
                labels = list(map(str, range(1, events.size + 1)))

                # Read event names from file (if present)
                filename = controller.dir_path / 'eventmarkernames.txt'
                try:
                    eventnames = np.loadtxt(filename,
                                            usecols=0,
                                            dtype=str,
                                            max_rows=events.size)
                    labels[:eventnames.size] = np.atleast_1d(eventnames)
                except (FileNotFoundError, OSError):
                    logger.info(
                        'Event marker names file not found (optional):'
                        ' `%s`. Fall back to numbering.', filename)

                # Obtain seconds from samples
                if events.dtype == int:
                    logger.debug('Converting input from samples to seconds.')
                    events = events / controller.model.sample_rate

                logger.debug('Add event markers to amplitude view.')

                # Obtain horizontal positions
                x = -1 + 2 * events / view.duration
                x = x.repeat(4, 0).reshape(-1, 4)
                x[:, 1::2] = 1, -1

                # Add lines and update view
                self.line_visual.reset_batch()
                self.line_visual.add_batch_data(pos=x, color=self.line_color)
                view.canvas.update_visual(self.line_visual)

                # Add text and update view
                self.text_visual.reset_batch()
                self.text_visual.add_batch_data(pos=x[:, :2],
                                                anchor=(1, -1),
                                                text=labels)
                view.canvas.update_visual(self.text_visual)

                # Finally enable the menu
                logger.debug('Enable menu items.')
                view.actions.enable('Go to event')
                view.actions.enable('Toggle event markers')
                if view.show_events:
                    view.actions.get('Toggle event markers').toggle()
                else:
                    self.line_visual.hide()
                    self.text_visual.hide()