Ejemplo n.º 1
0
    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)
Ejemplo n.º 2
0
    def __init__(self,
                 correlograms=None,
                 firing_rate=None,
                 sample_rate=None,
                 **kwargs):
        super(CorrelogramView, self).__init__(**kwargs)
        self.state_attrs += ('bin_size', 'window_size', 'refractory_period',
                             'uniform_normalization')
        self.local_state_attrs += ()
        self.canvas.set_layout(layout='grid')

        # Outside margin to show labels.
        self.canvas.gpu_transforms.add(Scale(.9))

        assert sample_rate > 0
        self.sample_rate = float(sample_rate)

        # Function clusters => CCGs.
        self.correlograms = correlograms

        # Function clusters => firing rates (same unit as CCG).
        self.firing_rate = firing_rate

        # Set the default bin and window size.
        self._set_bin_window(bin_size=self.bin_size,
                             window_size=self.window_size)

        self.correlogram_visual = HistogramVisual()
        self.canvas.add_visual(self.correlogram_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)
Ejemplo n.º 3
0
    def __init__(self, amplitudes=None, amplitude_name=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitude_name',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes
        self.amplitude_names = list(amplitudes.keys())
        # Current amplitude type.
        self.amplitude_name = amplitude_name or self.amplitude_names[0]
        assert self.amplitude_name in amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add_on_gpu([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')])
        self.canvas.add_visual(self.hist_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        # Amplitude name.
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))
Ejemplo n.º 4
0
    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)
Ejemplo n.º 5
0
    def __init__(
            self, positions=None, best_channels=None, channel_labels=None,
            dead_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)
        self.state_attrs += ('do_show_labels',)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        positions = positions.astype(np.float32)
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)]
        self.dead_channels = dead_channels if dead_channels is not None else ()

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        color = np.ones((self.n_channels, 4))
        color[:, :3] = .5
        # Change alpha value for dead channels.
        if len(self.dead_channels):
            color[self.dead_channels, 3] = self.dead_channel_alpha
        self.probe_visual.set_data(
            pos=self.positions, data_bounds=self.data_bounds,
            color=color, size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

        # Text visual
        color[:] = 1
        color[self.dead_channels, :3] = self.dead_channel_alpha * 2
        self.text_visual = TextVisual()
        self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_channels >= 200 * u_zoom.y) && '
            '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
        self.text_visual.set_data(
            pos=self.positions, text=self.channel_labels, anchor=[0, -1],
            data_bounds=self.data_bounds, color=color
        )
        self.text_visual.program['n_channels'] = self.n_channels
        self.canvas.update()
Ejemplo n.º 6
0
    def __init__(self, cluster_stat=None):
        super(HistogramView, self).__init__()
        self.state_attrs += self._state_attrs
        self.local_state_attrs += self._local_state_attrs
        self.canvas.set_layout(layout='stacked', n_plots=1)
        self.canvas.enable_axes()

        self.cluster_stat = cluster_stat

        self.visual = HistogramVisual()
        self.canvas.add_visual(self.visual)

        self.plot_visual = PlotVisual()
        self.canvas.add_visual(self.plot_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)
Ejemplo n.º 7
0
    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
Ejemplo n.º 8
0
    def __init__(self,
                 cluster_ids=None,
                 cluster_info=None,
                 bindings=None,
                 **kwargs):
        super(ClusterScatterView, self).__init__(**kwargs)
        self.state_attrs += (
            'scaling',
            'x_axis',
            'y_axis',
            'size',
            'x_axis_log_scale',
            'y_axis_log_scale',
            'size_log_scale',
        )
        self.local_state_attrs += ()

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        bindings = bindings or {}
        self.cluster_info = cluster_info
        # update self.x_axis, y_axis, size
        self.__dict__.update({(k, v)
                              for k, v in bindings.items() if k in self._dims})

        # Size range computed initially so that it doesn't change during the course of the session.
        self._size_min = self._size_max = None

        # Full list of clusters.
        self.all_cluster_ids = cluster_ids

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.label_visual = TextVisual()
        self.canvas.add_visual(self.label_visual,
                               exclude_origins=(self.canvas.panzoom, ))

        self.marker_positions = self.marker_colors = self.marker_sizes = None
Ejemplo n.º 9
0
    def __init__(self, waveforms=None, waveforms_type=None, **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]])
        self.canvas.enable_axes()

        self._box_scaling = (1., 1.)
        self._probe_scaling = (1., 1.)

        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        assert waveforms
        self.waveforms = waveforms
        self.waveforms_types = list(waveforms.keys())
        # Current waveforms type.
        self.waveforms_type = waveforms_type or self.waveforms_types[0]
        assert self.waveforms_type in waveforms
        assert 'waveforms' in waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)
Ejemplo n.º 10
0
Archivo: trace.py Proyecto: zsong30/phy
    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        # Gradient of color for the traces.
        if self.trace_color_0 and self.trace_color_1:
            self.trace_visual.inserter.insert_frag(
                'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % (
                    self.trace_color_0, self.trace_color_1, self.n_channels), 'end')
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
Ejemplo n.º 11
0
    def __init__(self, model=None):

        super(WaveformClusteringView, self).__init__()
        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

        self.model = model

        self.gain = 0.195
        self.Fs = 30  # kHz

        self.visual = PlotVisual()

        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0)
        self.cluster_ids = None
        self.cluster_id = None
        self.channel_ids = None
        self.wavefs = None
        self.current_channel_idx = None
