Пример #1
0
    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)
Пример #2
0
    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)
Пример #3
0
class WaveformView(ScalingMixin, ManualClusteringView):
    """This view shows the waveforms of the selected clusters, on relevant channels,
    following the probe geometry.

    Constructor
    -----------

    waveforms : dict of functions
        Every function maps a cluster id to a Bunch with the following attributes:

        * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)`
        * `channel_ids` : the channel ids corresponding to the third dimension in `data`
        * `channel_labels` : a list of channel labels for every channel in `channel_ids`
        * `channel_positions` : a 2D array with the coordinates of the channels on the probe
        * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks
        * `alpha` : the alpha transparency channel

        The keys of the dictionary are called **waveform types**. The `next_waveforms_type`
        action cycles through all available waveform types. The key `waveforms` is mandatory.
    waveforms_type : str
        Default key of the waveforms dictionary to plot initially.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    ax_color = (.75, .75, .75, 1.)
    tick_size = 5.
    cluster_ids = ()

    default_shortcuts = {
        'toggle_waveform_overlap': 'o',
        'toggle_show_labels': 'ctrl+l',
        'next_waveforms_type': 'w',
        'previous_waveforms_type': 'shift+w',
        'toggle_mean_waveforms': 'm',

        # Box scaling.
        'widen': 'ctrl+right',
        'narrow': 'ctrl+left',
        'increase': 'ctrl+up',
        'decrease': 'ctrl+down',
        'change_box_size': 'ctrl+wheel',

        # Probe scaling.
        'extend_horizontally': 'shift+right',
        'shrink_horizontally': 'shift+left',
        'extend_vertically': 'shift+up',
        'shrink_vertically': 'shift+down',
    }
    default_snippets = {
        'change_n_spikes_waveforms': 'wn',
    }

    def __init__(self,
                 waveforms=None,
                 waveforms_type=None,
                 sample_rate=None,
                 **kwargs):
        self._overlap = False
        self.do_show_labels = True
        self.channel_ids = None
        self.filtered_tags = ()
        self.wave_duration = 0.  # updated in the plotting method
        self.data_bounds = None
        self.sample_rate = sample_rate
        self._status_suffix = ''
        assert sample_rate > 0., "The sample rate must be provided to the waveform view."

        # Initialize the view.
        super(WaveformView, self).__init__(**kwargs)
        self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels')
        self.local_state_attrs += ('box_scaling', 'probe_scaling')

        # Box and probe scaling.
        self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2)))

        # Ensure waveforms is a dictionary, even if there is a single waveforms type.
        waveforms = waveforms or {}
        waveforms = waveforms if isinstance(waveforms, dict) else {
            'waveforms': waveforms
        }
        self.waveforms = waveforms

        # Rotating property waveforms types.
        self.waveforms_types = RotatingProperty()
        for name, value in self.waveforms.items():
            self.waveforms_types.add(name, value)
        # Current waveforms type.
        self.waveforms_types.set(waveforms_type)
        assert self.waveforms_type in self.waveforms

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.tick_visual = UniformScatterVisual(marker='vbar',
                                                color=self.ax_color,
                                                size=self.tick_size)
        self.canvas.add_visual(self.tick_visual)

        # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased
        # agg plot visual for mean and template waveforms.
        self.waveform_agg_visual = PlotAggVisual()
        self.waveform_visual = PlotVisual()
        self.canvas.add_visual(self.waveform_agg_visual)
        self.canvas.add_visual(self.waveform_visual)

    # Internal methods
    # -------------------------------------------------------------------------

    @property
    def _current_visual(self):
        if self.waveforms_type == 'waveforms':
            return self.waveform_visual
        else:
            return self.waveform_agg_visual

    def _get_data_bounds(self, bunchs):
        m = min(_min(b.data) for b in bunchs)
        M = max(_max(b.data) for b in bunchs)
        # Symmetrize on the y axis.
        M = max(abs(m), abs(M))
        return [-1, -M, +1, M]

    def get_clusters_data(self):
        if self.waveforms_type not in self.waveforms:
            return
        bunchs = [
            self.waveforms_types.get()(cluster_id)
            for cluster_id in self.cluster_ids
        ]
        clu_offsets = _get_clu_offsets(bunchs)
        n_clu = max(clu_offsets) + 1
        # Offset depending on the overlap.
        for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
            bunch.index = i
            bunch.offset = offset
            bunch.n_clu = n_clu
            bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
        return bunchs

    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )

    def _plot_labels(self, channel_ids, n_clusters, channel_labels):
        # Add channel labels.
        if not self.do_show_labels:
            return
        self.text_visual.reset_batch()
        for i, ch in enumerate(channel_ids):
            label = channel_labels[ch]
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(label),
                anchor=[-1.25, 0],
                box_index=i,
            )
        self.canvas.update_visual(self.text_visual)

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        if not self.cluster_ids:
            return
        bunchs = self.get_clusters_data()
        if not bunchs:
            return

        # All channel ids appearing in all selected clusters.
        channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs])))
        self.channel_ids = channel_ids
        if bunchs[0].data is not None:
            self.wave_duration = bunchs[0].data.shape[1] / float(
                self.sample_rate)
        else:  # pragma: no cover
            self.wave_duration = 1.

        # Channel labels.
        channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids])
            channel_labels.update({
                channel_id: chl[i]
                for i, channel_id in enumerate(d.channel_ids)
            })

        # Update the Boxed box positions as a function of the selected channels.
        if channel_ids:
            self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids))

        self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs)

        self._current_visual.reset_batch()
        self.line_visual.reset_batch()
        self.tick_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.tick_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self._current_visual)

        self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels)

        # Only show the current waveform visual.
        if self._current_visual == self.waveform_visual:
            self.waveform_visual.show()
            self.waveform_agg_visual.hide()
        elif self._current_visual == self.waveform_agg_visual:
            self.waveform_agg_visual.show()
            self.waveform_visual.hide()

        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(WaveformView, self).attach(gui)

        self.actions.add(self.toggle_waveform_overlap,
                         checkable=True,
                         checked=self.overlap)
        self.actions.add(self.toggle_show_labels,
                         checkable=True,
                         checked=self.do_show_labels)
        self.actions.add(self.next_waveforms_type)
        self.actions.add(self.previous_waveforms_type)
        self.actions.add(self.toggle_mean_waveforms, checkable=True)
        self.actions.separator()

        # Box scaling.
        self.actions.add(self.widen)
        self.actions.add(self.narrow)
        self.actions.separator()

        # Probe scaling.
        self.actions.add(self.extend_horizontally)
        self.actions.add(self.shrink_horizontally)
        self.actions.separator()

        self.actions.add(self.extend_vertically)
        self.actions.add(self.shrink_vertically)
        self.actions.separator()

    @property
    def boxed(self):
        """Layout instance."""
        return self.canvas.boxed

    @property
    def status(self):
        return self.waveforms_type

    # Overlap
    # -------------------------------------------------------------------------

    @property
    def overlap(self):
        """Whether to overlap the waveforms belonging to different clusters."""
        return self._overlap

    @overlap.setter
    def overlap(self, value):
        self._overlap = value
        self.plot()

    def toggle_waveform_overlap(self, checked):
        """Toggle the overlap of the waveforms."""
        self.overlap = checked

    # Box scaling
    # -------------------------------------------------------------------------

    def widen(self):
        """Increase the horizontal scaling of the waveforms."""
        self.boxed.expand_box_width()

    def narrow(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_box_width()

    @property
    def box_scaling(self):
        return self.boxed._box_scaling

    @box_scaling.setter
    def box_scaling(self, value):
        self.boxed._box_scaling = value

    def _get_scaling_value(self):
        return self.boxed._box_scaling[1]

    def _set_scaling_value(self, value):
        w, h = self.boxed._box_scaling
        self.boxed._box_scaling = (w, value)
        self.boxed.update()

    # Probe scaling
    # -------------------------------------------------------------------------

    @property
    def probe_scaling(self):
        return self.boxed._layout_scaling

    @probe_scaling.setter
    def probe_scaling(self, value):
        self.boxed._layout_scaling = value

    def extend_horizontally(self):
        """Increase the horizontal scaling of the probe."""
        self.boxed.expand_layout_width()

    def shrink_horizontally(self):
        """Decrease the horizontal scaling of the waveforms."""
        self.boxed.shrink_layout_width()

    def extend_vertically(self):
        """Increase the vertical scaling of the waveforms."""
        self.boxed.expand_layout_height()

    def shrink_vertically(self):
        """Decrease the vertical scaling of the waveforms."""
        self.boxed.shrink_layout_height()

    # Navigation
    # -------------------------------------------------------------------------

    def toggle_show_labels(self, checked):
        """Whether to show the channel ids or not."""
        self.do_show_labels = checked
        self.text_visual.show() if checked else self.text_visual.hide()
        self.canvas.update()

    def on_mouse_click(self, e):
        """Select a channel by clicking on a box in the waveform view."""
        b = e.button
        nums = tuple('%d' % i for i in range(10))
        if 'Control' in e.modifiers or e.key in nums:
            key = int(e.key) if e.key in nums else None
            # Get mouse position in NDC.
            channel_idx, _ = self.canvas.boxed.box_map(e.pos)
            channel_id = self.channel_ids[channel_idx]
            logger.debug("Click on channel_id %d with key %s and button %s.",
                         channel_id, key, b)
            emit('select_channel',
                 self,
                 channel_id=channel_id,
                 key=key,
                 button=b)

    @property
    def waveforms_type(self):
        return self.waveforms_types.current

    @waveforms_type.setter
    def waveforms_type(self, value):
        self.waveforms_types.set(value)

    def next_waveforms_type(self):
        """Switch to the next waveforms type."""
        self.waveforms_types.next()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def previous_waveforms_type(self):
        """Switch to the previous waveforms type."""
        self.waveforms_types.previous()
        logger.debug("Switch to waveforms type %s.", self.waveforms_type)
        self.plot()

    def toggle_mean_waveforms(self, checked):
        """Switch to the `mean_waveforms` type, if it is available."""
        if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms:
            self.waveforms_types.set('waveforms')
            logger.debug("Switch to raw waveforms.")
            self.plot()
        elif 'mean_waveforms' in self.waveforms:
            self.waveforms_types.set('mean_waveforms')
            logger.debug("Switch to mean waveforms.")
            self.plot()
Пример #4
0
class CorrelogramView(ScalingMixin, ManualClusteringView):
    """A view showing the autocorrelogram of the selected clusters, and all cross-correlograms
    of cluster pairs.

    Constructor
    -----------

    correlograms : function
        Maps `(cluster_ids, bin_size, window_size)` to an `(n_clusters, n_clusters, n_bins) array`.

    firing_rate : function
        Maps `(cluster_ids, bin_size)` to an `(n_clusters, n_clusters) array`

    """

    # Do not show too many clusters.
    max_n_clusters = 20

    _default_position = 'left'
    cluster_ids = ()

    # Bin size, in seconds.
    bin_size = 1e-3

    # Window size, in seconds.
    window_size = 50e-3

    # Refactory period, in seconds
    refractory_period = 2e-3

    # Whether the normalization is uniform across entire rows or not.
    uniform_normalization = False

    default_shortcuts = {
        'change_window_size': 'ctrl+wheel',
        'change_bin_size': 'alt+wheel',
    }

    default_snippets = {
        'set_bin': 'cb',
        'set_window': 'cw',
        'set_refractory_period': 'cr',
    }

    def __init__(self,
                 correlograms=None,
                 firing_rate=None,
                 sample_rate=None,
                 **kwargs):
        super(CorrelogramView, self).__init__(**kwargs)
        self.state_attrs += ('bin_size', 'window_size', 'refractory_period',
                             'uniform_normalization')
        self.local_state_attrs += ()
        self.canvas.set_layout(layout='grid')

        # Outside margin to show labels.
        self.canvas.gpu_transforms.add(Scale(.9))

        assert sample_rate > 0
        self.sample_rate = float(sample_rate)

        # Function clusters => CCGs.
        self.correlograms = correlograms

        # Function clusters => firing rates (same unit as CCG).
        self.firing_rate = firing_rate

        # Set the default bin and window size.
        self._set_bin_window(bin_size=self.bin_size,
                             window_size=self.window_size)

        self.correlogram_visual = HistogramVisual()
        self.canvas.add_visual(self.correlogram_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

        self.text_visual = TextVisual(color=(1., 1., 1., 1.))
        self.canvas.add_visual(self.text_visual)

    # -------------------------------------------------------------------------
    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self, n_clusters):
        for i in range(n_clusters):
            for j in range(n_clusters):
                yield i, j

    def get_clusters_data(self, load_all=None):
        ccg = self.correlograms(self.cluster_ids, self.bin_size,
                                self.window_size)
        fr = self.firing_rate(self.cluster_ids,
                              self.bin_size) if self.firing_rate else None
        assert ccg.ndim == 3
        n_bins = ccg.shape[2]
        bunchs = []
        m = ccg.max()
        for i, j in self._iter_subplots(len(self.cluster_ids)):
            b = Bunch()
            b.correlogram = ccg[i, j, :]
            if not self.uniform_normalization:
                # Normalization row per row.
                m = ccg[i, j, :].max()
            b.firing_rate = fr[i, j] if fr is not None else None
            b.data_bounds = (0, 0, n_bins, m)
            b.pair_index = i, j
            b.color = selected_cluster_color(i, 1)
            if i != j:
                b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1))
            bunchs.append(b)
        return bunchs

    def _plot_pair(self, bunch):
        # Plot the histogram.
        self.correlogram_visual.add_batch_data(hist=bunch.correlogram,
                                               color=bunch.color,
                                               ylim=bunch.data_bounds[3],
                                               box_index=bunch.pair_index)

        # Plot the firing rate.
        gray = (.25, .25, .25, 1.)
        if bunch.firing_rate is not None:
            # Line.
            pos = np.array([[
                0, bunch.firing_rate, bunch.data_bounds[2], bunch.firing_rate
            ]])
            self.line_visual.add_batch_data(pos=pos,
                                            color=gray,
                                            data_bounds=bunch.data_bounds,
                                            box_index=bunch.pair_index)
            # # Text.
            # self.text_visual.add_batch_data(
            #     pos=[bunch.data_bounds[2], bunch.firing_rate],
            #     text='%.2f' % bunch.firing_rate,
            #     anchor=(-1, 0),
            #     box_index=bunch.pair_index,
            #     data_bounds=bunch.data_bounds,
            # )

        # Refractory period.
        xrp0 = round(
            (self.window_size * .5 - self.refractory_period) / self.bin_size)
        xrp1 = round((self.window_size * .5 + self.refractory_period) /
                     self.bin_size) + 1
        ylim = bunch.data_bounds[3]
        pos = np.array([[xrp0, 0, xrp0, ylim], [xrp1, 0, xrp1, ylim]])
        self.line_visual.add_batch_data(pos=pos,
                                        color=gray,
                                        data_bounds=bunch.data_bounds,
                                        box_index=bunch.pair_index)

    def _plot_labels(self):
        n = len(self.cluster_ids)

        # Display the cluster ids in the subplots.
        for k in range(n):
            self.text_visual.add_batch_data(
                pos=[-1, 0],
                text=str(self.cluster_ids[k]),
                anchor=[-1.25, 0],
                data_bounds=None,
                box_index=(k, 0),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1],
                text=str(self.cluster_ids[k]),
                anchor=[0, -1.25],
                data_bounds=None,
                box_index=(n - 1, k),
            )

        # # Display the window size in the bottom right subplot.
        # self.text_visual.add_batch_data(
        #     pos=[1, -1],
        #     anchor=[1.25, 1],
        #     text='%.1f ms' % (1000 * .5 * self.window_size),
        #     box_index=(n - 1, n - 1),
        # )

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        self.canvas.grid.shape = (len(self.cluster_ids), len(self.cluster_ids))

        bunchs = self.get_clusters_data()

        self.correlogram_visual.reset_batch()
        self.line_visual.reset_batch()
        self.text_visual.reset_batch()

        for bunch in bunchs:
            self._plot_pair(bunch)
        self._plot_labels()

        self.canvas.update_visual(self.correlogram_visual)
        self.canvas.update_visual(self.line_visual)
        self.canvas.update_visual(self.text_visual)

        self.canvas.update()

    # -------------------------------------------------------------------------
    # Public methods
    # -------------------------------------------------------------------------

    def toggle_normalization(self, checked):
        """Change the normalization of the correlograms."""
        self.uniform_normalization = checked
        self.plot()

    def toggle_labels(self, checked):
        """Show or hide all labels."""
        if checked:
            self.text_visual.show()
        else:
            self.text_visual.hide()
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(CorrelogramView, self).attach(gui)

        self.actions.add(self.toggle_normalization,
                         shortcut='n',
                         checkable=True)
        self.actions.add(self.toggle_labels, checkable=True, checked=True)
        self.actions.separator()

        self.actions.add(self.set_bin,
                         prompt=True,
                         prompt_default=lambda: self.bin_size * 1000)
        self.actions.add(self.set_window,
                         prompt=True,
                         prompt_default=lambda: self.window_size * 1000)
        self.actions.add(self.set_refractory_period,
                         prompt=True,
                         prompt_default=lambda: self.refractory_period * 1000)
        self.actions.separator()

    # -------------------------------------------------------------------------
    # Methods for changing the parameters
    # -------------------------------------------------------------------------

    def _set_bin_window(self, bin_size=None, window_size=None):
        """Set the bin and window sizes (in seconds)."""
        bin_size = bin_size or self.bin_size
        window_size = window_size or self.window_size
        bin_size = _clip(bin_size, 1e-6, 1e3)
        window_size = _clip(window_size, 1e-6, 1e3)
        assert 1e-6 <= bin_size <= 1e3
        assert 1e-6 <= window_size <= 1e3
        assert bin_size < window_size
        self.bin_size = bin_size
        self.window_size = window_size
        self.update_status()

    @property
    def status(self):
        b, w = self.bin_size * 1000, self.window_size * 1000
        return '{:.1f} ms ({:.1f} ms)'.format(w, b)

    def set_refractory_period(self, value):
        """Set the refractory period (in milliseconds)."""
        self.refractory_period = _clip(value, .1, 100) * 1e-3
        self.plot()

    def set_bin(self, bin_size):
        """Set the correlogram bin size (in milliseconds).

        Example: `1`

        """
        self._set_bin_window(bin_size=bin_size * 1e-3)
        self.plot()

    def set_window(self, window_size):
        """Set the correlogram window size (in milliseconds).

        Example: `100`

        """
        self._set_bin_window(window_size=window_size * 1e-3)
        self.plot()

    def increase(self):
        """Increase the window size."""
        self.set_window(1000 * self.window_size * 1.1)

    def decrease(self):
        """Decrease the window size."""
        self.set_window(1000 * self.window_size / 1.1)

    def on_mouse_wheel(self, e):  # pragma: no cover
        """Change the scaling with the wheel."""
        super(CorrelogramView, self).on_mouse_wheel(e)
        if e.modifiers == ('Alt', ):
            self._set_bin_window(bin_size=self.bin_size * 1.1**e.delta)
            self.plot()
Пример #5
0
class FeatureView(MarkerSizeMixin, ScalingMixin, ManualClusteringView):
    """This view displays a 4x4 subplot matrix with different projections of the principal
    component features. This view keeps track of which channels are currently shown.

    Constructor
    -----------

    features : function
        Maps `(cluster_id, channel_ids=None, load_all=False)` to
        `Bunch(data, channel_ids, channel_labels, spike_ids , masks)`.
        * `data` is an `(n_spikes, n_channels, n_features)` array
        * `channel_ids` contains the channel ids of every row in `data`
        * `channel_labels` contains the channel labels of every row in `data`
        * `spike_ids` is a `(n_spikes,)` array
        * `masks` is an `(n_spikes, n_channels)` array

        This allows for a sparse format.

    attributes : dict
        Maps an attribute name to a 1D array with `n_spikes` numbers (for example, spike times).

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'
    cluster_ids = ()

    # Whether to disable automatic selection of channels.
    fixed_channels = False
    feature_scaling = 1.

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'increase': 'ctrl++',
        'decrease': 'ctrl+-',
        'add_lasso_point': 'ctrl+click',
        'stop_lasso': 'ctrl+right click',
        'toggle_automatic_channel_selection': 'c',
    }

    def __init__(self, features=None, attributes=None, **kwargs):
        super(FeatureView, self).__init__(**kwargs)
        self.state_attrs += ('fixed_channels', 'feature_scaling')

        assert features
        self.features = features
        self._lim = 1

        self.grid_dim = _get_default_grid()  # 2D array where every item a string like `0A,1B`
        self.n_rows, self.n_cols = np.array(self.grid_dim).shape
        self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols))
        self.canvas.enable_lasso()

        # Channels being shown.
        self.channel_ids = None

        # Attributes: extra features. This is a dictionary
        # {name: array}
        # where each array is a `(n_spikes,)` array.
        self.attributes = attributes or {}

        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)

    def set_grid_dim(self, grid_dim):
        """Change the grid dim dynamically.

        Parameters
        ----------
        grid_dim : array-like (2D)
            `grid_dim[row, col]` is a string with two values separated by a comma. Each value
            is the relative channel id (0, 1, 2...) followed by the PC (A, B, C...). For example,
            `grid_dim[row, col] = 0B,1A`. Each value can also be an attribute name, for example
            `time`. For example, `grid_dim[row, col] = time,2C`.

        """
        self.grid_dim = grid_dim
        self.n_rows, self.n_cols = np.array(grid_dim).shape
        self.canvas.grid.shape = (self.n_rows, self.n_cols)

    # Internal methods
    # -------------------------------------------------------------------------

    def _iter_subplots(self):
        """Yield (i, j, dim)."""
        for i in range(self.n_rows):
            for j in range(self.n_cols):
                dim = self.grid_dim[i][j]
                dim_x, dim_y = dim.split(',')
                yield i, j, dim_x, dim_y

    def _get_axis_label(self, dim):
        """Return the channel label from a dimension, if applicable."""
        if str(dim[:-1]).isdecimal():
            n = len(self.channel_ids)
            channel_id = self.channel_ids[int(dim[:-1]) % n]
            return self.channel_labels[channel_id] + dim[-1]
        else:
            return dim

    def _get_channel_and_pc(self, dim):
        """Return the channel_id and PC of a dim."""
        if self.channel_ids is None:
            return
        assert dim not in self.attributes  # This is called only on PC data.
        s = 'ABCDEFGHIJ'
        # Channel relative index, typically just 0 or 1.
        c_rel = int(dim[:-1])
        # Get the channel_id from the currently-selected channels.
        channel_id = self.channel_ids[c_rel % len(self.channel_ids)]
        pc = s.index(dim[-1])
        return channel_id, pc

    def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None):
        """Extract the points from the data on a given dimension.

        bunch is returned by the features() function.
        dim is the string specifying the dimensions to extract for the data.

        """
        if dim in self.attributes:
            return self.attributes[dim](cluster_id, load_all=load_all)
        masks = bunch.get('masks', None)
        channel_id, pc = self._get_channel_and_pc(dim)
        # Skip the plot if the channel id is not displayed.
        if channel_id not in bunch.channel_ids:  # pragma: no cover
            return Bunch(data=np.zeros((bunch.data.shape[0],)))
        # Get the column index of the current channel in data.
        c = list(bunch.channel_ids).index(channel_id)
        if masks is not None:
            masks = masks[:, c]
        return Bunch(data=self.feature_scaling * bunch.data[:, c, pc], masks=masks)

    def _get_axis_bounds(self, dim, bunch):
        """Return the min/max of an axis."""
        if dim in self.attributes:
            # Attribute: specified lim, or compute the min/max.
            vmin, vmax = bunch.get('lim', (0, 0))
            assert vmin is not None
            assert vmax is not None
            return vmin, vmax
        return (-self._lim, +self._lim)

    def _plot_points(self, bunch, clu_idx=None):
        if not bunch:
            return
        cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None
        for i, j, dim_x, dim_y in self._iter_subplots():
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id)
            # Skip empty data.
            if px is None or py is None:  # pragma: no cover
                logger.warning("Skipping empty data for cluster %d.", cluster_id)
                return
            assert px.data.shape == py.data.shape
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            assert xmin <= xmax
            assert ymin <= ymax
            data_bounds = (xmin, ymin, xmax, ymax)
            masks = _get_masks_max(px, py)
            # Prepare the batch visual with all subplots
            # for the selected cluster.
            self.visual.add_batch_data(
                x=px.data, y=py.data,
                color=_get_point_color(clu_idx),
                # Reduced marker size for background features
                size=self._marker_size,
                masks=_get_point_masks(clu_idx=clu_idx, masks=masks),
                data_bounds=data_bounds,
                box_index=(i, j),
            )
            # Get the channel ids corresponding to the relative channel indices
            # specified in the dimensions. Channel 0 corresponds to the first
            # best channel for the selected cluster, and so on.
            label_x = self._get_axis_label(dim_x)
            label_y = self._get_axis_label(dim_y)
            # Add labels.
            self.text_visual.add_batch_data(
                pos=[1, 1],
                anchor=[-1, -1],
                text=label_y,
                data_bounds=None,
                box_index=(i, j),
            )
            self.text_visual.add_batch_data(
                pos=[0, -1.],
                anchor=[0, 1],
                text=label_x,
                data_bounds=None,
                box_index=(i, j),
            )

    def _plot_axes(self):
        self.line_visual.reset_batch()
        for i, j, dim_x, dim_y in self._iter_subplots():
            self.line_visual.add_batch_data(
                pos=[[-1., 0., +1., 0.],
                     [0., -1., 0., +1.]],
                color=(.5, .5, .5, .5),
                box_index=(i, j),
                data_bounds=None,
            )
        self.canvas.update_visual(self.line_visual)

    def _get_lim(self, bunchs):
        if not bunchs:  # pragma: no cover
            return 1
        m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs)
        M = max(abs(m), abs(M))
        return M

    def _get_scaling_value(self):
        return self.feature_scaling

    def _set_scaling_value(self, value):
        self.feature_scaling = value
        self.plot(fixed_channels=True)

    # Public methods
    # -------------------------------------------------------------------------

    def clear_channels(self):
        """Reset the current channels."""
        self.channel_ids = None
        self.plot()

    def get_clusters_data(self, fixed_channels=None, load_all=None):
        # Get the feature data.
        # Specify the channel ids if these are fixed, otherwise
        # choose the first cluster's best channels.
        c = self.channel_ids if fixed_channels else None
        bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids]
        bunchs = [b for b in bunchs if b]
        if not bunchs:  # pragma: no cover
            return []
        for cluster_id, bunch in zip(self.cluster_ids, bunchs):
            bunch.cluster_id = cluster_id

        # Choose the channels based on the first selected cluster.
        channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else []
        common_channels = list(channel_ids)
        # Intersection (with order kept) of channels belonging to all clusters.
        for bunch in bunchs:
            common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels]
        # The selected channels will be (1) the channels common to all clusters, followed
        # by (2) remaining channels from the first cluster (excluding those already selected
        # in (1)).
        n = len(channel_ids)
        not_common_channels = [c for c in channel_ids if c not in common_channels]
        channel_ids = common_channels + not_common_channels[:n - len(common_channels)]
        assert len(channel_ids) > 0

        # Choose the channels automatically unless fixed_channels is set.
        if (not fixed_channels or self.channel_ids is None):
            self.channel_ids = channel_ids
        assert len(self.channel_ids)

        # Channel labels.
        self.channel_labels = {}
        for d in bunchs:
            chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])])
            self.channel_labels.update({
                channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))})

        return bunchs

    def plot(self, **kwargs):
        """Update the view with the selected clusters."""

        # Determine whether the channels should be fixed or not.
        added = kwargs.get('up', {}).get('added', None)
        # Fix the channels if the view updates after a cluster event
        # and there are new clusters.
        fixed_channels = (
            self.fixed_channels or kwargs.get('fixed_channels', None) or added is not None)

        # Get the clusters data.
        bunchs = self.get_clusters_data(fixed_channels=fixed_channels)
        bunchs = [b for b in bunchs if b]
        if not bunchs:
            return
        self._lim = self._get_lim(bunchs)

        # Get the background data.
        background = self.features(channel_ids=self.channel_ids)

        # Plot all features.
        self._plot_axes()

        # NOTE: the columns in bunch.data are ordered by decreasing quality
        # of the associated channels. The channels corresponding to each
        # column are given in bunch.channel_ids in the same order.

        # Plot points.
        self.visual.reset_batch()
        self.text_visual.reset_batch()

        self._plot_points(background)  # background spikes

        # Plot each cluster.
        for clu_idx, bunch in enumerate(bunchs):
            self._plot_points(bunch, clu_idx=clu_idx)

        # Upload the data on the GPU.
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.update()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(FeatureView, self).attach(gui)

        self.actions.add(
            self.toggle_automatic_channel_selection,
            checked=not self.fixed_channels, checkable=True)
        self.actions.add(self.clear_channels)
        self.actions.separator()

    def toggle_automatic_channel_selection(self, checked):
        """Toggle the automatic selection of channels when the cluster selection changes."""
        self.fixed_channels = not checked

    @property
    def status(self):
        if self.channel_ids is None:  # pragma: no cover
            return ''
        channel_labels = [self.channel_labels[ch] for ch in self.channel_ids[:2]]
        return 'channels: %s' % ', '.join(channel_labels)

    # Dimension selection
    # -------------------------------------------------------------------------

    def on_select_channel(self, sender=None, channel_id=None, key=None, button=None):
        """Respond to the click on a channel from another view, and update the
        relevant subplots."""
        channels = self.channel_ids
        if channels is None:
            return
        if len(channels) == 1:
            self.plot()
            return
        assert len(channels) >= 2
        # Get the axis from the pressed button (1, 2, etc.)
        if key is not None:
            d = np.clip(len(channels) - 1, 0, key - 1)
        else:
            d = 0 if button == 'Left' else 1
        # Change the first or second best channel.
        old = channels[d]
        # Avoid updating the view if the channel doesn't change.
        if channel_id == old:
            return
        channels[d] = channel_id
        # Ensure that the first two channels are different.
        if channels[1 - min(d, 1)] == channel_id:
            channels[1 - min(d, 1)] = old
        assert channels[0] != channels[1]
        # Remove duplicate channels.
        self.channel_ids = _uniq(channels)
        logger.debug("Choose channels %d and %d in feature view.", *channels[:2])
        # Fix the channels temporarily.
        self.plot(fixed_channels=True)
        self.update_status()

    def on_mouse_click(self, e):
        """Select a feature dimension by clicking on a box in the feature view."""
        b = e.button
        if 'Alt' in e.modifiers:
            # Get mouse position in NDC.
            (i, j), _ = self.canvas.grid.box_map(e.pos)
            dim = self.grid_dim[i][j]
            dim_x, dim_y = dim.split(',')
            dim = dim_x if b == 'Left' else dim_y
            other_dim = dim_y if b == 'Left' else dim_x
            if dim not in self.attributes:
                # When a regular (channel, PC) dimension is selected.
                channel_pc = self._get_channel_and_pc(dim)
                if channel_pc is None:
                    return
                channel_id, pc = channel_pc
                logger.debug("Click on feature dim %s, channel id %s, PC %s.", dim, channel_id, pc)
            else:
                # When the selected dimension is an attribute, e.g. "time".
                pc = None
                # Take the channel id in the other dimension.
                channel_pc = self._get_channel_and_pc(other_dim)
                channel_id = channel_pc[0] if channel_pc is not None else None
                logger.debug("Click on feature dim %s.", dim)
            emit('select_feature', self, dim=dim, channel_id=channel_id, pc=pc)

    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.canvas.layout.active_box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True)
            if not bunch:
                continue
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()
        return np.unique(spike_ids[ind])
