class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveforms_type : str Default key of the waveforms dictionary to plot initially. """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' ax_color = (.75, .75, .75, 1.) tick_size = 5. cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'previous_waveforms_type': 'shift+w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, sample_rate=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () self.wave_duration = 0. # updated in the plotting method self.data_bounds = None self.sample_rate = sample_rate self._status_suffix = '' assert sample_rate > 0., "The sample rate must be provided to the waveform view." # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2))) # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms or {} waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } self.waveforms = waveforms # Rotating property waveforms types. self.waveforms_types = RotatingProperty() for name, value in self.waveforms.items(): self.waveforms_types.add(name, value) # Current waveforms type. self.waveforms_types.set(waveforms_type) assert self.waveforms_type in self.waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.tick_visual = UniformScatterVisual(marker='vbar', color=self.ax_color, size=self.tick_size) self.canvas.add_visual(self.tick_visual) # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased # agg plot visual for mean and template waveforms. self.waveform_agg_visual = PlotAggVisual() self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_agg_visual) self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- @property def _current_visual(self): if self.waveforms_type == 'waveforms': return self.waveform_visual else: return self.waveform_agg_visual def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) # Symmetrize on the y axis. M = max(abs(m), abs(M)) return [-1, -M, +1, M] def get_clusters_data(self): if self.waveforms_type not in self.waveforms: return bunchs = [ self.waveforms_types.get()(cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. eps = .001 masks = eps + (1 - 2 * eps) * masks # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.tile(box_index, n_spikes_clu) # Find the correct number of vertices depending on the current waveform visual. if self._current_visual == self.waveform_visual: # PlotVisual box_index = np.repeat(box_index, n_samples) assert box_index.size == n_spikes_clu * n_channels * n_samples else: # PlotAggVisual box_index = np.repeat(box_index, 2 * (n_samples + 2)) assert box_index.size == n_spikes_clu * n_channels * 2 * ( n_samples + 2) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) nw = n_spikes_clu * n_channels wave = wave.reshape((nw, n_samples)) assert self.data_bounds is not None self._current_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) # Waveform axes. # -------------- # Horizontal y=0 lines. ax_db = self.data_bounds a, b = _overlap_transform(np.array([-1, 1]), offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, 2) box_index = np.tile(box_index, n_spikes_clu) hpos = np.tile([[a, 0, b, 0]], (nw, 1)) assert box_index.size == hpos.shape[0] * 2 self.line_visual.add_batch_data( pos=hpos, color=self.ax_color, data_bounds=ax_db, box_index=box_index, ) # Vertical ticks every millisecond. steps = np.arange(np.round(self.wave_duration * 1000)) # A vline every millisecond. x = .001 * steps # Scale to [-1, 1], same coordinates as the waveform points. x = -1 + 2 * x / self.wave_duration # Take overlap into account. x = _overlap_transform(x, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) x = np.tile(x, len(channel_ids_loc)) # Generate the box index. box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, x.size // len(box_index)) assert x.size == box_index.size self.tick_visual.add_batch_data( x=x, y=np.zeros_like(x), data_bounds=ax_db, box_index=box_index, ) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() if not bunchs: return # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids if bunchs[0].data is not None: self.wave_duration = bunchs[0].data.shape[1] / float( self.sample_rate) else: # pragma: no cover self.wave_duration = 1. # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the Boxed box positions as a function of the selected channels. if channel_ids: self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids)) self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs) self._current_visual.reset_batch() self.line_visual.reset_batch() self.tick_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.tick_visual) self.canvas.update_visual(self.line_visual) self.canvas.update_visual(self._current_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) # Only show the current waveform visual. if self._current_visual == self.waveform_visual: self.waveform_visual.show() self.waveform_agg_visual.hide() elif self._current_visual == self.waveform_agg_visual: self.waveform_agg_visual.show() self.waveform_visual.hide() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.previous_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed @property def status(self): return self.waveforms_type # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def widen(self): """Increase the horizontal scaling of the waveforms.""" self.boxed.expand_box_width() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_box_width() @property def box_scaling(self): return self.boxed._box_scaling @box_scaling.setter def box_scaling(self, value): self.boxed._box_scaling = value def _get_scaling_value(self): return self.boxed._box_scaling[1] def _set_scaling_value(self, value): w, h = self.boxed._box_scaling self.boxed._box_scaling = (w, value) self.boxed.update() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): return self.boxed._layout_scaling @probe_scaling.setter def probe_scaling(self, value): self.boxed._layout_scaling = value def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" self.boxed.expand_layout_width() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_layout_width() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" self.boxed.expand_layout_height() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" self.boxed.shrink_layout_height() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('select_channel', self, channel_id=channel_id, key=key, button=b) @property def waveforms_type(self): return self.waveforms_types.current @waveforms_type.setter def waveforms_type(self, value): self.waveforms_types.set(value) def next_waveforms_type(self): """Switch to the next waveforms type.""" self.waveforms_types.next() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def previous_waveforms_type(self): """Switch to the previous waveforms type.""" self.waveforms_types.previous() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms: self.waveforms_types.set('waveforms') logger.debug("Switch to raw waveforms.") self.plot() elif 'mean_waveforms' in self.waveforms: self.waveforms_types.set('mean_waveforms') logger.debug("Switch to mean waveforms.") self.plot()
class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveform_type : str Default key of the waveforms dictionary to plot initially. """ _default_position = 'right' cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]]) self.canvas.enable_axes() self._box_scaling = (1., 1.) self._probe_scaling = (1., 1.) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } assert waveforms self.waveforms = waveforms self.waveforms_types = list(waveforms.keys()) # Current waveforms type. self.waveforms_type = waveforms_type or self.waveforms_types[0] assert self.waveforms_type in waveforms assert 'waveforms' in waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) return [-1, m, +1, M] def get_clusters_data(self): bunchs = [ self.waveforms[self.waveforms_type](cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. masks *= .99999 # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, n_samples) box_index = np.tile(box_index, n_spikes_clu) assert box_index.shape == (n_spikes_clu * n_channels * n_samples, ) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) wave = wave.reshape((n_spikes_clu * n_channels, n_samples)) self.waveform_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the box bounds as a function of the selected channels. if channel_ids: self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() self.data_bounds = self._get_data_bounds(bunchs) self.waveform_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.waveform_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) self._update_axes(bunchs) self.canvas.update() def _update_axes(self, bunchs): """Update the axes.""" # Update the axes data bounds. _, m, _, M = self.data_bounds # Waveform duration, scaled by overlap factor if needed. wave_dur = bunchs[0].get('waveform_duration', 1.) wave_dur /= .5 * (1 + _overlap_transform( 1, n=len(self.cluster_ids), overlap=self.overlap)) x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC, [wave_dur, M - m]) axes_data_bounds = (0, 0, x1, y1) self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True) def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def _update_boxes(self): self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling, self.box_size) def _apply_box_scaling(self): self.canvas.layout.scaling = self._box_scaling @property def box_scaling(self): """Scaling of the channel boxes.""" return self._box_scaling @box_scaling.setter def box_scaling(self, value): assert len(value) == 2 self._box_scaling = value self._apply_box_scaling() def widen(self): """Increase the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w * self._scaling_param_increment, h) self._apply_box_scaling() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w / self._scaling_param_increment, h) self._apply_box_scaling() def _get_scaling_value(self): return self._box_scaling[1] def _set_scaling_value(self, value): w, h = self._box_scaling self.box_scaling = (w, value) self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): """Scaling of the entire probe.""" return self._probe_scaling @probe_scaling.setter def probe_scaling(self, value): assert len(value) == 2 self._probe_scaling = value self._update_boxes() def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" w, h = self._probe_scaling self._probe_scaling = (w * self._scaling_param_increment, h) self._update_boxes() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w / self._scaling_param_increment, h) self._update_boxes() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h * self._scaling_param_increment) self._update_boxes() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h / self._scaling_param_increment) self._update_boxes() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('channel_click', self, channel_id=channel_id, key=key, button=b) def next_waveforms_type(self): """Switch to the next waveforms type.""" i = self.waveforms_types.index(self.waveforms_type) n = len(self.waveforms_types) self.waveforms_type = self.waveforms_types[(i + 1) % n] logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms': self.waveforms_type = 'waveforms' self.plot() elif 'mean_waveforms' in self.waveforms_types: self.waveforms_type = 'mean_waveforms' self.plot()
class CorrelogramView(ScalingMixin, ManualClusteringView): """A view showing the autocorrelogram of the selected clusters, and all cross-correlograms of cluster pairs. Constructor ----------- correlograms : function Maps `(cluster_ids, bin_size, window_size)` to an `(n_clusters, n_clusters, n_bins) array`. firing_rate : function Maps `(cluster_ids, bin_size)` to an `(n_clusters, n_clusters) array` """ # Do not show too many clusters. max_n_clusters = 20 _default_position = 'left' cluster_ids = () # Bin size, in seconds. bin_size = 1e-3 # Window size, in seconds. window_size = 50e-3 # Refactory period, in seconds refractory_period = 2e-3 # Whether the normalization is uniform across entire rows or not. uniform_normalization = False default_shortcuts = { 'change_window_size': 'ctrl+wheel', 'change_bin_size': 'alt+wheel', } default_snippets = { 'set_bin': 'cb', 'set_window': 'cw', 'set_refractory_period': 'cr', } def __init__(self, correlograms=None, firing_rate=None, sample_rate=None, **kwargs): super(CorrelogramView, self).__init__(**kwargs) self.state_attrs += ('bin_size', 'window_size', 'refractory_period', 'uniform_normalization') self.local_state_attrs += () self.canvas.set_layout(layout='grid') # Outside margin to show labels. self.canvas.gpu_transforms.add(Scale(.9)) assert sample_rate > 0 self.sample_rate = float(sample_rate) # Function clusters => CCGs. self.correlograms = correlograms # Function clusters => firing rates (same unit as CCG). self.firing_rate = firing_rate # Set the default bin and window size. self._set_bin_window(bin_size=self.bin_size, window_size=self.window_size) self.correlogram_visual = HistogramVisual() self.canvas.add_visual(self.correlogram_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.text_visual = TextVisual(color=(1., 1., 1., 1.)) self.canvas.add_visual(self.text_visual) # ------------------------------------------------------------------------- # Internal methods # ------------------------------------------------------------------------- def _iter_subplots(self, n_clusters): for i in range(n_clusters): for j in range(n_clusters): yield i, j def get_clusters_data(self, load_all=None): ccg = self.correlograms(self.cluster_ids, self.bin_size, self.window_size) fr = self.firing_rate(self.cluster_ids, self.bin_size) if self.firing_rate else None assert ccg.ndim == 3 n_bins = ccg.shape[2] bunchs = [] m = ccg.max() for i, j in self._iter_subplots(len(self.cluster_ids)): b = Bunch() b.correlogram = ccg[i, j, :] if not self.uniform_normalization: # Normalization row per row. m = ccg[i, j, :].max() b.firing_rate = fr[i, j] if fr is not None else None b.data_bounds = (0, 0, n_bins, m) b.pair_index = i, j b.color = selected_cluster_color(i, 1) if i != j: b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1)) bunchs.append(b) return bunchs def _plot_pair(self, bunch): # Plot the histogram. self.correlogram_visual.add_batch_data(hist=bunch.correlogram, color=bunch.color, ylim=bunch.data_bounds[3], box_index=bunch.pair_index) # Plot the firing rate. gray = (.25, .25, .25, 1.) if bunch.firing_rate is not None: # Line. pos = np.array([[ 0, bunch.firing_rate, bunch.data_bounds[2], bunch.firing_rate ]]) self.line_visual.add_batch_data(pos=pos, color=gray, data_bounds=bunch.data_bounds, box_index=bunch.pair_index) # # Text. # self.text_visual.add_batch_data( # pos=[bunch.data_bounds[2], bunch.firing_rate], # text='%.2f' % bunch.firing_rate, # anchor=(-1, 0), # box_index=bunch.pair_index, # data_bounds=bunch.data_bounds, # ) # Refractory period. xrp0 = round( (self.window_size * .5 - self.refractory_period) / self.bin_size) xrp1 = round((self.window_size * .5 + self.refractory_period) / self.bin_size) + 1 ylim = bunch.data_bounds[3] pos = np.array([[xrp0, 0, xrp0, ylim], [xrp1, 0, xrp1, ylim]]) self.line_visual.add_batch_data(pos=pos, color=gray, data_bounds=bunch.data_bounds, box_index=bunch.pair_index) def _plot_labels(self): n = len(self.cluster_ids) # Display the cluster ids in the subplots. for k in range(n): self.text_visual.add_batch_data( pos=[-1, 0], text=str(self.cluster_ids[k]), anchor=[-1.25, 0], data_bounds=None, box_index=(k, 0), ) self.text_visual.add_batch_data( pos=[0, -1], text=str(self.cluster_ids[k]), anchor=[0, -1.25], data_bounds=None, box_index=(n - 1, k), ) # # Display the window size in the bottom right subplot. # self.text_visual.add_batch_data( # pos=[1, -1], # anchor=[1.25, 1], # text='%.1f ms' % (1000 * .5 * self.window_size), # box_index=(n - 1, n - 1), # ) def plot(self, **kwargs): """Update the view with the current cluster selection.""" self.canvas.grid.shape = (len(self.cluster_ids), len(self.cluster_ids)) bunchs = self.get_clusters_data() self.correlogram_visual.reset_batch() self.line_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_pair(bunch) self._plot_labels() self.canvas.update_visual(self.correlogram_visual) self.canvas.update_visual(self.line_visual) self.canvas.update_visual(self.text_visual) self.canvas.update() # ------------------------------------------------------------------------- # Public methods # ------------------------------------------------------------------------- def toggle_normalization(self, checked): """Change the normalization of the correlograms.""" self.uniform_normalization = checked self.plot() def toggle_labels(self, checked): """Show or hide all labels.""" if checked: self.text_visual.show() else: self.text_visual.hide() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(CorrelogramView, self).attach(gui) self.actions.add(self.toggle_normalization, shortcut='n', checkable=True) self.actions.add(self.toggle_labels, checkable=True, checked=True) self.actions.separator() self.actions.add(self.set_bin, prompt=True, prompt_default=lambda: self.bin_size * 1000) self.actions.add(self.set_window, prompt=True, prompt_default=lambda: self.window_size * 1000) self.actions.add(self.set_refractory_period, prompt=True, prompt_default=lambda: self.refractory_period * 1000) self.actions.separator() # ------------------------------------------------------------------------- # Methods for changing the parameters # ------------------------------------------------------------------------- def _set_bin_window(self, bin_size=None, window_size=None): """Set the bin and window sizes (in seconds).""" bin_size = bin_size or self.bin_size window_size = window_size or self.window_size bin_size = _clip(bin_size, 1e-6, 1e3) window_size = _clip(window_size, 1e-6, 1e3) assert 1e-6 <= bin_size <= 1e3 assert 1e-6 <= window_size <= 1e3 assert bin_size < window_size self.bin_size = bin_size self.window_size = window_size self.update_status() @property def status(self): b, w = self.bin_size * 1000, self.window_size * 1000 return '{:.1f} ms ({:.1f} ms)'.format(w, b) def set_refractory_period(self, value): """Set the refractory period (in milliseconds).""" self.refractory_period = _clip(value, .1, 100) * 1e-3 self.plot() def set_bin(self, bin_size): """Set the correlogram bin size (in milliseconds). Example: `1` """ self._set_bin_window(bin_size=bin_size * 1e-3) self.plot() def set_window(self, window_size): """Set the correlogram window size (in milliseconds). Example: `100` """ self._set_bin_window(window_size=window_size * 1e-3) self.plot() def increase(self): """Increase the window size.""" self.set_window(1000 * self.window_size * 1.1) def decrease(self): """Decrease the window size.""" self.set_window(1000 * self.window_size / 1.1) def on_mouse_wheel(self, e): # pragma: no cover """Change the scaling with the wheel.""" super(CorrelogramView, self).on_mouse_wheel(e) if e.modifiers == ('Alt', ): self._set_bin_window(bin_size=self.bin_size * 1.1**e.delta) self.plot()
class EventMarker(IPlugin): # Line color of the event markers line_color = (1, 1, 1, 0.75) def attach_to_controller(self, controller): @connect def on_view_attached(view, gui): if isinstance(view, AmplitudeView): # Create batch of vertical lines (full height) self.line_visual = LineVisual() _fix_coordinate_in_visual(self.line_visual, 'y') view.canvas.add_visual(self.line_visual) # Create batch of annotative text self.text_visual = TextVisual(self.line_color) _fix_coordinate_in_visual(self.text_visual, 'y') self.text_visual.inserter.insert_vert( 'gl_Position.x += 0.001;', 'after_transforms') view.canvas.add_visual(self.text_visual) @view.actions.add(shortcut='alt+b', checkable=True, name='Toggle event markers') def toggle(on): """Toggle event markers""" # Use `show` and `hide` instead of `toggle` here in # case synchronization issues if on: logger.debug('Toggle on markers.') self.line_visual.show() self.text_visual.show() view.show_events = True else: logger.debug('Toggle off markers.') self.line_visual.hide() self.text_visual.hide() view.show_events = False view.canvas.update() @view.actions.add(shortcut='shift+alt+e', prompt=True, name='Go to event', alias='ge') def Go_to_event(event_num): trace_view = gui.get_view(TraceView) if 0 < event_num <= events.size: trace_view.go_to(events[event_num - 1]) # Disable the menu until events are successfully added view.actions.disable('Go to event') view.actions.disable('Toggle event markers') if not hasattr(view, 'show_events'): view.show_events = True view.state_attrs += ('show_events', ) # Read event markers from file filename = controller.dir_path / 'eventmarkers.txt' try: events = np.genfromtxt(filename, usecols=0, dtype=None) except (FileNotFoundError, OSError): logger.warn('Event marker file not found: `%s`.', filename) view.show_events = False return # Create list of event names labels = list(map(str, range(1, events.size + 1))) # Read event names from file (if present) filename = controller.dir_path / 'eventmarkernames.txt' try: eventnames = np.loadtxt(filename, usecols=0, dtype=str, max_rows=events.size) labels[:eventnames.size] = np.atleast_1d(eventnames) except (FileNotFoundError, OSError): logger.info( 'Event marker names file not found (optional):' ' `%s`. Fall back to numbering.', filename) # Obtain seconds from samples if events.dtype == int: logger.debug('Converting input from samples to seconds.') events = events / controller.model.sample_rate logger.debug('Add event markers to amplitude view.') # Obtain horizontal positions x = -1 + 2 * events / view.duration x = x.repeat(4, 0).reshape(-1, 4) x[:, 1::2] = 1, -1 # Add lines and update view self.line_visual.reset_batch() self.line_visual.add_batch_data(pos=x, color=self.line_color) view.canvas.update_visual(self.line_visual) # Add text and update view self.text_visual.reset_batch() self.text_visual.add_batch_data(pos=x[:, :2], anchor=(1, -1), text=labels) view.canvas.update_visual(self.text_visual) # Finally enable the menu logger.debug('Enable menu items.') view.actions.enable('Go to event') view.actions.enable('Toggle event markers') if view.show_events: view.actions.get('Toggle event markers').toggle() else: self.line_visual.hide() self.text_visual.hide()