Ejemplo n.º 12
0
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView):
    """This view displays an amplitude plot for all selected clusters.

    Constructor
    -----------

    amplitudes : function
        Maps `cluster_ids` to a list `[Bunch(amplitudes, spike_ids), ...]` for each cluster.
        Use `cluster_id=None` for background amplitudes.

    """

    _default_position = 'right'

    # Alpha channel of the markers in the scatter plot.
    marker_alpha = 1.

    # Number of bins in the histogram.
    n_bins = 100

    # Alpha channel of the histogram in the background.
    histogram_alpha = .5

    # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers).
    quantile = .99

    # Size of the histogram, between 0 and 1.
    histogram_scale = .25

    default_shortcuts = {
        'change_marker_size': 'ctrl+wheel',
        'next_amplitude_type': 'a',
        'previous_amplitude_type': 'shift+a',
        'select_x_dim': 'alt+left click',
        'select_y_dim': 'alt+right click',
    }

    def __init__(self, amplitudes=None, amplitude_name=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitude_name',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes
        self.amplitude_names = list(amplitudes.keys())
        # Current amplitude type.
        self.amplitude_name = amplitude_name or self.amplitude_names[0]
        assert self.amplitude_name in amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add_on_gpu([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')])
        self.canvas.add_visual(self.hist_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        # Amplitude name.
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

    def _get_data_bounds(self, bunchs):
        """Compute the data bounds."""
        if not bunchs:  # pragma: no cover
            return (0, 0, self.duration, 1)
        m = min(np.quantile(bunch.amplitudes, 1 - self.quantile) for bunch in bunchs)
        m = min(0, m)  # ensure ymin <= 0
        M = max(np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs)
        return (0, m, self.duration, M)

    def _add_histograms(self, bunchs):
        # We do this after get_clusters_data because we need x_max.
        for bunch in bunchs:
            bunch.histogram = _compute_histogram(
                bunch.amplitudes,
                x_min=self.data_bounds[1],
                x_max=self.data_bounds[3],
                n_bins=self.n_bins,
                normalize=False,
                ignore_zeros=True,
            )
        return bunchs

    def _plot_cluster(self, bunch):
        """Make the scatter plot."""
        ms = self._marker_size

        # Histogram in the background.
        self.hist_visual.add_batch_data(
            hist=bunch.histogram,
            ylim=self._ylim,
            color=add_alpha(bunch.color, self.histogram_alpha))

        # Scatter plot.
        self.visual.add_batch_data(
            pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds)

    def _plot_amplitude_name(self):
        """Show the amplitude name."""
        self.text_visual.add_batch_data(pos=[0, 1], anchor=[0, -1], text=self.amplitude_name)

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids."""
        if not len(self.cluster_ids):
            return
        cluster_ids = list(self.cluster_ids)
        # Don't need the background when splitting.
        if not load_all:
            # Add None cluster which means background spikes.
            cluster_ids = [None] + cluster_ids
        bunchs = self.amplitudes[self.amplitude_name](cluster_ids, load_all=load_all) or ()
        # Add a pos attribute in bunchs in addition to x and y.
        for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
            spike_ids = _as_array(bunch.spike_ids)
            spike_times = _as_array(bunch.spike_times)
            amplitudes = _as_array(bunch.amplitudes)
            assert spike_ids.shape == spike_times.shape == amplitudes.shape
            # Ensure that bunch.pos exists, as it used by the LassoMixin.
            bunch.pos = np.c_[spike_times, amplitudes]
            assert bunch.pos.ndim == 2
            bunch.cluster_id = cluster_id
            bunch.color = (
                selected_cluster_color(i - 1, self.marker_alpha)
                # Background amplitude color.
                if cluster_id is not None else (.5, .5, .5, .5))
        return bunchs

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data()
        if not bunchs:
            return
        self.data_bounds = self._get_data_bounds(bunchs)
        bunchs = self._add_histograms(bunchs)
        # Use the same scale for all histograms.
        self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1.

        self.visual.reset_batch()
        self.hist_visual.reset_batch()
        self.text_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self._plot_amplitude_name()
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.hist_visual)
        self.canvas.update_visual(self.text_visual)

        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(AmplitudeView, self).attach(gui)
        self.actions.add(self.next_amplitude_type, set_busy=True)
        self.actions.add(self.previous_amplitude_type, set_busy=True)

    def _change_amplitude_type(self, dir=+1):
        i = self.amplitude_names.index(self.amplitude_name)
        n = len(self.amplitude_names)
        self.amplitude_name = self.amplitude_names[(i + dir) % n]
        logger.debug("Switch to amplitude type: %s.", self.amplitude_name)
        self.plot()

    def next_amplitude_type(self):
        """Switch to the next amplitude type."""
        self._change_amplitude_type(+1)

    def previous_amplitude_type(self):
        """Switch to the previous amplitude type."""
        self._change_amplitude_type(-1)
Ejemplo n.º 13
0
class HistogramView(ScalingMixin, ManualClusteringView):
    """This view displays a histogram for every selected cluster, along with a possible plot
    and some text. To be overriden.

    Constructor
    -----------

    cluster_stat : function
        Maps `cluster_id` to `Bunch(data (1D array), plot (1D array), text)`.

    """

    _default_position = 'right'
    cluster_ids = ()

    # Number of bins in the histogram.
    n_bins = 100

    # Minimum value on the x axis (determines the range of the histogram)
    # If None, then `data.min()` is used.
    x_min = None

    # Maximum value on the x axis (determines the range of the histogram)
    # If None, then `data.max()` is used.
    x_max = None

    # Unit of the bin in the set_bin_size, set_x_min, set_x_max actions.
    bin_unit = 's'  # s (seconds) or ms (milliseconds)

    # The snippet to update this view are `hn` to change the number of bins, and `hm` to
    # change the maximum value on the x axis. The character `h` can be customized by child classes.
    alias_char = 'h'

    default_shortcuts = {
        'change_window_size': 'ctrl+wheel',
    }

    default_snippets = {
        'set_n_bins': '%sn' % alias_char,
        'set_bin_size (%s)' % bin_unit: '%sb' % alias_char,
        'set_x_min (%s)' % bin_unit: '%smin' % alias_char,
        'set_x_max (%s)' % bin_unit: '%smax' % alias_char,
    }

    _state_attrs = ('n_bins', 'x_min', 'x_max')
    _local_state_attrs = ()

    def __init__(self, cluster_stat=None):
        super(HistogramView, self).__init__()
        self.state_attrs += self._state_attrs
        self.local_state_attrs += self._local_state_attrs
        self.canvas.set_layout(layout='stacked', n_plots=1)
        self.canvas.enable_axes()

        self.cluster_stat = cluster_stat

        self.visual = HistogramVisual()
        self.canvas.add_visual(self.visual)

        self.plot_visual = PlotVisual()
        self.canvas.add_visual(self.plot_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)

    def _plot_cluster(self, bunch):
        assert bunch
        n_bins = self.n_bins
        assert n_bins >= 0

        # Update the visual's data.
        self.visual.add_batch_data(
            hist=bunch.histogram, ylim=bunch.ylim, color=bunch.color, box_index=bunch.index)

        # Plot.
        plot = bunch.get('plot', None)
        if plot is not None:
            x = np.linspace(self.x_min, self.x_max, len(plot))
            self.plot_visual.add_batch_data(
                x=x, y=plot, color=(1, 1, 1, 1), data_bounds=self.data_bounds,
                box_index=bunch.index,
            )

        text = bunch.get('text', None)
        if not text:
            return
        # Support multiline text.
        text = text.splitlines()
        n = len(text)
        self.text_visual.add_batch_data(
            text=text, pos=[(-1, .8)] * n, anchor=[(1, -1 - 2 * i) for i in range(n)],
            box_index=bunch.index,
        )

    def get_clusters_data(self, load_all=None):
        bunchs = []
        for i, cluster_id in enumerate(self.cluster_ids):
            bunch = self.cluster_stat(cluster_id)
            if not bunch.data.size:
                continue
            bmin, bmax = bunch.data.min(), bunch.data.max()
            # Update self.x_max if it was not set before.
            self.x_min = self.x_min or bunch.get('x_min', None) or bmin
            self.x_max = self.x_max or bunch.get('x_max', None) or bmax
            self.x_min = min(self.x_min, self.x_max)
            assert self.x_min is not None
            assert self.x_max is not None
            assert self.x_min <= self.x_max

            # Compute the histogram.
            bunch.histogram = _compute_histogram(
                bunch.data, x_min=self.x_min, x_max=self.x_max, n_bins=self.n_bins)
            bunch.ylim = bunch.histogram.max()

            bunch.color = selected_cluster_color(i)
            bunch.index = i
            bunch.cluster_id = cluster_id
            bunchs.append(bunch)
        return bunchs

    def _get_data_bounds(self, bunchs):
        # Get the axes data bounds (the last subplot's extended n_cluster times on the y axis).
        ylim = max(bunch.ylim for bunch in bunchs) if bunchs else 1
        return (self.x_min, 0, self.x_max, ylim * len(self.cluster_ids))

    def plot(self, **kwargs):
        """Update the view with the selected clusters."""
        bunchs = self.get_clusters_data()
        self.data_bounds = self._get_data_bounds(bunchs)

        self.canvas.stacked.n_boxes = len(self.cluster_ids)

        self.visual.reset_batch()
        self.plot_visual.reset_batch()
        self.text_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.plot_visual)
        self.canvas.update_visual(self.text_visual)

        self._update_axes()
        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(HistogramView, self).attach(gui)

        self.actions.add(
            self.set_n_bins, alias=self.alias_char + 'n',
            prompt=True, prompt_default=lambda: self.n_bins)
        self.actions.add(
            self.set_bin_size, alias=self.alias_char + 'b',
            prompt=True, prompt_default=lambda: self.bin_size)
        self.actions.add(
            self.set_x_min, alias=self.alias_char + 'min',
            prompt=True, prompt_default=lambda: self.x_min)
        self.actions.add(
            self.set_x_max, alias=self.alias_char + 'max',
            prompt=True, prompt_default=lambda: self.x_max)
        self.actions.separator()

    @property
    def status(self):
        f = 1 if self.bin_unit == 's' else 1000
        return '[{:.1f}{u}, {:.1f}{u:s}]'.format(
            (self.x_min or 0) * f, (self.x_max or 0) * f, u=self.bin_unit)

    # Histogram parameters
    # -------------------------------------------------------------------------

    def _get_scaling_value(self):
        return self.x_max

    def _set_scaling_value(self, value):
        if self.bin_unit == 'ms':
            value *= 1000
        self.set_x_max(value)

    def set_n_bins(self, n_bins):
        """Set the number of bins in the histogram."""
        self.n_bins = n_bins
        logger.debug("Change number of bins to %d for %s.", n_bins, self.__class__.__name__)
        self.plot()

    @property
    def bin_size(self):
        """Return the bin size (in seconds or milliseconds depending on `self.bin_unit`)."""
        bs = (self.x_max - self.x_min) / self.n_bins
        if self.bin_unit == 'ms':
            bs *= 1000
        return bs

    def set_bin_size(self, bin_size):
        """Set the bin size in the histogram."""
        assert bin_size > 0
        if self.bin_unit == 'ms':
            bin_size /= 1000
        self.n_bins = np.round((self.x_max - self.x_min) / bin_size)
        logger.debug("Change number of bins to %d for %s.", self.n_bins, self.__class__.__name__)
        self.plot()

    def set_x_min(self, x_min):
        """Set the minimum value on the x axis for the histogram."""
        if self.bin_unit == 'ms':
            x_min /= 1000
        x_min = min(x_min, self.x_max)
        if x_min == self.x_max:
            return
        self.x_min = x_min
        logger.debug("Change x min to %s for %s.", x_min, self.__class__.__name__)
        self.plot()

    def set_x_max(self, x_max):
        """Set the maximum value on the x axis for the histogram."""
        if self.bin_unit == 'ms':
            x_max /= 1000
        x_max = max(x_max, self.x_min)
        if x_max == self.x_min:
            return
        self.x_max = x_max
        logger.debug("Change x max to %s for %s.", x_max, self.__class__.__name__)
        self.plot()
Ejemplo n.º 14
0
class CorrelogramView(ScalingMixin, ManualClusteringView):
    """A view showing the autocorrelogram of the selected clusters, and all cross-correlograms
    of cluster pairs.

    Constructor
    -----------

    correlograms : function
        Maps `(cluster_ids, bin_size, window_size)` to an `(n_clusters, n_clusters, n_bins) array`.

    firing_rate : function
        Maps `(cluster_ids, bin_size)` to an `(n_clusters, n_clusters) array`

    """

    # Do not show too many clusters.
    max_n_clusters = 20

    _default_position = 'left'
    cluster_ids = ()

    # Bin size, in seconds.
    bin_size = 1e-3

    # Window size, in seconds.
    window_size = 50e-3

    # Refactory period, in seconds
    refractory_period = 2e-3

    # Whether the normalization is uniform across entire rows or not.
    uniform_normalization = False

    default_shortcuts = {
        'change_window_size': 'ctrl+wheel',
        'change_bin_size': 'alt+wheel',
    }

    default_snippets = {
        'set_bin': 'cb',
        'set_window': 'cw',
        'set_refractory_period': 'cr',
    }

    def __init__(self,
                 correlograms=None,
                 firing_rate=None,
                 sample_rate=None,
                 **kwargs):
        super(CorrelogramView, self).__init__(**kwargs)
        self.state_attrs += ('bin_size', 'window_size', 'refractory_period',
                             'uniform_normalization')
        self.local_state_attrs += ()
        self.canvas.set_layout(layout='grid')

        # Outside margin to show labels.
        self.canvas.gpu_transforms.add(Scale(.9))

        assert sample_rate > 0
        self.sample_rate = float(sample_rate)

        # Function clusters => CCGs.
        self.correlograms = correlograms

        # Function clusters => firing rates (same unit as CCG).
        self.firing_rate = firing_rate

        # Set the default bin and window size.
        self._set_bin_window(bin_size=self.bin_size,
                             window_size=self.window_size)

        self.correlogram_visual = HistogramVisual()
        self.canvas.add_visual(self.correlogram_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)

    # -------------------------------------------------------------------------
    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self, n_clusters):
        for i in range(n_clusters):
            for j in range(n_clusters):
                yield i, j

    def get_clusters_data(self, load_all=None):
        ccg = self.correlograms(self.cluster_ids, self.bin_size,
                                self.window_size)
        fr = self.firing_rate(self.cluster_ids,
                              self.bin_size) if self.firing_rate else None
        assert ccg.ndim == 3
        n_bins = ccg.shape[2]
        bunchs = []
        m = ccg.max()
        for i, j in self._iter_subplots(len(self.cluster_ids)):
            b = Bunch()
            b.correlogram = ccg[i, j, :]
            if not self.uniform_normalization:
                # Normalization row per row.
                m = ccg[i, j, :].max()
            b.firing_rate = fr[i, j] if fr is not None else None
            b.data_bounds = (0, 0, n_bins, m)
            b.pair_index = i, j
            b.color = selected_cluster_color(i, 1)
            if i != j:
                b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1))
            bunchs.append(b)
        return bunchs

    def _plot_pair(self, bunch):
        # Plot the histogram.
        self.correlogram_visual.add_batch_data(hist=bunch.correlogram,
                                               color=bunch.color,
                                               ylim=bunch.data_bounds[3],
                                               box_index=bunch.pair_index)

        # Plot the firing rate.
        gray = (.25, .25, .25, 1.)
        if bunch.firing_rate is not None:
            # Line.
            pos = np.array([[
                0, bunch.firing_rate, bunch.data_bounds[2], bunch.firing_rate
            ]])
            self.line_visual.add_batch_data(pos=pos,
                                            color=gray,
                                            data_bounds=bunch.data_bounds,
                                            box_index=bunch.pair_index)
            # # Text.
            # self.text_visual.add_batch_data(
            #     pos=[bunch.data_bounds[2], bunch.firing_rate],
            #     text='%.2f' % bunch.firing_rate,
            #     anchor=(-1, 0),
            #     box_index=bunch.pair_index,
            #     data_bounds=bunch.data_bounds,
            # )

        # Refractory period.
        xrp0 = round(
            (self.window_size * .5 - self.refractory_period) / self.bin_size)
        xrp1 = round((self.window_size * .5 + self.refractory_period) /
                     self.bin_size) + 1
        ylim = bunch.data_bounds[3]
        pos = np.array([[xrp0, 0, xrp0, ylim], [xrp1, 0, xrp1, ylim]])
        self.line_visual.add_batch_data(pos=pos,
                                        color=gray,
                                        data_bounds=bunch.data_bounds,
                                        box_index=bunch.pair_index)

    def _plot_labels(self):
        n = len(self.cluster_ids)

        # Display the cluster ids in the subplots.
        for k in range(n):
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(self.cluster_ids[k]),
                anchor=[-1.25, 0],
                data_bounds=None,
                box_index=(k, 0),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1],
                text=str(self.cluster_ids[k]),
                anchor=[0, -1.25],
                data_bounds=None,
                box_index=(n - 1, k),
            )

        # # Display the window size in the bottom right subplot.
        # self.text_visual.add_batch_data(
        #     pos=[1, -1],
        #     anchor=[1.25, 1],
        #     text='%.1f ms' % (1000 * .5 * self.window_size),
        #     box_index=(n - 1, n - 1),
        # )

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        self.canvas.grid.shape = (len(self.cluster_ids), len(self.cluster_ids))

        bunchs = self.get_clusters_data()

        self.correlogram_visual.reset_batch()
        self.line_visual.reset_batch()
        self.text_visual.reset_batch()

        for bunch in bunchs:
            self._plot_pair(bunch)
        self._plot_labels()

        self.canvas.update_visual(self.correlogram_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self.text_visual)

        self.canvas.update()

    # -------------------------------------------------------------------------
    # Public methods
    # -------------------------------------------------------------------------

    def toggle_normalization(self, checked):
        """Change the normalization of the correlograms."""
        self.uniform_normalization = checked
        self.plot()

    def toggle_labels(self, checked):
        """Show or hide all labels."""
        if checked:
            self.text_visual.show()
        else:
            self.text_visual.hide()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(CorrelogramView, self).attach(gui)

        self.actions.add(self.toggle_normalization,
                         shortcut='n',
                         checkable=True)
        self.actions.add(self.toggle_labels, checkable=True, checked=True)
        self.actions.separator()

        self.actions.add(self.set_bin,
                         prompt=True,
                         prompt_default=lambda: self.bin_size * 1000)
        self.actions.add(self.set_window,
                         prompt=True,
                         prompt_default=lambda: self.window_size * 1000)
        self.actions.add(self.set_refractory_period,
                         prompt=True,
                         prompt_default=lambda: self.refractory_period * 1000)
        self.actions.separator()

    # -------------------------------------------------------------------------
    # Methods for changing the parameters
    # -------------------------------------------------------------------------

    def _set_bin_window(self, bin_size=None, window_size=None):
        """Set the bin and window sizes (in seconds)."""
        bin_size = bin_size or self.bin_size
        window_size = window_size or self.window_size
        bin_size = _clip(bin_size, 1e-6, 1e3)
        window_size = _clip(window_size, 1e-6, 1e3)
        assert 1e-6 <= bin_size <= 1e3
        assert 1e-6 <= window_size <= 1e3
        assert bin_size < window_size
        self.bin_size = bin_size
        self.window_size = window_size
        self.update_status()

    @property
    def status(self):
        b, w = self.bin_size * 1000, self.window_size * 1000
        return '{:.1f} ms ({:.1f} ms)'.format(w, b)

    def set_refractory_period(self, value):
        """Set the refractory period (in milliseconds)."""
        self.refractory_period = _clip(value, .1, 100) * 1e-3
        self.plot()

    def set_bin(self, bin_size):
        """Set the correlogram bin size (in milliseconds).

        Example: `1`

        """
        self._set_bin_window(bin_size=bin_size * 1e-3)
        self.plot()

    def set_window(self, window_size):
        """Set the correlogram window size (in milliseconds).

        Example: `100`

        """
        self._set_bin_window(window_size=window_size * 1e-3)
        self.plot()

    def increase(self):
        """Increase the window size."""
        self.set_window(1000 * self.window_size * 1.1)

    def decrease(self):
        """Decrease the window size."""
        self.set_window(1000 * self.window_size / 1.1)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Change the scaling with the wheel."""
        super(CorrelogramView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt', ):
            self._set_bin_window(bin_size=self.bin_size * 1.1**e.delta)
            self.plot()
Ejemplo n.º 15
0
class ClusterScatterView(MarkerSizeMixin, BaseColorView, BaseGlobalView,
                         ManualClusteringView):
    """This view shows all clusters in a customizable scatter plot.

    Constructor
    -----------

    cluster_ids : array-like
    cluster_info: function
        Maps cluster_id => Bunch() with attributes.
    bindings: dict
        Maps plot dimension to cluster attributes.

    """
    _default_position = 'right'
    _scaling = 1.
    _default_alpha = .75
    _min_marker_size = 5.0
    _max_marker_size = 30.0
    _dims = ('x_axis', 'y_axis', 'size')

    # NOTE: this is not the actual marker size, but a scaling factor for the normal marker size.
    _marker_size = 1.
    _default_marker_size = 1.

    x_axis = ''
    y_axis = ''
    size = ''
    x_axis_log_scale = False
    y_axis_log_scale = False
    size_log_scale = False

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'switch_color_scheme': 'shift+wheel',
        'select_cluster': 'click',
        'select_more': 'shift+click',
        'add_to_lasso': 'control+left click',
        'clear_lasso': 'control+right click',
    }

    default_snippets = {
        'set_x_axis': 'csx',
        'set_y_axis': 'csy',
        'set_size': 'css',
    }

    def __init__(self,
                 cluster_ids=None,
                 cluster_info=None,
                 bindings=None,
                 **kwargs):
        super(ClusterScatterView, self).__init__(**kwargs)
        self.state_attrs += (
            'scaling',
            'x_axis',
            'y_axis',
            'size',
            'x_axis_log_scale',
            'y_axis_log_scale',
            'size_log_scale',
        )
        self.local_state_attrs += ()

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        bindings = bindings or {}
        self.cluster_info = cluster_info
        # update self.x_axis, y_axis, size
        self.__dict__.update({(k, v)
                              for k, v in bindings.items() if k in self._dims})

        # Size range computed initially so that it doesn't change during the course of the session.
        self._size_min = self._size_max = None

        # Full list of clusters.
        self.all_cluster_ids = cluster_ids

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.label_visual = TextVisual()
        self.canvas.add_visual(self.label_visual,
                               exclude_origins=(self.canvas.panzoom, ))

        self.marker_positions = self.marker_colors = self.marker_sizes = None

    def _update_labels(self):
        self.label_visual.set_data(pos=[[-1, -1], [1, 1]],
                                   text=[self.x_axis, self.y_axis],
                                   anchor=[[1.25, 3], [-3, -1.25]])

    # Data access
    # -------------------------------------------------------------------------

    @property
    def bindings(self):
        return {k: getattr(self, k) for k in self._dims}

    def get_cluster_data(self, cluster_id):
        """Return the data of one cluster."""
        data = self.cluster_info(cluster_id)
        return {k: data.get(v, 0.) for k, v in self.bindings.items()}

    def get_clusters_data(self, cluster_ids):
        """Return the data of a set of clusters, as a dictionary {cluster_id: Bunch}."""
        return {
            cluster_id: self.get_cluster_data(cluster_id)
            for cluster_id in cluster_ids
        }

    def set_cluster_ids(self, all_cluster_ids):
        """Update the cluster data by specifying the list of all cluster ids."""
        self.all_cluster_ids = all_cluster_ids
        if len(all_cluster_ids) == 0:
            return
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    # Data preparation
    # -------------------------------------------------------------------------

    def set_fields(self):
        data = self.cluster_info(self.all_cluster_ids[0])
        self.fields = sorted(data.keys())
        self.fields = [f for f in self.fields if not isinstance(data[f], str)]

    def prepare_data(self):
        """Prepare the marker position, size, and color from the cluster information."""
        self.prepare_position()
        self.prepare_size()
        self.prepare_color()

    def prepare_position(self):
        """Compute the marker positions."""
        self.cluster_data = self.get_clusters_data(self.all_cluster_ids)

        # Get the list of fields returned by cluster_info.
        self.set_fields()

        # Create the x array.
        x = np.array([
            self.cluster_data[cluster_id]['x_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.x_axis_log_scale:
            x = np.log(1.0 + x - x.min())

        # Create the y array.
        y = np.array([
            self.cluster_data[cluster_id]['y_axis'] or 0.
            for cluster_id in self.all_cluster_ids
        ])
        if self.y_axis_log_scale:
            y = np.log(1.0 + y - y.min())

        self.marker_positions = np.c_[x, y]

        # Update the data bounds.
        self.data_bounds = (x.min(), y.min(), x.max(), y.max())

    def prepare_size(self):
        """Compute the marker sizes."""
        size = np.array([
            self.cluster_data[cluster_id]['size'] or 1.
            for cluster_id in self.all_cluster_ids
        ])
        # Log scale for the size.
        if self.size_log_scale:
            size = np.log(1.0 + size - size.min())
        # Find the size range.
        if self._size_min is None:
            self._size_min, self._size_max = size.min(), size.max()
        m, M = self._size_min, self._size_max
        # Normalize the marker size.
        size = (size - m) / ((M - m) or 1.0)  # size is in [0, 1]
        ms, Ms = self._min_marker_size, self._max_marker_size
        size = ms + size * (Ms - ms)  # now, size is in [ms, Ms]
        self.marker_sizes = size

    def prepare_color(self):
        """Compute the marker colors."""
        colors = self.get_cluster_colors(self.all_cluster_ids,
                                         self._default_alpha)
        self.marker_colors = colors

    # Marker size
    # -------------------------------------------------------------------------

    @property
    def marker_size(self):
        """Size of the spike markers, in pixels."""
        return self._marker_size

    @marker_size.setter
    def marker_size(self, val):
        # We override this method so as to use self._marker_size as a scaling factor, not
        # as an actual fixed marker size.
        self._marker_size = val
        self._set_marker_size()
        self.canvas.update()

    def _set_marker_size(self):
        if self.marker_sizes is not None:
            self.visual.set_marker_size(self.marker_sizes * self._marker_size)

    # Plotting functions
    # -------------------------------------------------------------------------

    def update_color(self):
        """Update the cluster colors depending on the current color scheme."""
        self.prepare_color()
        self.visual.set_color(self.marker_colors)
        self.canvas.update()

    def update_select_color(self):
        """Update the cluster colors after the cluster selection changes."""
        if self.marker_colors is None:
            return
        selected_clusters = self.cluster_ids
        if selected_clusters is not None and len(selected_clusters) > 0:
            colors = _add_selected_clusters_colors(selected_clusters,
                                                   self.all_cluster_ids,
                                                   self.marker_colors.copy())
            self.visual.set_color(colors)
            self.canvas.update()

    def plot(self, **kwargs):
        """Make the scatter plot."""
        if self.marker_positions is None:
            self.prepare_data()
        self.visual.set_data(
            pos=self.marker_positions,
            color=self.marker_colors,
            size=self.marker_sizes *
            self._marker_size,  # marker size scaling factor
            data_bounds=self.data_bounds)
        self.canvas.axes.reset_data_bounds(self.data_bounds)
        self.canvas.update()

    def change_bindings(self, **kwargs):
        """Change the bindings."""
        # Ensure the specified fields are valid.
        kwargs = {k: v for k, v in kwargs.items() if v in self.fields}
        assert set(kwargs.keys()) <= set(self._dims)
        # Reset the size scaling.
        if 'size' in kwargs:
            self._size_min = self._size_max = None
        self.__dict__.update(kwargs)
        self._update_labels()
        self.update_status()
        self.prepare_data()
        self.plot()

    def toggle_log_scale(self, dim, checked):
        """Toggle logarithmic scaling for one of the dimensions."""
        self._size_min = None
        setattr(self, '%s_log_scale' % dim, checked)
        self.prepare_data()
        self.plot()
        self.canvas.update()

    def set_x_axis(self, field):
        """Set the dimension for the x axis."""
        self.change_bindings(x_axis=field)

    def set_y_axis(self, field):
        """Set the dimension for the y axis."""
        self.change_bindings(y_axis=field)

    def set_size(self, field):
        """Set the dimension for the marker size."""
        self.change_bindings(size=field)

    # Misc functions
    # -------------------------------------------------------------------------

    def attach(self, gui):
        """Attach the GUI."""
        super(ClusterScatterView, self).attach(gui)

        def _make_action(dim, name):
            def callback():
                self.change_bindings(**{dim: name})

            return callback

        def _make_log_toggle(dim):
            def callback(checked):
                self.toggle_log_scale(dim, checked)

            return callback

        # Change the bindings.
        for dim in self._dims:
            view_submenu = 'Change %s' % dim

            # Change to every cluster info.
            for name in self.fields:
                self.actions.add(_make_action(dim, name),
                                 show_shortcut=False,
                                 name='Change %s to %s' % (dim, name),
                                 view_submenu=view_submenu)

            # Toggle logarithmic scale.
            self.actions.separator(view_submenu=view_submenu)
            self.actions.add(_make_log_toggle(dim),
                             checkable=True,
                             view_submenu=view_submenu,
                             name='Toggle log scale for %s' % dim,
                             show_shortcut=False,
                             checked=getattr(self, '%s_log_scale' % dim))

        self.actions.separator()
        self.actions.add(self.set_x_axis,
                         prompt=True,
                         prompt_default=lambda: self.x_axis)
        self.actions.add(self.set_y_axis,
                         prompt=True,
                         prompt_default=lambda: self.y_axis)
        self.actions.add(self.set_size,
                         prompt=True,
                         prompt_default=lambda: self.size)

        connect(self.on_select)
        connect(self.on_cluster)

        @connect(sender=self.canvas)
        def on_lasso_updated(sender, polygon):
            if len(polygon) < 3:
                return
            pos = range_transform([self.data_bounds], [NDC],
                                  self.marker_positions)
            ind = self.canvas.lasso.in_polygon(pos)
            cluster_ids = self.all_cluster_ids[ind]
            emit("request_select", self, list(cluster_ids))

        @connect(sender=self)
        def on_close_view(view_, gui):
            """Unconnect all events when closing the view."""
            unconnect(self.on_select)
            unconnect(self.on_cluster)
            unconnect(on_lasso_updated)

        if self.all_cluster_ids is not None:
            self.set_cluster_ids(self.all_cluster_ids)
        self._update_labels()

    def on_select(self, *args, **kwargs):
        super(ClusterScatterView, self).on_select(*args, **kwargs)
        self.update_select_color()

    def on_cluster(self, sender, up):
        if 'all_cluster_ids' in up:
            self.set_cluster_ids(up.all_cluster_ids)
            self.plot()

    @property
    def status(self):
        return 'Size: %s. Color scheme: %s.' % (self.size, self.color_scheme)

    # Interactivity
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on its template waveform."""
        if 'Control' in e.modifiers:
            return
        b = e.button
        pos = self.canvas.window_to_ndc(e.pos)
        marker_pos = range_transform([self.data_bounds], [NDC],
                                     self.marker_positions)
        cluster_rel = np.argmin(((marker_pos - pos)**2).sum(axis=1))
        cluster_id = self.all_cluster_ids[cluster_rel]
        logger.debug("Click on cluster %d with button %s.", cluster_id, b)
        if 'Shift' in e.modifiers:
            emit('select_more', self, [cluster_id])
        else:
            emit('request_select', self, [cluster_id])
Ejemplo n.º 16
0
        def on_view_attached(view, gui):
            if isinstance(view, AmplitudeView):
                # Create batch of vertical lines (full height)
                self.line_visual = LineVisual()
                _fix_coordinate_in_visual(self.line_visual, 'y')
                view.canvas.add_visual(self.line_visual)

                # Create batch of annotative text
                self.text_visual = TextVisual(self.line_color)
                _fix_coordinate_in_visual(self.text_visual, 'y')
                self.text_visual.inserter.insert_vert(
                    'gl_Position.x += 0.001;', 'after_transforms')
                view.canvas.add_visual(self.text_visual)

                @view.actions.add(shortcut='alt+b',
                                  checkable=True,
                                  name='Toggle event markers')
                def toggle(on):
                    """Toggle event markers"""
                    # Use `show` and `hide` instead of `toggle` here in
                    # case synchronization issues
                    if on:
                        logger.debug('Toggle on markers.')
                        self.line_visual.show()
                        self.text_visual.show()
                        view.show_events = True
                    else:
                        logger.debug('Toggle off markers.')
                        self.line_visual.hide()
                        self.text_visual.hide()
                        view.show_events = False
                    view.canvas.update()

                @view.actions.add(shortcut='shift+alt+e',
                                  prompt=True,
                                  name='Go to event',
                                  alias='ge')
                def Go_to_event(event_num):
                    trace_view = gui.get_view(TraceView)
                    if 0 < event_num <= events.size:
                        trace_view.go_to(events[event_num - 1])

                # Disable the menu until events are successfully added
                view.actions.disable('Go to event')
                view.actions.disable('Toggle event markers')
                if not hasattr(view, 'show_events'):
                    view.show_events = True
                view.state_attrs += ('show_events', )

                # Read event markers from file
                filename = controller.dir_path / 'eventmarkers.txt'
                try:
                    events = np.genfromtxt(filename, usecols=0, dtype=None)
                except (FileNotFoundError, OSError):
                    logger.warn('Event marker file not found: `%s`.', filename)
                    view.show_events = False
                    return

                # Create list of event names
                labels = list(map(str, range(1, events.size + 1)))

                # Read event names from file (if present)
                filename = controller.dir_path / 'eventmarkernames.txt'
                try:
                    eventnames = np.loadtxt(filename,
                                            usecols=0,
                                            dtype=str,
                                            max_rows=events.size)
                    labels[:eventnames.size] = np.atleast_1d(eventnames)
                except (FileNotFoundError, OSError):
                    logger.info(
                        'Event marker names file not found (optional):'
                        ' `%s`. Fall back to numbering.', filename)

                # Obtain seconds from samples
                if events.dtype == int:
                    logger.debug('Converting input from samples to seconds.')
                    events = events / controller.model.sample_rate

                logger.debug('Add event markers to amplitude view.')

                # Obtain horizontal positions
                x = -1 + 2 * events / view.duration
                x = x.repeat(4, 0).reshape(-1, 4)
                x[:, 1::2] = 1, -1

                # Add lines and update view
                self.line_visual.reset_batch()
                self.line_visual.add_batch_data(pos=x, color=self.line_color)
                view.canvas.update_visual(self.line_visual)

                # Add text and update view
                self.text_visual.reset_batch()
                self.text_visual.add_batch_data(pos=x[:, :2],
                                                anchor=(1, -1),
                                                text=labels)
                view.canvas.update_visual(self.text_visual)

                # Finally enable the menu
                logger.debug('Enable menu items.')
                view.actions.enable('Go to event')
                view.actions.enable('Toggle event markers')
                if view.show_events:
                    view.actions.get('Toggle event markers').toggle()
                else:
                    self.line_visual.hide()
                    self.text_visual.hide()
Ejemplo n.º 17
0
class FeatureView(MarkerSizeMixin, ScalingMixin, ManualClusteringView):
    """This view displays a 4x4 subplot matrix with different projections of the principal
    component features. This view keeps track of which channels are currently shown.

    Constructor
    -----------

    features : function
        Maps `(cluster_id, channel_ids=None, load_all=False)` to
        `Bunch(data, channel_ids, channel_labels, spike_ids , masks)`.
        * `data` is an `(n_spikes, n_channels, n_features)` array
        * `channel_ids` contains the channel ids of every row in `data`
        * `channel_labels` contains the channel labels of every row in `data`
        * `spike_ids` is a `(n_spikes,)` array
        * `masks` is an `(n_spikes, n_channels)` array

        This allows for a sparse format.

    attributes : dict
        Maps an attribute name to a 1D array with `n_spikes` numbers (for example, spike times).

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    cluster_ids = ()

    # Whether to disable automatic selection of channels.
    fixed_channels = False
    feature_scaling = 1.

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'increase': 'ctrl++',
        'decrease': 'ctrl+-',
        'add_lasso_point': 'ctrl+click',
        'stop_lasso': 'ctrl+right click',
        'toggle_automatic_channel_selection': 'c',
    }

    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

    def set_grid_dim(self, grid_dim):
        """Change the grid dim dynamically.

        Parameters
        ----------
        grid_dim : array-like (2D)
            `grid_dim[row, col]` is a string with two values separated by a comma. Each value
            is the relative channel id (0, 1, 2...) followed by the PC (A, B, C...). For example,
            `grid_dim[row, col] = 0B,1A`. Each value can also be an attribute name, for example
            `time`. For example, `grid_dim[row, col] = time,2C`.

        """
        self.grid_dim = grid_dim
        self.n_rows, self.n_cols = np.array(grid_dim).shape
        self.canvas.grid.shape = (self.n_rows, self.n_cols)

    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self):
        """Yield (i, j, dim)."""
        for i in range(self.n_rows):
            for j in range(self.n_cols):
                dim = self.grid_dim[i][j]
                dim_x, dim_y = dim.split(',')
                yield i, j, dim_x, dim_y

    def _get_axis_label(self, dim):
        """Return the channel label from a dimension, if applicable."""
        if str(dim[:-1]).isdecimal():
            n = len(self.channel_ids)
            channel_id = self.channel_ids[int(dim[:-1]) % n]
            return self.channel_labels[channel_id] + dim[-1]
        else:
            return dim

    def _get_channel_and_pc(self, dim):
        """Return the channel_id and PC of a dim."""
        if self.channel_ids is None:
            return
        assert dim not in self.attributes  # This is called only on PC data.
        s = 'ABCDEFGHIJ'
        # Channel relative index, typically just 0 or 1.
        c_rel = int(dim[:-1])
        # Get the channel_id from the currently-selected channels.
        channel_id = self.channel_ids[c_rel % len(self.channel_ids)]
        pc = s.index(dim[-1])
        return channel_id, pc

    def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None):
        """Extract the points from the data on a given dimension.

        bunch is returned by the features() function.
        dim is the string specifying the dimensions to extract for the data.

        """
        if dim in self.attributes:
            return self.attributes[dim](cluster_id, load_all=load_all)
        masks = bunch.get('masks', None)
        channel_id, pc = self._get_channel_and_pc(dim)
        # Skip the plot if the channel id is not displayed.
        if channel_id not in bunch.channel_ids:  # pragma: no cover
            return Bunch(data=np.zeros((bunch.data.shape[0],)))
        # Get the column index of the current channel in data.
        c = list(bunch.channel_ids).index(channel_id)
        if masks is not None:
            masks = masks[:, c]
        return Bunch(data=self.feature_scaling * bunch.data[:, c, pc], masks=masks)

    def _get_axis_bounds(self, dim, bunch):
        """Return the min/max of an axis."""
        if dim in self.attributes:
            # Attribute: specified lim, or compute the min/max.
            vmin, vmax = bunch.get('lim', (0, 0))
            assert vmin is not None
            assert vmax is not None
            return vmin, vmax
        return (-self._lim, +self._lim)

    def _plot_points(self, bunch, clu_idx=None):
        if not bunch:
            return
        cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None
        for i, j, dim_x, dim_y in self._iter_subplots():
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id)
            # Skip empty data.
            if px is None or py is None:  # pragma: no cover
                logger.warning("Skipping empty data for cluster %d.", cluster_id)
                return
            assert px.data.shape == py.data.shape
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            assert xmin <= xmax
            assert ymin <= ymax
            data_bounds = (xmin, ymin, xmax, ymax)
            masks = _get_masks_max(px, py)
            # Prepare the batch visual with all subplots
            # for the selected cluster.
            self.visual.add_batch_data(
                x=px.data, y=py.data,
                color=_get_point_color(clu_idx),
                # Reduced marker size for background features
                size=self._marker_size,
                masks=_get_point_masks(clu_idx=clu_idx, masks=masks),
                data_bounds=data_bounds,
                box_index=(i, j),
            )
            # Get the channel ids corresponding to the relative channel indices
            # specified in the dimensions. Channel 0 corresponds to the first
            # best channel for the selected cluster, and so on.
            label_x = self._get_axis_label(dim_x)
            label_y = self._get_axis_label(dim_y)
            # Add labels.
            self.text_visual.add_batch_data(
                pos=[1, 1],
                anchor=[-1, -1],
                text=label_y,
                data_bounds=None,
                box_index=(i, j),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1.],
                anchor=[0, 1],
                text=label_x,
                data_bounds=None,
                box_index=(i, j),
            )

    def _plot_axes(self):
        self.line_visual.reset_batch()
        for i, j, dim_x, dim_y in self._iter_subplots():
            self.line_visual.add_batch_data(
                pos=[[-1., 0., +1., 0.],
                     [0., -1., 0., +1.]],
                color=(.5, .5, .5, .5),
                box_index=(i, j),
                data_bounds=None,
            )
        self.canvas.update_visual(self.line_visual)

    def _get_lim(self, bunchs):
        if not bunchs:  # pragma: no cover
            return 1
        m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs)
        M = max(abs(m), abs(M))
        return M

    def _get_scaling_value(self):
        return self.feature_scaling

    def _set_scaling_value(self, value):
        self.feature_scaling = value
        self.plot(fixed_channels=True)

    # Public methods
    # -------------------------------------------------------------------------

    def clear_channels(self):
        """Reset the current channels."""
        self.channel_ids = None
        self.plot()

    def get_clusters_data(self, fixed_channels=None, load_all=None):
        # Get the feature data.
        # Specify the channel ids if these are fixed, otherwise
        # choose the first cluster's best channels.
        c = self.channel_ids if fixed_channels else None
        bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids]
        bunchs = [b for b in bunchs if b]
        if not bunchs:  # pragma: no cover
            return []
        for cluster_id, bunch in zip(self.cluster_ids, bunchs):
            bunch.cluster_id = cluster_id

        # Choose the channels based on the first selected cluster.
        channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else []
        common_channels = list(channel_ids)
        # Intersection (with order kept) of channels belonging to all clusters.
        for bunch in bunchs:
            common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels]
        # The selected channels will be (1) the channels common to all clusters, followed
        # by (2) remaining channels from the first cluster (excluding those already selected
        # in (1)).
        n = len(channel_ids)
        not_common_channels = [c for c in channel_ids if c not in common_channels]
        channel_ids = common_channels + not_common_channels[:n - len(common_channels)]
        assert len(channel_ids) > 0

        # Choose the channels automatically unless fixed_channels is set.
        if (not fixed_channels or self.channel_ids is None):
            self.channel_ids = channel_ids
        assert len(self.channel_ids)

        # Channel labels.
        self.channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])])
            self.channel_labels.update({
                channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))})

        return bunchs

    def plot(self, **kwargs):
        """Update the view with the selected clusters."""

        # Determine whether the channels should be fixed or not.
        added = kwargs.get('up', {}).get('added', None)
        # Fix the channels if the view updates after a cluster event
        # and there are new clusters.
        fixed_channels = (
            self.fixed_channels or kwargs.get('fixed_channels', None) or added is not None)

        # Get the clusters data.
        bunchs = self.get_clusters_data(fixed_channels=fixed_channels)
        bunchs = [b for b in bunchs if b]
        if not bunchs:
            return
        self._lim = self._get_lim(bunchs)

        # Get the background data.
        background = self.features(channel_ids=self.channel_ids)

        # Plot all features.
        self._plot_axes()

        # NOTE: the columns in bunch.data are ordered by decreasing quality
        # of the associated channels. The channels corresponding to each
        # column are given in bunch.channel_ids in the same order.

        # Plot points.
        self.visual.reset_batch()
        self.text_visual.reset_batch()

        self._plot_points(background)  # background spikes

        # Plot each cluster.
        for clu_idx, bunch in enumerate(bunchs):
            self._plot_points(bunch, clu_idx=clu_idx)

        # Upload the data on the GPU.
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(FeatureView, self).attach(gui)

        self.actions.add(
            self.toggle_automatic_channel_selection,
            checked=not self.fixed_channels, checkable=True)
        self.actions.add(self.clear_channels)
        self.actions.separator()

    def toggle_automatic_channel_selection(self, checked):
        """Toggle the automatic selection of channels when the cluster selection changes."""
        self.fixed_channels = not checked

    @property
    def status(self):
        if self.channel_ids is None:  # pragma: no cover
            return ''
        channel_labels = [self.channel_labels[ch] for ch in self.channel_ids[:2]]
        return 'channels: %s' % ', '.join(channel_labels)

    # Dimension selection
    # -------------------------------------------------------------------------

    def on_select_channel(self, sender=None, channel_id=None, key=None, button=None):
        """Respond to the click on a channel from another view, and update the
        relevant subplots."""
        channels = self.channel_ids
        if channels is None:
            return
        if len(channels) == 1:
            self.plot()
            return
        assert len(channels) >= 2
        # Get the axis from the pressed button (1, 2, etc.)
        if key is not None:
            d = np.clip(len(channels) - 1, 0, key - 1)
        else:
            d = 0 if button == 'Left' else 1
        # Change the first or second best channel.
        old = channels[d]
        # Avoid updating the view if the channel doesn't change.
        if channel_id == old:
            return
        channels[d] = channel_id
        # Ensure that the first two channels are different.
        if channels[1 - min(d, 1)] == channel_id:
            channels[1 - min(d, 1)] = old
        assert channels[0] != channels[1]
        # Remove duplicate channels.
        self.channel_ids = _uniq(channels)
        logger.debug("Choose channels %d and %d in feature view.", *channels[:2])
        # Fix the channels temporarily.
        self.plot(fixed_channels=True)
        self.update_status()

    def on_mouse_click(self, e):
        """Select a feature dimension by clicking on a box in the feature view."""
        b = e.button
        if 'Alt' in e.modifiers:
            # Get mouse position in NDC.
            (i, j), _ = self.canvas.grid.box_map(e.pos)
            dim = self.grid_dim[i][j]
            dim_x, dim_y = dim.split(',')
            dim = dim_x if b == 'Left' else dim_y
            other_dim = dim_y if b == 'Left' else dim_x
            if dim not in self.attributes:
                # When a regular (channel, PC) dimension is selected.
                channel_pc = self._get_channel_and_pc(dim)
                if channel_pc is None:
                    return
                channel_id, pc = channel_pc
                logger.debug("Click on feature dim %s, channel id %s, PC %s.", dim, channel_id, pc)
            else:
                # When the selected dimension is an attribute, e.g. "time".
                pc = None
                # Take the channel id in the other dimension.
                channel_pc = self._get_channel_and_pc(other_dim)
                channel_id = channel_pc[0] if channel_pc is not None else None
                logger.debug("Click on feature dim %s.", dim)
            emit('select_feature', self, dim=dim, channel_id=channel_id, pc=pc)

    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.canvas.layout.active_box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True)
            if not bunch:
                continue
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()
        return np.unique(spike_ids[ind])