Пример #6
0
        def on_view_attached(view, gui):
            if isinstance(view, AmplitudeView):
                # Create batch of vertical lines (full height)
                self.line_visual = LineVisual()
                _fix_coordinate_in_visual(self.line_visual, 'y')
                view.canvas.add_visual(self.line_visual)

                # Create batch of annotative text
                self.text_visual = TextVisual(self.line_color)
                _fix_coordinate_in_visual(self.text_visual, 'y')
                self.text_visual.inserter.insert_vert(
                    'gl_Position.x += 0.001;', 'after_transforms')
                view.canvas.add_visual(self.text_visual)

                @view.actions.add(shortcut='alt+b',
                                  checkable=True,
                                  name='Toggle event markers')
                def toggle(on):
                    """Toggle event markers"""
                    # Use `show` and `hide` instead of `toggle` here in
                    # case synchronization issues
                    if on:
                        logger.debug('Toggle on markers.')
                        self.line_visual.show()
                        self.text_visual.show()
                        view.show_events = True
                    else:
                        logger.debug('Toggle off markers.')
                        self.line_visual.hide()
                        self.text_visual.hide()
                        view.show_events = False
                    view.canvas.update()

                @view.actions.add(shortcut='shift+alt+e',
                                  prompt=True,
                                  name='Go to event',
                                  alias='ge')
                def Go_to_event(event_num):
                    trace_view = gui.get_view(TraceView)
                    if 0 < event_num <= events.size:
                        trace_view.go_to(events[event_num - 1])

                # Disable the menu until events are successfully added
                view.actions.disable('Go to event')
                view.actions.disable('Toggle event markers')
                if not hasattr(view, 'show_events'):
                    view.show_events = True
                view.state_attrs += ('show_events', )

                # Read event markers from file
                filename = controller.dir_path / 'eventmarkers.txt'
                try:
                    events = np.genfromtxt(filename, usecols=0, dtype=None)
                except (FileNotFoundError, OSError):
                    logger.warn('Event marker file not found: `%s`.', filename)
                    view.show_events = False
                    return

                # Create list of event names
                labels = list(map(str, range(1, events.size + 1)))

                # Read event names from file (if present)
                filename = controller.dir_path / 'eventmarkernames.txt'
                try:
                    eventnames = np.loadtxt(filename,
                                            usecols=0,
                                            dtype=str,
                                            max_rows=events.size)
                    labels[:eventnames.size] = np.atleast_1d(eventnames)
                except (FileNotFoundError, OSError):
                    logger.info(
                        'Event marker names file not found (optional):'
                        ' `%s`. Fall back to numbering.', filename)

                # Obtain seconds from samples
                if events.dtype == int:
                    logger.debug('Converting input from samples to seconds.')
                    events = events / controller.model.sample_rate

                logger.debug('Add event markers to amplitude view.')

                # Obtain horizontal positions
                x = -1 + 2 * events / view.duration
                x = x.repeat(4, 0).reshape(-1, 4)
                x[:, 1::2] = 1, -1

                # Add lines and update view
                self.line_visual.reset_batch()
                self.line_visual.add_batch_data(pos=x, color=self.line_color)
                view.canvas.update_visual(self.line_visual)

                # Add text and update view
                self.text_visual.reset_batch()
                self.text_visual.add_batch_data(pos=x[:, :2],
                                                anchor=(1, -1),
                                                text=labels)
                view.canvas.update_visual(self.text_visual)

                # Finally enable the menu
                logger.debug('Enable menu items.')
                view.actions.enable('Go to event')
                view.actions.enable('Toggle event markers')
                if view.show_events:
                    view.actions.get('Toggle event markers').toggle()
                else:
                    self.line_visual.hide()
                    self.text_visual.hide()
