Beispiel #1
0
class TraceView(ScalingMixin, BaseColorView, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_positions : array-like
        Positions of the channels, used for displaying the channels in the right y order
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    trace_color_0 = (.353, .161, .443)
    trace_color_1 = (.133, .404, .396)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'switch_color_scheme': 'shift+wheel',
        'navigate': 'alt+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'jump_left': 'shift+alt+left',
        'jump_right': 'shift+alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'select_channel_pcA': 'shift+left click',
        'select_channel_pcB': 'shift+right click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None,
            n_channels=None, channel_positions=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        # self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel y ranking.
        self.channel_positions = (
            channel_positions if channel_positions is not None else
            np.c_[np.zeros(n_channels), np.arange(n_channels)])
        # channel_y_ranks[i] is the position of channel #i in the trace view.
        self.channel_y_ranks = np.argsort(np.argsort(self.channel_positions[:, 1]))
        assert self.channel_y_ranks.shape == (n_channels,)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        # Visuals.
        self._create_visuals()

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []
        self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))

    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        # Gradient of color for the traces.
        if self.trace_color_0 and self.trace_color_1:
            self.trace_visual.inserter.insert_frag(
                'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % (
                    self.trace_color_0, self.trace_color_1, self.n_channels), 'end')
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)

    @property
    def stacked(self):
        return self.canvas.stacked

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self.channel_y_ranks
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # Determine the spike color.
        i = bunch.select_index
        c = bunch.spike_cluster
        cs = self.color_schemes.get()
        color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1)

        # We could tweak the color of each spike waveform depending on the template amplitude
        # on each of its best channels.
        # channel_amps = bunch.get('channel_amps', None)
        # if channel_amps is not None:
        #     color = np.tile(color, (n_channels, 1))
        #     assert color.shape == (n_channels, 4)
        #     color[:, 3] = channel_amps

        # The box index depends on the channel.
        box_index = self.channel_y_ranks[bunch.channel_ids]
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=color,
            data_bounds=self.data_bounds,
        )

    def _plot_waveforms(self, waveforms, **kwargs):
        """Plot the waveforms."""
        # waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self.channel_y_ranks[ch]
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def plot(self, update_traces=True, update_waveforms=True):
        if update_waveforms:
            # Load the traces in the interval.
            traces = self.traces(self._interval)

        if update_traces:
            logger.log(5, "Redraw the entire trace view.")
            start, end = self._interval

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        if update_waveforms:
            self._plot_waveforms(traces.get('waveforms', []))

        self._update_axes()
        self.canvas.update()

    def set_interval(self, interval=None):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        if interval != self._interval:
            logger.log(5, "Redraw the entire trace view.")
            self._interval = interval
            emit('is_busy', self, True)
            self.plot(update_traces=True, update_waveforms=True)
            emit('is_busy', self, False)
            emit('time_range_selected', self, interval)
            self.update_status()
        else:
            self.plot(update_traces=False, update_waveforms=True)

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.add(self.jump_right)
        self.actions.add(self.jump_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    @property
    def status(self):
        a, b = self._interval
        return '[{:.2f}s - {:.2f}s]. Color scheme: {}.'.format(a, b, self.color_scheme)

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return getattr(self.canvas.layout, 'origin', Stacked._origin)

    @origin.setter
    def origin(self, value):
        if value is None:
            return
        if self.canvas.layout:
            self.canvas.layout.origin = value
        else:  # pragma: no cover
            logger.warning(
                "Could not set origin to %s because the layout instance was not initialized yet.",
                value)

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'bottom' if self.origin == 'top' else 'top'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def jump_right(self):
        """Jump to right."""
        delay = self.duration * .1
        self.shift(delay)

    def jump_left(self):
        """Jump to left."""
        delay = self.duration * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual.toggle()
        self.canvas.update()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    def update_color(self):
        """Update the view when the color scheme changes."""
        self.plot(update_traces=False, update_waveforms=True)

    # Scaling
    # -------------------------------------------------------------------------

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self.stacked._box_scaling[1]

    @scaling.setter
    def scaling(self, value):
        self.stacked._box_scaling = (self.stacked._box_scaling[0], value)

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value
        self.stacked.update()

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = np.nonzero(self.channel_y_ranks == box_id)[0]
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the select_spike event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('select_spike', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)

        if 'Shift' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0])
            emit('select_channel', self, channel_id=channel_id, button=e.button)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Scroll through the data with alt+wheel."""
        super(TraceView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt',):
            start, end = self._interval
            delay = e.delta * (end - start) * .1
            self.shift(-delay)
Beispiel #2
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveforms_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    ax_color = (.75, .75, .75, 1.)
    tick_size = 5.
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'previous_waveforms_type': 'shift+w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    @property
    def _current_visual(self):
        if self.waveforms_type == 'waveforms':
            return self.waveform_visual
        else:
            return self.waveform_agg_visual

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        # Symmetrize on the y axis.
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def get_clusters_data(self):
        if self.waveforms_type not in self.waveforms:
            return
        bunchs = [
            self.waveforms_types.get()(cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()
        if not bunchs:
            return

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids
        if bunchs[0].data is not None:
            self.wave_duration = bunchs[0].data.shape[1] / float(
                self.sample_rate)
        else:  # pragma: no cover
            self.wave_duration = 1.

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the Boxed box positions as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids))

        self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs)

        self._current_visual.reset_batch()
        self.line_visual.reset_batch()
        self.tick_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.tick_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self._current_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)

        # Only show the current waveform visual.
        if self._current_visual == self.waveform_visual:
            self.waveform_visual.show()
            self.waveform_agg_visual.hide()
        elif self._current_visual == self.waveform_agg_visual:
            self.waveform_agg_visual.show()
            self.waveform_visual.hide()

        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.previous_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    @property
    def status(self):
        return self.waveforms_type

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        self.boxed.expand_box_width()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_box_width()

    @property
    def box_scaling(self):
        return self.boxed._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        self.boxed._box_scaling = value

    def _get_scaling_value(self):
        return self.boxed._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self.boxed._box_scaling
        self.boxed._box_scaling = (w, value)
        self.boxed.update()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        return self.boxed._layout_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        self.boxed._layout_scaling = value

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        self.boxed.expand_layout_width()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_layout_width()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        self.boxed.expand_layout_height()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        self.boxed.shrink_layout_height()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('select_channel',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    @property
    def waveforms_type(self):
        return self.waveforms_types.current

    @waveforms_type.setter
    def waveforms_type(self, value):
        self.waveforms_types.set(value)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        self.waveforms_types.next()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def previous_waveforms_type(self):
        """Switch to the previous waveforms type."""
        self.waveforms_types.previous()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms:
            self.waveforms_types.set('waveforms')
            logger.debug("Switch to raw waveforms.")
            self.plot()
        elif 'mean_waveforms' in self.waveforms:
            self.waveforms_types.set('mean_waveforms')
            logger.debug("Switch to mean waveforms.")
            self.plot()
Beispiel #3
0
class TraceView(ScalingMixin, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_vertical_order : array-like
        Permutation of the channels. This 1D array gives the channel id of all channels from
        top to bottom (or conversely, depending on `origin=top|bottom`).
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None,
            channel_vertical_order=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False
        self._scaling = 1.

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel permutation.
        self._channel_perm = (
            np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order)
        assert self._channel_perm.shape == (n_channels,)
        self._channel_perm = np.argsort(self._channel_perm)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Box and probe scaling.
        self._origin = None

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        # Visuals.
        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.canvas.add_visual(self.text_visual)

        # Make a copy of the initial box pos and size. We'll apply the scaling
        # to these quantities.
        self.box_size = np.array(self.canvas.stacked.box_size)

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []

    @property
    def stacked(self):
        return self.canvas.stacked

    def _permute_channels(self, x, inv=False):
        cp = self._channel_perm
        cp = np.argsort(cp)
        return cp[x]

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self._permute_channels(np.arange(n_ch))
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # The box index depends on the channel.
        box_index = self._permute_channels(bunch.channel_ids)
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=bunch.color,
            data_bounds=self.data_bounds,
        )

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self._permute_channels(ch)
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def set_interval(self, interval=None, change_status=True):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        # Load the traces.
        traces = self.traces(interval)
        self.waveforms = traces.get('waveforms', [])

        if interval != self._interval:
            logger.debug("Redraw the entire trace view.")
            self._interval = interval
            start, end = interval

            # Set the status message.
            if change_status:
                self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end))

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        # Plot the waveforms.
        self.plot()

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def plot(self, **kwargs):
        """Plot the waveforms."""
        waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return self._origin

    @origin.setter
    def origin(self, value):
        self._origin = value
        if self.canvas.layout:
            self.canvas.layout.origin = value

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'top' if self._origin in ('bottom', None) else 'bottom'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.set_interval()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    # Scaling
    # -------------------------------------------------------------------------

    def _apply_scaling(self):
        self.canvas.layout.scaling = (self.canvas.layout.scaling[0], self._scaling)

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self._scaling

    @scaling.setter
    def scaling(self, value):
        self._scaling = value
        self._apply_scaling()

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = self._permute_channels(box_id, inv=True)
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the spike_click event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('spike_click', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)