Ejemplo n.º 18
0
Archivo: trace.py Proyecto: zsong30/phy
class TraceView(ScalingMixin, BaseColorView, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_positions : array-like
        Positions of the channels, used for displaying the channels in the right y order
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    trace_color_0 = (.353, .161, .443)
    trace_color_1 = (.133, .404, .396)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'switch_color_scheme': 'shift+wheel',
        'navigate': 'alt+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'jump_left': 'shift+alt+left',
        'jump_right': 'shift+alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'select_channel_pcA': 'shift+left click',
        'select_channel_pcB': 'shift+right click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None,
            n_channels=None, channel_positions=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        # self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel y ranking.
        self.channel_positions = (
            channel_positions if channel_positions is not None else
            np.c_[np.zeros(n_channels), np.arange(n_channels)])
        # channel_y_ranks[i] is the position of channel #i in the trace view.
        self.channel_y_ranks = np.argsort(np.argsort(self.channel_positions[:, 1]))
        assert self.channel_y_ranks.shape == (n_channels,)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        # Visuals.
        self._create_visuals()

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []
        self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))

    def _create_visuals(self):
        self.canvas.set_layout('stacked', n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        self.trace_visual = UniformPlotVisual()
        # Gradient of color for the traces.
        if self.trace_color_0 and self.trace_color_1:
            self.trace_visual.inserter.insert_frag(
                'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % (
                    self.trace_color_0, self.trace_color_1, self.n_channels), 'end')
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_boxes >= 50 * u_zoom.y) && '
            '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)

    @property
    def stacked(self):
        return self.canvas.stacked

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self.channel_y_ranks
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # Determine the spike color.
        i = bunch.select_index
        c = bunch.spike_cluster
        cs = self.color_schemes.get()
        color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1)

        # We could tweak the color of each spike waveform depending on the template amplitude
        # on each of its best channels.
        # channel_amps = bunch.get('channel_amps', None)
        # if channel_amps is not None:
        #     color = np.tile(color, (n_channels, 1))
        #     assert color.shape == (n_channels, 4)
        #     color[:, 3] = channel_amps

        # The box index depends on the channel.
        box_index = self.channel_y_ranks[bunch.channel_ids]
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=color,
            data_bounds=self.data_bounds,
        )

    def _plot_waveforms(self, waveforms, **kwargs):
        """Plot the waveforms."""
        # waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self.channel_y_ranks[ch]
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def plot(self, update_traces=True, update_waveforms=True):
        if update_waveforms:
            # Load the traces in the interval.
            traces = self.traces(self._interval)

        if update_traces:
            logger.log(5, "Redraw the entire trace view.")
            start, end = self._interval

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        if update_waveforms:
            self._plot_waveforms(traces.get('waveforms', []))

        self._update_axes()
        self.canvas.update()

    def set_interval(self, interval=None):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        if interval != self._interval:
            logger.log(5, "Redraw the entire trace view.")
            self._interval = interval
            emit('is_busy', self, True)
            self.plot(update_traces=True, update_waveforms=True)
            emit('is_busy', self, False)
            emit('time_range_selected', self, interval)
            self.update_status()
        else:
            self.plot(update_traces=False, update_waveforms=True)

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.add(self.jump_right)
        self.actions.add(self.jump_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    @property
    def status(self):
        a, b = self._interval
        return '[{:.2f}s - {:.2f}s]. Color scheme: {}.'.format(a, b, self.color_scheme)

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return getattr(self.canvas.layout, 'origin', Stacked._origin)

    @origin.setter
    def origin(self, value):
        if value is None:
            return
        if self.canvas.layout:
            self.canvas.layout.origin = value
        else:  # pragma: no cover
            logger.warning(
                "Could not set origin to %s because the layout instance was not initialized yet.",
                value)

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'bottom' if self.origin == 'top' else 'top'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def jump_right(self):
        """Jump to right."""
        delay = self.duration * .1
        self.shift(delay)

    def jump_left(self):
        """Jump to left."""
        delay = self.duration * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual.toggle()
        self.canvas.update()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    def update_color(self):
        """Update the view when the color scheme changes."""
        self.plot(update_traces=False, update_waveforms=True)

    # Scaling
    # -------------------------------------------------------------------------

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self.stacked._box_scaling[1]

    @scaling.setter
    def scaling(self, value):
        self.stacked._box_scaling = (self.stacked._box_scaling[0], value)

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value
        self.stacked.update()

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = np.nonzero(self.channel_y_ranks == box_id)[0]
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the select_spike event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('select_spike', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)

        if 'Shift' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0])
            emit('select_channel', self, channel_id=channel_id, button=e.button)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Scroll through the data with alt+wheel."""
        super(TraceView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt',):
            start, end = self._interval
            delay = e.delta * (end - start) * .1
            self.shift(-delay)
Ejemplo n.º 19
0
class ProbeView(ManualClusteringView):
    """This view displays the positions of all channels on the probe, highlighting channels
    where the selected clusters belong.

    Constructor
    -----------

    positions : array-like
        An `(n_channels, 2)` array with the channel positions
    best_channels : function
        Maps `cluster_id` to the list of the best_channel_ids.
    channel_labels : list
        List of channel label strings.
    dead_channels : list
        List of dead channel ids.

    """

    _default_position = 'right'

    # Marker size of channels without selected clusters.
    unselected_marker_size = 10

    # Marker size of channels with selected clusters.
    selected_marker_size = 15

    # Alpha value of the dead channels.
    dead_channel_alpha = .25

    do_show_labels = False

    def __init__(
            self, positions=None, best_channels=None, channel_labels=None,
            dead_channels=None, **kwargs):
        super(ProbeView, self).__init__(**kwargs)
        self.state_attrs += ('do_show_labels',)

        # Normalize positions.
        assert positions.ndim == 2
        assert positions.shape[1] == 2
        positions = positions.astype(np.float32)
        self.positions, self.data_bounds = _get_pos_data_bounds(positions)

        self.n_channels = positions.shape[0]
        self.best_channels = best_channels

        self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)]
        self.dead_channels = dead_channels if dead_channels is not None else ()

        self.probe_visual = ScatterVisual()
        self.canvas.add_visual(self.probe_visual)

        # Probe visual.
        color = np.ones((self.n_channels, 4))
        color[:, :3] = .5
        # Change alpha value for dead channels.
        if len(self.dead_channels):
            color[self.dead_channels, 3] = self.dead_channel_alpha
        self.probe_visual.set_data(
            pos=self.positions, data_bounds=self.data_bounds,
            color=color, size=self.unselected_marker_size)

        # Cluster visual.
        self.cluster_visual = ScatterVisual()
        self.canvas.add_visual(self.cluster_visual)

        # Text visual
        color[:] = 1
        color[self.dead_channels, :3] = self.dead_channel_alpha * 2
        self.text_visual = TextVisual()
        self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header')
        self.text_visual.inserter.add_varying(
            'float', 'v_discard',
            'float((n_channels >= 200 * u_zoom.y) && '
            '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))')
        self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end')
        self.canvas.add_visual(self.text_visual)
        self.text_visual.set_data(
            pos=self.positions, text=self.channel_labels, anchor=[0, -1],
            data_bounds=self.data_bounds, color=color
        )
        self.text_visual.program['n_channels'] = self.n_channels
        self.canvas.update()

    def _get_clu_positions(self, cluster_ids):
        """Get the positions of the channels containing selected clusters."""

        # List of channels per cluster.
        cluster_channels = {i: self.best_channels(cl) for i, cl in enumerate(cluster_ids)}

        # List of clusters per channel.
        clusters_per_channel = defaultdict(lambda: [])
        for clu_idx, channels in cluster_channels.items():
            for channel in channels:
                clusters_per_channel[channel].append(clu_idx)

        # Enumerate the discs for each channel.
        w = self.data_bounds[2] - self.data_bounds[0]
        clu_pos = []
        clu_colors = []
        for channel_id, (x, y) in enumerate(self.positions):
            for i, clu_idx in enumerate(clusters_per_channel[channel_id]):
                n = len(clusters_per_channel[channel_id])
                # Translation.
                t = .025 * w * (i - .5 * (n - 1))
                x += t
                alpha = 1.0 if channel_id not in self.dead_channels else self.dead_channel_alpha
                clu_pos.append((x, y))
                clu_colors.append(selected_cluster_color(clu_idx, alpha=alpha))
        return np.array(clu_pos), np.array(clu_colors)

    def on_select(self, cluster_ids=(), **kwargs):
        """Update the view with the selected clusters."""
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        pos, colors = self._get_clu_positions(cluster_ids)
        self.cluster_visual.set_data(
            pos=pos, color=colors, size=self.selected_marker_size, data_bounds=self.data_bounds)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(ProbeView, self).attach(gui)
        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)

        if not self.do_show_labels:
            self.text_visual.hide()

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.text_visual._hidden = not checked
        self.canvas.update()
Ejemplo n.º 20
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveform_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    _default_position = 'right'
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self, waveforms=None, waveforms_type=None, **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]])
        self.canvas.enable_axes()

        self._box_scaling = (1., 1.)
        self._probe_scaling = (1., 1.)

        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        assert waveforms
        self.waveforms = waveforms
        self.waveforms_types = list(waveforms.keys())
        # Current waveforms type.
        self.waveforms_type = waveforms_type or self.waveforms_types[0]
        assert self.waveforms_type in waveforms
        assert 'waveforms' in waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        return [-1, m, +1, M]

    def get_clusters_data(self):
        bunchs = [
            self.waveforms[self.waveforms_type](cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        masks *= .99999
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, n_samples)
        box_index = np.tile(box_index, n_spikes_clu)
        assert box_index.shape == (n_spikes_clu * n_channels * n_samples, )

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

        self.waveform_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the box bounds as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids)
        self.box_pos = np.array(self.canvas.boxed.box_pos)
        self.box_size = np.array(self.canvas.boxed.box_size)
        self._update_boxes()

        self.data_bounds = self._get_data_bounds(bunchs)

        self.waveform_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.waveform_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)
        self._update_axes(bunchs)
        self.canvas.update()

    def _update_axes(self, bunchs):
        """Update the axes."""
        # Update the axes data bounds.
        _, m, _, M = self.data_bounds
        # Waveform duration, scaled by overlap factor if needed.
        wave_dur = bunchs[0].get('waveform_duration', 1.)
        wave_dur /= .5 * (1 + _overlap_transform(
            1, n=len(self.cluster_ids), overlap=self.overlap))
        x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC,
                                 [wave_dur, M - m])
        axes_data_bounds = (0, 0, x1, y1)
        self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True)

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def _update_boxes(self):
        self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling,
                                       self.box_size)

    def _apply_box_scaling(self):
        self.canvas.layout.scaling = self._box_scaling

    @property
    def box_scaling(self):
        """Scaling of the channel boxes."""
        return self._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        assert len(value) == 2
        self._box_scaling = value
        self._apply_box_scaling()

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w * self._scaling_param_increment, h)
        self._apply_box_scaling()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._box_scaling
        self._box_scaling = (w / self._scaling_param_increment, h)
        self._apply_box_scaling()

    def _get_scaling_value(self):
        return self._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self._box_scaling
        self.box_scaling = (w, value)
        self._update_boxes()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        """Scaling of the entire probe."""
        return self._probe_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        assert len(value) == 2
        self._probe_scaling = value
        self._update_boxes()

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        w, h = self._probe_scaling
        self._probe_scaling = (w * self._scaling_param_increment, h)
        self._update_boxes()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w / self._scaling_param_increment, h)
        self._update_boxes()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h * self._scaling_param_increment)
        self._update_boxes()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        w, h = self._probe_scaling
        self._probe_scaling = (w, h / self._scaling_param_increment)
        self._update_boxes()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('channel_click',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        i = self.waveforms_types.index(self.waveforms_type)
        n = len(self.waveforms_types)
        self.waveforms_type = self.waveforms_types[(i + 1) % n]
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms':
            self.waveforms_type = 'waveforms'
            self.plot()
        elif 'mean_waveforms' in self.waveforms_types:
            self.waveforms_type = 'mean_waveforms'
            self.plot()
Ejemplo n.º 21
0
class TraceView(ScalingMixin, ManualClusteringView):
    """This view shows the raw traces along with spike waveforms.

    Constructor
    -----------

    traces : function
        Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where
        * `data` is an `(n_samples, n_channels)` array
        * `waveforms` is a list of bunchs with the following attributes:
            * `data`
            * `color`
            * `channel_ids`
            * `start_time`
            * `spike_id`
            * `spike_cluster`

    spike_times : function
        Teturns the list of relevant spike times.
    sample_rate : float
    duration : float
    n_channels : int
    channel_vertical_order : array-like
        Permutation of the channels. This 1D array gives the channel id of all channels from
        top to bottom (or conversely, depending on `origin=top|bottom`).
    channel_labels : list
        Labels of all shown channels. By default, this is just the channel ids.

    """
    _default_position = 'left'
    auto_update = True
    auto_scale = True
    interval_duration = .25  # default duration of the interval
    shift_amount = .1
    scaling_coeff_x = 1.25
    trace_quantile = .01  # quantile for auto-scaling
    default_trace_color = (.5, .5, .5, 1)
    default_shortcuts = {
        'change_trace_size': 'ctrl+wheel',
        'decrease': 'alt+down',
        'increase': 'alt+up',
        'go_left': 'alt+left',
        'go_right': 'alt+right',
        'go_to_start': 'alt+home',
        'go_to_end': 'alt+end',
        'go_to': 'alt+t',
        'go_to_next_spike': 'alt+pgdown',
        'go_to_previous_spike': 'alt+pgup',
        'narrow': 'alt++',
        'select_spike': 'ctrl+click',
        'switch_origin': 'alt+o',
        'toggle_highlighted_spikes': 'alt+s',
        'toggle_show_labels': 'alt+l',
        'widen': 'alt+-',
    }
    default_snippets = {
        'go_to': 'tg',
        'shift': 'ts',
    }

    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None,
            channel_vertical_order=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False
        self._scaling = 1.

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel permutation.
        self._channel_perm = (
            np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order)
        assert self._channel_perm.shape == (n_channels,)
        self._channel_perm = np.argsort(self._channel_perm)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Box and probe scaling.
        self._origin = None

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        # Visuals.
        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.canvas.add_visual(self.text_visual)

        # Make a copy of the initial box pos and size. We'll apply the scaling
        # to these quantities.
        self.box_size = np.array(self.canvas.stacked.box_size)

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []

    @property
    def stacked(self):
        return self.canvas.stacked

    def _permute_channels(self, x, inv=False):
        cp = self._channel_perm
        cp = np.argsort(cp)
        return cp[x]

    # Internal methods
    # -------------------------------------------------------------------------

    def _plot_traces(self, traces, color=None):
        traces = traces.T
        n_samples = traces.shape[1]
        n_ch = self.n_channels
        assert traces.shape == (n_ch, n_samples)
        color = color or self.default_trace_color

        t = self._interval[0] + np.arange(n_samples) * self.dt
        t = np.tile(t, (n_ch, 1))

        box_index = self._permute_channels(np.arange(n_ch))
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1)

        assert t.shape == (n_ch, n_samples)
        assert traces.shape == (n_ch, n_samples)
        assert box_index.shape == (n_ch, n_samples)

        self.trace_visual.color = color
        self.canvas.update_visual(
            self.trace_visual,
            t, traces,
            data_bounds=self.data_bounds,
            box_index=box_index.ravel(),
        )

    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # The box index depends on the channel.
        box_index = self._permute_channels(bunch.channel_ids)
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=bunch.color,
            data_bounds=self.data_bounds,
        )

    def _plot_labels(self, traces):
        self.text_visual.reset_batch()
        for ch in range(self.n_channels):
            bi = self._permute_channels(ch)
            ch_label = self.channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[self.data_bounds[0], 0],
                text=ch_label,
                anchor=[+1., 0],
                data_bounds=self.data_bounds,
                box_index=bi,
            )
        self.canvas.update_visual(self.text_visual)

    # Public methods
    # -------------------------------------------------------------------------

    def _restrict_interval(self, interval):
        start, end = interval
        # Round the times to full samples to avoid subsampling shifts
        # in the traces.
        start = int(round(start * self.sample_rate)) / self.sample_rate
        end = int(round(end * self.sample_rate)) / self.sample_rate
        # Restrict the interval to the boundaries of the traces.
        if start < 0:
            end += (-start)
            start = 0
        elif end >= self.duration:
            start -= (end - self.duration)
            end = self.duration
        start = np.clip(start, 0, end)
        end = np.clip(end, start, self.duration)
        assert 0 <= start < end <= self.duration
        return start, end

    def set_interval(self, interval=None, change_status=True):
        """Display the traces and spikes in a given interval."""
        if interval is None:
            interval = self._interval
        interval = self._restrict_interval(interval)

        # Load the traces.
        traces = self.traces(interval)
        self.waveforms = traces.get('waveforms', [])

        if interval != self._interval:
            logger.debug("Redraw the entire trace view.")
            self._interval = interval
            start, end = interval

            # Set the status message.
            if change_status:
                self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end))

            # Find the data bounds.
            if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC:
                ymin = np.quantile(traces.data, self.trace_quantile)
                ymax = np.quantile(traces.data, 1. - self.trace_quantile)
            else:
                ymin, ymax = self.data_bounds[1], self.data_bounds[3]
            self.data_bounds = (start, ymin, end, ymax)

            # Used for spike click.
            self._waveform_times = []

            # Plot the traces.
            self._plot_traces(
                traces.data, color=traces.get('color', None))

            # Plot the labels.
            if self.do_show_labels:
                self._plot_labels(traces.data)

        # Plot the waveforms.
        self.plot()

    def on_select(self, cluster_ids=None, **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        # Make sure we call again self.traces() when the cluster selection changes.
        self.set_interval()

    def plot(self, **kwargs):
        """Plot the waveforms."""
        waveforms = self.waveforms
        assert isinstance(waveforms, list)
        if waveforms:
            self.waveform_visual.show()
            self.waveform_visual.reset_batch()
            for w in waveforms:
                self._plot_spike(w)
                self._waveform_times.append(
                    (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None)))
            self.canvas.update_visual(self.waveform_visual)
        else:  # pragma: no cover
            self.waveform_visual.hide()

        self._update_axes()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(TraceView, self).attach(gui)

        self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels)
        self.actions.add(
            self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes)
        self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale)
        self.actions.add(self.switch_origin)
        self.actions.separator()

        self.actions.add(
            self.go_to, prompt=True, prompt_default=lambda: str(self.time))
        self.actions.separator()

        self.actions.add(self.go_to_start)
        self.actions.add(self.go_to_end)
        self.actions.separator()

        self.actions.add(self.shift, prompt=True)
        self.actions.add(self.go_right)
        self.actions.add(self.go_left)
        self.actions.separator()

        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        self.actions.add(self.go_to_next_spike)
        self.actions.add(self.go_to_previous_spike)
        self.actions.separator()

        self.set_interval()

    # Origin
    # -------------------------------------------------------------------------

    @property
    def origin(self):
        """Whether to show the channels from top to bottom (`top` option, the default), or from
        bottom to top (`bottom`)."""
        return self._origin

    @origin.setter
    def origin(self, value):
        self._origin = value
        if self.canvas.layout:
            self.canvas.layout.origin = value

    def switch_origin(self):
        """Switch between top and bottom origin for the channels."""
        self.origin = 'top' if self._origin in ('bottom', None) else 'bottom'

    # Navigation
    # -------------------------------------------------------------------------

    @property
    def time(self):
        """Time at the center of the window."""
        return sum(self._interval) * .5

    @property
    def interval(self):
        """Interval as `(tmin, tmax)`."""
        return self._interval

    @interval.setter
    def interval(self, value):
        self.set_interval(value)

    @property
    def half_duration(self):
        """Half of the duration of the current interval."""
        if self._interval is not None:
            a, b = self._interval
            return (b - a) * .5
        else:
            return self.interval_duration * .5

    def go_to(self, time):
        """Go to a specific time (in seconds)."""
        half_dur = self.half_duration
        self.set_interval((time - half_dur, time + half_dur))

    def shift(self, delay):
        """Shift the interval by a given delay (in seconds)."""
        self.go_to(self.time + delay)

    def go_to_start(self):
        """Go to the start of the recording."""
        self.go_to(0)

    def go_to_end(self):
        """Go to end of the recording."""
        self.go_to(self.duration)

    def go_right(self):
        """Go to right."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(delay)

    def go_left(self):
        """Go to left."""
        start, end = self._interval
        delay = (end - start) * .1
        self.shift(-delay)

    def _jump_to_spike(self, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        spike_times = self.get_spike_times()
        if spike_times is not None and len(spike_times):
            ind = np.searchsorted(spike_times, self.time)
            n = len(spike_times)
            self.go_to(spike_times[(ind + delta) % n])

    def go_to_next_spike(self, ):
        """Jump to the next spike from the first selected cluster."""
        self._jump_to_spike(+1)

    def go_to_previous_spike(self, ):
        """Jump to the previous spike from the first selected cluster."""
        self._jump_to_spike(-1)

    def toggle_highlighted_spikes(self, checked):
        """Toggle between showing all spikes or selected spikes."""
        self.show_all_spikes = checked
        self.set_interval()

    def widen(self):
        """Increase the interval size."""
        t, h = self.time, self.half_duration
        h *= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    def narrow(self):
        """Decrease the interval size."""
        t, h = self.time, self.half_duration
        h /= self.scaling_coeff_x
        self.set_interval((t - h, t + h))

    # Misc
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Toggle the display of the channel ids."""
        logger.debug("Set show labels to %s.", checked)
        self.do_show_labels = checked
        self.set_interval()

    def toggle_auto_scale(self, checked):
        """Toggle automatic scaling of the traces."""
        logger.debug("Set auto scale to %s.", checked)
        self.auto_scale = checked

    # Scaling
    # -------------------------------------------------------------------------

    def _apply_scaling(self):
        self.canvas.layout.scaling = (self.canvas.layout.scaling[0], self._scaling)

    @property
    def scaling(self):
        """Scaling of the channel boxes."""
        return self._scaling

    @scaling.setter
    def scaling(self, value):
        self._scaling = value
        self._apply_scaling()

    def _get_scaling_value(self):
        return self.scaling

    def _set_scaling_value(self, value):
        self.scaling = value

    # Spike selection
    # -------------------------------------------------------------------------

    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = self._permute_channels(box_id, inv=True)
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the spike_click event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('spike_click', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)
Ejemplo n.º 22
0
class WaveformClusteringView(LassoMixin,ManualClusteringView):

    default_shortcuts = {
        'next_channel': 'f',
        'previous_channel': 'r',
    }

    def __init__(self, model=None):

        super(WaveformClusteringView, self).__init__()
        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))

        self.model = model

        self.gain = 0.195
        self.Fs = 30  # kHz

        self.visual = PlotVisual()

        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0)
        self.cluster_ids = None
        self.cluster_id = None
        self.channel_ids = None
        self.wavefs = None
        self.current_channel_idx = None
    
    def on_select(self, cluster_ids=(), **kwargs):
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        
        if self.cluster_id != cluster_ids[0]:
            self.cluster_id = cluster_ids[0]

            self.channel_ids = self.model.get_cluster_channels(self.cluster_id)
            self.spike_ids = self.model.get_cluster_spikes(self.cluster_id)
            self.wavefs = self.model.get_waveforms(self.spike_ids,channel_ids=self.channel_ids)

            self.setChannelIdx(0)
        

    def plotWaveforms(self):

        Nspk,Ntime,Nchan = self.wavefs.shape

        self.visual.reset_batch()
        self.text_visual.reset_batch()

        x = np.tile(np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime), (Nspk, 1))

        M=np.max(np.abs(self.wavefs[:,:,self.current_channel_idx]))
        #print(M*self.gain)
        if M*self.gain<100:
            M = 10*np.ceil(M*self.gain/10)
        elif M*self.gain<1000:
            M = 100*np.ceil(M*self.gain/100)
        else:
            M = 1000*np.floor(M*self.gain/1000)
        self.data_bounds = (x[0][0], -M, x[0][-1], M)

        colorwavef = selected_cluster_color(0)
        colormedian = selected_cluster_color(3)#(1,156/256,0,.5)#selected_cluster_color(1)
        colorstd = (0,1,0,1)#selected_cluster_color(2)
        colorqtl = (1,1,0,1)

        if Nspk>100:
            medianCl = np.median(self.wavefs[:,:,self.current_channel_idx],axis=0)
            stdCl = np.std(self.wavefs[:,:,self.current_channel_idx],axis=0)
            q1 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.01,axis=0,interpolation='higher')
            q9 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.99,axis=0,interpolation='lower')

        self.visual.add_batch_data(
                x=x, y=self.gain*self.wavefs[:,:,self.current_channel_idx], color=colorwavef, data_bounds=self.data_bounds, box_index=0)

        #stats
        if Nspk>100:
            x1 = x[0]
            self.visual.add_batch_data(
                    x=x1, y=self.gain*medianCl, color=colormedian, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl+3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl-3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q1, color=colorqtl, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q9, color=colorqtl, data_bounds=self.data_bounds, box_index=0)

        #axes
        self.text_visual.add_batch_data(
                pos=[.9, .98],
                text='[uV]',
                anchor=[-1, -1],
                box_index=0,
            )
        
        self.text_visual.add_batch_data(
                pos=[-1, -.95],
                text='[ms]',
                anchor=[1, 1],
                box_index=0,
            )

        label = 'Ch {a}'.format(a=self.channel_ids[self.current_channel_idx])
        self.text_visual.add_batch_data(
                pos=[-.98, .98],
                text=str(label),
                anchor=[1, -1],
                box_index=0,
            )
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.axes.reset_data_bounds(self.data_bounds)

        self.canvas.update()
    
    def setChannel(self,channel_id):
        self.channel_ids
        itemindex = np.where(self.channel_ids==channel_id)[0]
        if len(itemindex):
            self.setChannelIdx(itemindex[0])

    def setChannelIdx(self,channel_idx):
        self.current_channel_idx = channel_idx
        self.plotWaveforms()

    def setNextChannelIdx(self):
        if self.current_channel_idx == len(self.channel_ids)-1:
            return
        self.setChannelIdx(self.current_channel_idx+1)

    def setPrevChannelIdx(self):
        if self.current_channel_idx == 0:
            return
        self.setChannelIdx(self.current_channel_idx-1)

    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)

        pos = []
        spike_ids = []

        Ntime = self.wavefs.shape[1]
        x = np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime)

        for idx,spike in enumerate(self.spike_ids):
            points = np.c_[x,self.gain*self.wavefs[idx,:,self.current_channel_idx]]
            pos.append(points)
            spike_ids.append([spike]*len(x))

        if not pos:  # pragma: no cover
            logger.warning("Empty lasso.")
            return np.array([])
        pos = np.vstack(pos)
        pos = range_transform([self.data_bounds], [NDC], pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()

        # Return all spikes not lassoed, so the selected cluster is still the same we are working on
        spikes_to_remove = np.unique(spike_ids[ind])
        keepspikes=np.isin(self.spike_ids,spikes_to_remove,assume_unique=True,invert=True)
        A=self.spike_ids[self.spike_ids != spikes_to_remove]
        if len(A)>0:
            return self.spike_ids[keepspikes]
        else:
            return np.array([], dtype=np.int64)
Ejemplo n.º 23
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveforms_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    ax_color = (.75, .75, .75, 1.)
    tick_size = 5.
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'previous_waveforms_type': 'shift+w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    @property
    def _current_visual(self):
        if self.waveforms_type == 'waveforms':
            return self.waveform_visual
        else:
            return self.waveform_agg_visual

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        # Symmetrize on the y axis.
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def get_clusters_data(self):
        if self.waveforms_type not in self.waveforms:
            return
        bunchs = [
            self.waveforms_types.get()(cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()
        if not bunchs:
            return

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids
        if bunchs[0].data is not None:
            self.wave_duration = bunchs[0].data.shape[1] / float(
                self.sample_rate)
        else:  # pragma: no cover
            self.wave_duration = 1.

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the Boxed box positions as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids))

        self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs)

        self._current_visual.reset_batch()
        self.line_visual.reset_batch()
        self.tick_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.tick_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self._current_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)

        # Only show the current waveform visual.
        if self._current_visual == self.waveform_visual:
            self.waveform_visual.show()
            self.waveform_agg_visual.hide()
        elif self._current_visual == self.waveform_agg_visual:
            self.waveform_agg_visual.show()
            self.waveform_visual.hide()

        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.previous_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    @property
    def status(self):
        return self.waveforms_type

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        self.boxed.expand_box_width()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_box_width()

    @property
    def box_scaling(self):
        return self.boxed._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        self.boxed._box_scaling = value

    def _get_scaling_value(self):
        return self.boxed._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self.boxed._box_scaling
        self.boxed._box_scaling = (w, value)
        self.boxed.update()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        return self.boxed._layout_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        self.boxed._layout_scaling = value

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        self.boxed.expand_layout_width()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_layout_width()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        self.boxed.expand_layout_height()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        self.boxed.shrink_layout_height()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('select_channel',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    @property
    def waveforms_type(self):
        return self.waveforms_types.current

    @waveforms_type.setter
    def waveforms_type(self, value):
        self.waveforms_types.set(value)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        self.waveforms_types.next()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def previous_waveforms_type(self):
        """Switch to the previous waveforms type."""
        self.waveforms_types.previous()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms:
            self.waveforms_types.set('waveforms')
            logger.debug("Switch to raw waveforms.")
            self.plot()
        elif 'mean_waveforms' in self.waveforms:
            self.waveforms_types.set('mean_waveforms')
            logger.debug("Switch to mean waveforms.")
            self.plot()
Ejemplo n.º 24
0
    def __init__(
            self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None,
            channel_vertical_order=None, channel_labels=None, **kwargs):

        self.do_show_labels = True
        self.show_all_spikes = False
        self._scaling = 1.

        self.get_spike_times = spike_times

        # Sample rate.
        assert sample_rate > 0
        self.sample_rate = float(sample_rate)
        self.dt = 1. / self.sample_rate

        # Traces and spikes.
        assert hasattr(traces, '__call__')
        self.traces = traces
        self.waveforms = None

        assert duration >= 0
        self.duration = duration

        assert n_channels >= 0
        self.n_channels = n_channels

        # Channel permutation.
        self._channel_perm = (
            np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order)
        assert self._channel_perm.shape == (n_channels,)
        self._channel_perm = np.argsort(self._channel_perm)

        # Channel labels.
        self.channel_labels = (
            channel_labels if channel_labels is not None else
            ['%d' % ch for ch in range(n_channels)])
        assert len(self.channel_labels) == n_channels

        # Box and probe scaling.
        self._origin = None

        # Initialize the view.
        super(TraceView, self).__init__(**kwargs)
        self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale')
        self.local_state_attrs += ('interval', 'scaling',)

        self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels)
        self.canvas.enable_axes(show_y=False)

        # Visuals.
        self.trace_visual = UniformPlotVisual()
        self.canvas.add_visual(self.trace_visual)

        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_visual)

        self.text_visual = TextVisual()
        _fix_coordinate_in_visual(self.text_visual, 'x')
        self.canvas.add_visual(self.text_visual)

        # Make a copy of the initial box pos and size. We'll apply the scaling
        # to these quantities.
        self.box_size = np.array(self.canvas.stacked.box_size)

        # Initial interval.
        self._interval = None
        self.go_to(duration / 2.)

        self._waveform_times = []