Пример #7
0
    def __init__(self, amplitudes=None, amplitudes_type=None, duration=None, path=None,sample_rate=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitudes_type',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes

        # Rotating property amplitudes types.
        self.amplitudes_types = RotatingProperty()
        for name, value in self.amplitudes.items():
            self.amplitudes_types.add(name, value)
        # Current amplitudes type.
        self.amplitudes_types.set(amplitudes_type)
        assert self.amplitudes_type in self.amplitudes

        self.show_background_clusters=True

        self.cluster_ids = ()
        self.duration = duration or 1.

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)),
            Rotate('cw'),
            Scale((1, -1)),
            Translate((2.05, 0)),
        ])
        self.canvas.add_visual(self.hist_visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0)

        # Yellow vertical bar showing the selected time interval.
        self.patch_visual = PatchVisual(primitive_type='triangle_fan')
        self.patch_visual.inserter.insert_vert('''
            const float MIN_INTERVAL_SIZE = 0.01;
            uniform float u_interval_size;
        ''', 'header')
        self.patch_visual.inserter.insert_vert('''
            gl_Position.y = pos_orig.y;

            // The following is used to ensure that (1) the bar width increases with the zoom level
            // but also (2) there is a minimum absolute width so that the bar remains visible
            // at low zoom levels.
            float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x);
            // HACK: the z coordinate is used to store 0 or 1, depending on whether the current
            // vertex is on the left or right edge of the bar.
            gl_Position.x += w * (-1 + 2 * int(a_position.z == 0));

        ''', 'after_transforms')
        self.canvas.add_visual(self.patch_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)
        
        blocksizes_path=op.join(path,'blocksizes.npy')
        if op.exists(blocksizes_path):
            self.blocksizes=np.load(blocksizes_path)[0]
            self.blockstarts=np.load(op.join(path,'blockstarts.npy'))[0]
            self.blocksizes_time=self.blocksizes/sample_rate
            self.blockstarts_time=self.blockstarts/sample_rate
            self.gap=np.diff(self.blockstarts)-self.blocksizes[:-1]
            self.gap_time=np.diff(self.blockstarts_time)-self.blocksizes_time[:-1]
            self.show_block_gap=False
            self.show_block_lines=True 
        else:
            self.show_block_gap=False
            self.show_block_lines=False
         
        
        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))
Пример #8
0
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView):
    """This view displays an amplitude plot for all selected clusters.

    Constructor
    -----------

    amplitudes : dict
        Dictionary `{amplitudes_type: function}`, for different types of amplitudes.

        Each function maps `cluster_ids` to a list
        `[Bunch(amplitudes, spike_ids, spike_times), ...]` for each cluster.
        Use `cluster_id=None` for background amplitudes.

    """

    # Do not show too many clusters.
    max_n_clusters = 8

    _default_position = 'right'

    # Alpha channel of the markers in the scatter plot.
    marker_alpha = 1.
    time_range_color = (1., 1., 0., .25)

    # Number of bins in the histogram.
    n_bins = 100

    # Alpha channel of the histogram in the background.
    histogram_alpha = .5

    # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers).
    quantile = .99

    # Size of the histogram, between 0 and 1.
    histogram_scale = .25

    default_shortcuts = {
        'change_marker_size': 'alt+wheel',
        'next_amplitudes_type': 'a',
        'toggle_other_clusters': 'shift+a',
        'select_x_dim': 'shift+left click',
        'select_y_dim': 'shift+right click',
        'select_time': 'alt+click',
    }

    def __init__(self, amplitudes=None, amplitudes_type=None, duration=None, path=None,sample_rate=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitudes_type',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes

        # Rotating property amplitudes types.
        self.amplitudes_types = RotatingProperty()
        for name, value in self.amplitudes.items():
            self.amplitudes_types.add(name, value)
        # Current amplitudes type.
        self.amplitudes_types.set(amplitudes_type)
        assert self.amplitudes_type in self.amplitudes

        self.show_background_clusters=True

        self.cluster_ids = ()
        self.duration = duration or 1.

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)),
            Rotate('cw'),
            Scale((1, -1)),
            Translate((2.05, 0)),
        ])
        self.canvas.add_visual(self.hist_visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0)

        # Yellow vertical bar showing the selected time interval.
        self.patch_visual = PatchVisual(primitive_type='triangle_fan')
        self.patch_visual.inserter.insert_vert('''
            const float MIN_INTERVAL_SIZE = 0.01;
            uniform float u_interval_size;
        ''', 'header')
        self.patch_visual.inserter.insert_vert('''
            gl_Position.y = pos_orig.y;

            // The following is used to ensure that (1) the bar width increases with the zoom level
            // but also (2) there is a minimum absolute width so that the bar remains visible
            // at low zoom levels.
            float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x);
            // HACK: the z coordinate is used to store 0 or 1, depending on whether the current
            // vertex is on the left or right edge of the bar.
            gl_Position.x += w * (-1 + 2 * int(a_position.z == 0));

        ''', 'after_transforms')
        self.canvas.add_visual(self.patch_visual)

        self.line_visual = LineVisual()
        self.canvas.add_visual(self.line_visual)
        
        blocksizes_path=op.join(path,'blocksizes.npy')
        if op.exists(blocksizes_path):
            self.blocksizes=np.load(blocksizes_path)[0]
            self.blockstarts=np.load(op.join(path,'blockstarts.npy'))[0]
            self.blocksizes_time=self.blocksizes/sample_rate
            self.blockstarts_time=self.blockstarts/sample_rate
            self.gap=np.diff(self.blockstarts)-self.blocksizes[:-1]
            self.gap_time=np.diff(self.blockstarts_time)-self.blocksizes_time[:-1]
            self.show_block_gap=False
            self.show_block_lines=True 
        else:
            self.show_block_gap=False
            self.show_block_lines=False
         
        
        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))

    def _get_data_bounds(self, bunchs):
        """Compute the data bounds."""
        if not bunchs:  # pragma: no cover
            return (0, 0, self.duration, 1)
        m = min(
            np.quantile(bunch.amplitudes, 1 - self.quantile)
            for bunch in bunchs if len(bunch.amplitudes))
        m = min(0, m)  # ensure ymin <= 0
        M = max(
            np.quantile(bunch.amplitudes, self.quantile)
            for bunch in bunchs if len(bunch.amplitudes))
        return (0, m, self.duration, M)

    def _add_histograms(self, bunchs):
        # We do this after get_clusters_data because we need x_max.
        for bunch in bunchs:
            bunch.histogram = _compute_histogram(
                bunch.amplitudes,
                x_min=self.data_bounds[1],
                x_max=self.data_bounds[3],
                n_bins=self.n_bins,
                normalize=True,
                ignore_zeros=True,
            )
        return bunchs

    def show_time_range(self, interval=(0, 0)):
        start, end = interval
        x0 = -1 + 2 * (start / self.duration)
        x1 = -1 + 2 * (end / self.duration)
        xm = .5 * (x0 + x1)
        pos = np.array([
            [xm, -1],
            [xm, +1],
            [xm, +1],
            [xm, -1],
        ])
        self.patch_visual.program['u_interval_size'] = .5 * (x1 - x0)
        self.patch_visual.set_data(pos=pos, color=self.time_range_color, depth=[0, 0, 1, 1])
        self.canvas.update()

    def _plot_cluster(self, bunch):
        """Make the scatter plot."""
        ms = self._marker_size
        if not len(bunch.histogram):
            return

        # Histogram in the background.
        self.hist_visual.add_batch_data(
            hist=bunch.histogram,
            ylim=self._ylim,
            color=add_alpha(bunch.color, self.histogram_alpha))

        # Scatter plot.
        self.visual.add_batch_data(
            pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds)

    def get_clusters_data(self, load_all=None):
        """Return a list of Bunch instances, with attributes pos and spike_ids."""
        if not len(self.cluster_ids):
            return
        cluster_ids = list(self.cluster_ids)
        # Don't need the background when splitting.
        if not load_all and self.show_background_clusters:
            # Add None cluster which means background spikes.
            cluster_ids = [None] + cluster_ids
            color_ind_offset = -1
        else:
            color_ind_offset=0
        bunchs = self.amplitudes[self.amplitudes_type](cluster_ids, load_all=load_all) or ()
        # Add a pos attribute in bunchs in addition to x and y.
        for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
            spike_ids = _as_array(bunch.spike_ids)
            spike_times = _as_array(bunch.spike_times)
            amplitudes = _as_array(bunch.amplitudes)
            assert spike_ids.shape == spike_times.shape == amplitudes.shape
            # Ensure that bunch.pos exists, as it used by the LassoMixin.
            bunch.pos = np.c_[spike_times, amplitudes]
            assert bunch.pos.ndim == 2
            bunch.cluster_id = cluster_id
            bunch.color = (
                selected_cluster_color(i + color_ind_offset, self.marker_alpha)
                # Background amplitude color.
                if cluster_id is not None else (.5, .5, .5, .5))
        return bunchs

    def plot(self, **kwargs):
        """Update the view with the current cluster selection."""
        bunchs = self.get_clusters_data(**kwargs)
        if not bunchs:
            return
        self.data_bounds = self._get_data_bounds(bunchs)
        bunchs = self._add_histograms(bunchs)
        # Use the same scale for all histograms.
        self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1.

        self.visual.reset_batch()
        self.hist_visual.reset_batch()
        for bunch in bunchs:
            self._plot_cluster(bunch)
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.hist_visual)

        if self.show_block_lines:                    
                line_times = self.blocksizes_time[:-1].cumsum()
                line_times=np.append(line_times,[0,self.data_bounds[2]])
                line_points=[]
                for time in line_times:
                    line_points.append(np.array([time, self.data_bounds[1], time, self.data_bounds[3]]))
                pos = np.array(line_points)
                self.line_visual.add_batch_data(
                		pos=pos, 
                		color=(1, .5, 0, 1.),
                		data_bounds=self.data_bounds)
                self.canvas.update_visual(self.line_visual)
        
        self._update_axes()
        self.canvas.update()
        self.update_status()

    def attach(self, gui):
        """Attach the view to the GUI."""
        super(AmplitudeView, self).attach(gui)

        # Amplitude type actions.
        def _make_amplitude_action(a):
            def callback():
                self.amplitudes_type = a
                self.plot()
            return callback

        for a in self.amplitudes_types.keys():
            name = 'Change amplitudes type to %s' % a
            self.actions.add(
                _make_amplitude_action(a), show_shortcut=False,
                name=name, view_submenu='Change amplitudes type')

        self.actions.add(self.next_amplitudes_type, set_busy=True)
        self.actions.add(self.previous_amplitudes_type, set_busy=True)
        self.actions.add(self.toggle_other_clusters, set_busy=True)

    @property
    def status(self):
        return self.amplitudes_type

    @property
    def amplitudes_type(self):
        return self.amplitudes_types.current

    @amplitudes_type.setter
    def amplitudes_type(self, value):
        self.amplitudes_types.set(value)

    def next_amplitudes_type(self):
        """Switch to the next amplitudes type."""
        self.amplitudes_types.next()
        logger.debug("Switch to amplitudes type: %s.", self.amplitudes_types.current)
        self.plot()

    def toggle_other_clusters(self):
        self.show_background_clusters = not self.show_background_clusters
        print(self.show_background_clusters)
        self.plot()

    def previous_amplitudes_type(self):
        """Switch to the previous amplitudes type."""
        self.amplitudes_types.previous()
        logger.debug("Switch to amplitudes type: %s.", self.amplitudes_types.current)
        self.plot()

    def on_mouse_click(self, e):
        """Select a time from the amplitude view to display in the trace view."""
       # from pdb import set_trace
       # set_trace()  
        if 'Shift' in e.modifiers:
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            time = Range(NDC, self.data_bounds).apply(mouse_pos)[0][0]
            emit('select_time', self, time)