class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveforms_type : str Default key of the waveforms dictionary to plot initially. """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' ax_color = (.75, .75, .75, 1.) tick_size = 5. cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'previous_waveforms_type': 'shift+w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, sample_rate=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () self.wave_duration = 0. # updated in the plotting method self.data_bounds = None self.sample_rate = sample_rate self._status_suffix = '' assert sample_rate > 0., "The sample rate must be provided to the waveform view." # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2))) # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms or {} waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } self.waveforms = waveforms # Rotating property waveforms types. self.waveforms_types = RotatingProperty() for name, value in self.waveforms.items(): self.waveforms_types.add(name, value) # Current waveforms type. self.waveforms_types.set(waveforms_type) assert self.waveforms_type in self.waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.tick_visual = UniformScatterVisual(marker='vbar', color=self.ax_color, size=self.tick_size) self.canvas.add_visual(self.tick_visual) # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased # agg plot visual for mean and template waveforms. self.waveform_agg_visual = PlotAggVisual() self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_agg_visual) self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- @property def _current_visual(self): if self.waveforms_type == 'waveforms': return self.waveform_visual else: return self.waveform_agg_visual def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) # Symmetrize on the y axis. M = max(abs(m), abs(M)) return [-1, -M, +1, M] def get_clusters_data(self): if self.waveforms_type not in self.waveforms: return bunchs = [ self.waveforms_types.get()(cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. eps = .001 masks = eps + (1 - 2 * eps) * masks # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.tile(box_index, n_spikes_clu) # Find the correct number of vertices depending on the current waveform visual. if self._current_visual == self.waveform_visual: # PlotVisual box_index = np.repeat(box_index, n_samples) assert box_index.size == n_spikes_clu * n_channels * n_samples else: # PlotAggVisual box_index = np.repeat(box_index, 2 * (n_samples + 2)) assert box_index.size == n_spikes_clu * n_channels * 2 * ( n_samples + 2) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) nw = n_spikes_clu * n_channels wave = wave.reshape((nw, n_samples)) assert self.data_bounds is not None self._current_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) # Waveform axes. # -------------- # Horizontal y=0 lines. ax_db = self.data_bounds a, b = _overlap_transform(np.array([-1, 1]), offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, 2) box_index = np.tile(box_index, n_spikes_clu) hpos = np.tile([[a, 0, b, 0]], (nw, 1)) assert box_index.size == hpos.shape[0] * 2 self.line_visual.add_batch_data( pos=hpos, color=self.ax_color, data_bounds=ax_db, box_index=box_index, ) # Vertical ticks every millisecond. steps = np.arange(np.round(self.wave_duration * 1000)) # A vline every millisecond. x = .001 * steps # Scale to [-1, 1], same coordinates as the waveform points. x = -1 + 2 * x / self.wave_duration # Take overlap into account. x = _overlap_transform(x, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) x = np.tile(x, len(channel_ids_loc)) # Generate the box index. box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, x.size // len(box_index)) assert x.size == box_index.size self.tick_visual.add_batch_data( x=x, y=np.zeros_like(x), data_bounds=ax_db, box_index=box_index, ) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() if not bunchs: return # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids if bunchs[0].data is not None: self.wave_duration = bunchs[0].data.shape[1] / float( self.sample_rate) else: # pragma: no cover self.wave_duration = 1. # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the Boxed box positions as a function of the selected channels. if channel_ids: self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids)) self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs) self._current_visual.reset_batch() self.line_visual.reset_batch() self.tick_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.tick_visual) self.canvas.update_visual(self.line_visual) self.canvas.update_visual(self._current_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) # Only show the current waveform visual. if self._current_visual == self.waveform_visual: self.waveform_visual.show() self.waveform_agg_visual.hide() elif self._current_visual == self.waveform_agg_visual: self.waveform_agg_visual.show() self.waveform_visual.hide() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.previous_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed @property def status(self): return self.waveforms_type # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def widen(self): """Increase the horizontal scaling of the waveforms.""" self.boxed.expand_box_width() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_box_width() @property def box_scaling(self): return self.boxed._box_scaling @box_scaling.setter def box_scaling(self, value): self.boxed._box_scaling = value def _get_scaling_value(self): return self.boxed._box_scaling[1] def _set_scaling_value(self, value): w, h = self.boxed._box_scaling self.boxed._box_scaling = (w, value) self.boxed.update() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): return self.boxed._layout_scaling @probe_scaling.setter def probe_scaling(self, value): self.boxed._layout_scaling = value def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" self.boxed.expand_layout_width() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_layout_width() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" self.boxed.expand_layout_height() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" self.boxed.shrink_layout_height() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('select_channel', self, channel_id=channel_id, key=key, button=b) @property def waveforms_type(self): return self.waveforms_types.current @waveforms_type.setter def waveforms_type(self, value): self.waveforms_types.set(value) def next_waveforms_type(self): """Switch to the next waveforms type.""" self.waveforms_types.next() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def previous_waveforms_type(self): """Switch to the previous waveforms type.""" self.waveforms_types.previous() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms: self.waveforms_types.set('waveforms') logger.debug("Switch to raw waveforms.") self.plot() elif 'mean_waveforms' in self.waveforms: self.waveforms_types.set('mean_waveforms') logger.debug("Switch to mean waveforms.") self.plot()
class CorrelogramView(ScalingMixin, ManualClusteringView): """A view showing the autocorrelogram of the selected clusters, and all cross-correlograms of cluster pairs. Constructor ----------- correlograms : function Maps `(cluster_ids, bin_size, window_size)` to an `(n_clusters, n_clusters, n_bins) array`. firing_rate : function Maps `(cluster_ids, bin_size)` to an `(n_clusters, n_clusters) array` """ # Do not show too many clusters. max_n_clusters = 20 _default_position = 'left' cluster_ids = () # Bin size, in seconds. bin_size = 1e-3 # Window size, in seconds. window_size = 50e-3 # Refactory period, in seconds refractory_period = 2e-3 # Whether the normalization is uniform across entire rows or not. uniform_normalization = False default_shortcuts = { 'change_window_size': 'ctrl+wheel', 'change_bin_size': 'alt+wheel', } default_snippets = { 'set_bin': 'cb', 'set_window': 'cw', 'set_refractory_period': 'cr', } def __init__(self, correlograms=None, firing_rate=None, sample_rate=None, **kwargs): super(CorrelogramView, self).__init__(**kwargs) self.state_attrs += ('bin_size', 'window_size', 'refractory_period', 'uniform_normalization') self.local_state_attrs += () self.canvas.set_layout(layout='grid') # Outside margin to show labels. self.canvas.gpu_transforms.add(Scale(.9)) assert sample_rate > 0 self.sample_rate = float(sample_rate) # Function clusters => CCGs. self.correlograms = correlograms # Function clusters => firing rates (same unit as CCG). self.firing_rate = firing_rate # Set the default bin and window size. self._set_bin_window(bin_size=self.bin_size, window_size=self.window_size) self.correlogram_visual = HistogramVisual() self.canvas.add_visual(self.correlogram_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.text_visual = TextVisual(color=(1., 1., 1., 1.)) self.canvas.add_visual(self.text_visual) # ------------------------------------------------------------------------- # Internal methods # ------------------------------------------------------------------------- def _iter_subplots(self, n_clusters): for i in range(n_clusters): for j in range(n_clusters): yield i, j def get_clusters_data(self, load_all=None): ccg = self.correlograms(self.cluster_ids, self.bin_size, self.window_size) fr = self.firing_rate(self.cluster_ids, self.bin_size) if self.firing_rate else None assert ccg.ndim == 3 n_bins = ccg.shape[2] bunchs = [] m = ccg.max() for i, j in self._iter_subplots(len(self.cluster_ids)): b = Bunch() b.correlogram = ccg[i, j, :] if not self.uniform_normalization: # Normalization row per row. m = ccg[i, j, :].max() b.firing_rate = fr[i, j] if fr is not None else None b.data_bounds = (0, 0, n_bins, m) b.pair_index = i, j b.color = selected_cluster_color(i, 1) if i != j: b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1)) bunchs.append(b) return bunchs def _plot_pair(self, bunch): # Plot the histogram. self.correlogram_visual.add_batch_data(hist=bunch.correlogram, color=bunch.color, ylim=bunch.data_bounds[3], box_index=bunch.pair_index) # Plot the firing rate. gray = (.25, .25, .25, 1.) if bunch.firing_rate is not None: # Line. pos = np.array([[ 0, bunch.firing_rate, bunch.data_bounds[2], bunch.firing_rate ]]) self.line_visual.add_batch_data(pos=pos, color=gray, data_bounds=bunch.data_bounds, box_index=bunch.pair_index) # # Text. # self.text_visual.add_batch_data( # pos=[bunch.data_bounds[2], bunch.firing_rate], # text='%.2f' % bunch.firing_rate, # anchor=(-1, 0), # box_index=bunch.pair_index, # data_bounds=bunch.data_bounds, # ) # Refractory period. xrp0 = round( (self.window_size * .5 - self.refractory_period) / self.bin_size) xrp1 = round((self.window_size * .5 + self.refractory_period) / self.bin_size) + 1 ylim = bunch.data_bounds[3] pos = np.array([[xrp0, 0, xrp0, ylim], [xrp1, 0, xrp1, ylim]]) self.line_visual.add_batch_data(pos=pos, color=gray, data_bounds=bunch.data_bounds, box_index=bunch.pair_index) def _plot_labels(self): n = len(self.cluster_ids) # Display the cluster ids in the subplots. for k in range(n): self.text_visual.add_batch_data( pos=[-1, 0], text=str(self.cluster_ids[k]), anchor=[-1.25, 0], data_bounds=None, box_index=(k, 0), ) self.text_visual.add_batch_data( pos=[0, -1], text=str(self.cluster_ids[k]), anchor=[0, -1.25], data_bounds=None, box_index=(n - 1, k), ) # # Display the window size in the bottom right subplot. # self.text_visual.add_batch_data( # pos=[1, -1], # anchor=[1.25, 1], # text='%.1f ms' % (1000 * .5 * self.window_size), # box_index=(n - 1, n - 1), # ) def plot(self, **kwargs): """Update the view with the current cluster selection.""" self.canvas.grid.shape = (len(self.cluster_ids), len(self.cluster_ids)) bunchs = self.get_clusters_data() self.correlogram_visual.reset_batch() self.line_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_pair(bunch) self._plot_labels() self.canvas.update_visual(self.correlogram_visual) self.canvas.update_visual(self.line_visual) self.canvas.update_visual(self.text_visual) self.canvas.update() # ------------------------------------------------------------------------- # Public methods # ------------------------------------------------------------------------- def toggle_normalization(self, checked): """Change the normalization of the correlograms.""" self.uniform_normalization = checked self.plot() def toggle_labels(self, checked): """Show or hide all labels.""" if checked: self.text_visual.show() else: self.text_visual.hide() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(CorrelogramView, self).attach(gui) self.actions.add(self.toggle_normalization, shortcut='n', checkable=True) self.actions.add(self.toggle_labels, checkable=True, checked=True) self.actions.separator() self.actions.add(self.set_bin, prompt=True, prompt_default=lambda: self.bin_size * 1000) self.actions.add(self.set_window, prompt=True, prompt_default=lambda: self.window_size * 1000) self.actions.add(self.set_refractory_period, prompt=True, prompt_default=lambda: self.refractory_period * 1000) self.actions.separator() # ------------------------------------------------------------------------- # Methods for changing the parameters # ------------------------------------------------------------------------- def _set_bin_window(self, bin_size=None, window_size=None): """Set the bin and window sizes (in seconds).""" bin_size = bin_size or self.bin_size window_size = window_size or self.window_size bin_size = _clip(bin_size, 1e-6, 1e3) window_size = _clip(window_size, 1e-6, 1e3) assert 1e-6 <= bin_size <= 1e3 assert 1e-6 <= window_size <= 1e3 assert bin_size < window_size self.bin_size = bin_size self.window_size = window_size self.update_status() @property def status(self): b, w = self.bin_size * 1000, self.window_size * 1000 return '{:.1f} ms ({:.1f} ms)'.format(w, b) def set_refractory_period(self, value): """Set the refractory period (in milliseconds).""" self.refractory_period = _clip(value, .1, 100) * 1e-3 self.plot() def set_bin(self, bin_size): """Set the correlogram bin size (in milliseconds). Example: `1` """ self._set_bin_window(bin_size=bin_size * 1e-3) self.plot() def set_window(self, window_size): """Set the correlogram window size (in milliseconds). Example: `100` """ self._set_bin_window(window_size=window_size * 1e-3) self.plot() def increase(self): """Increase the window size.""" self.set_window(1000 * self.window_size * 1.1) def decrease(self): """Decrease the window size.""" self.set_window(1000 * self.window_size / 1.1) def on_mouse_wheel(self, e): # pragma: no cover """Change the scaling with the wheel.""" super(CorrelogramView, self).on_mouse_wheel(e) if e.modifiers == ('Alt', ): self._set_bin_window(bin_size=self.bin_size * 1.1**e.delta) self.plot()
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView): """This view displays an amplitude plot for all selected clusters. Constructor ----------- amplitudes : function Maps `cluster_ids` to a list `[Bunch(amplitudes, spike_ids), ...]` for each cluster. Use `cluster_id=None` for background amplitudes. """ _default_position = 'right' # Alpha channel of the markers in the scatter plot. marker_alpha = 1. # Number of bins in the histogram. n_bins = 100 # Alpha channel of the histogram in the background. histogram_alpha = .5 # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers). quantile = .99 # Size of the histogram, between 0 and 1. histogram_scale = .25 default_shortcuts = { 'change_marker_size': 'ctrl+wheel', 'next_amplitude_type': 'a', 'previous_amplitude_type': 'shift+a', 'select_x_dim': 'alt+left click', 'select_y_dim': 'alt+right click', } def __init__(self, amplitudes=None, amplitude_name=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitude_name',) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes self.amplitude_names = list(amplitudes.keys()) # Current amplitude type. self.amplitude_name = amplitude_name or self.amplitude_names[0] assert self.amplitude_name in amplitudes self.cluster_ids = () self.duration = duration or 1 # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add_on_gpu([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')]) self.canvas.add_visual(self.hist_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) # Amplitude name. self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,)) def _get_data_bounds(self, bunchs): """Compute the data bounds.""" if not bunchs: # pragma: no cover return (0, 0, self.duration, 1) m = min(np.quantile(bunch.amplitudes, 1 - self.quantile) for bunch in bunchs) m = min(0, m) # ensure ymin <= 0 M = max(np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs) return (0, m, self.duration, M) def _add_histograms(self, bunchs): # We do this after get_clusters_data because we need x_max. for bunch in bunchs: bunch.histogram = _compute_histogram( bunch.amplitudes, x_min=self.data_bounds[1], x_max=self.data_bounds[3], n_bins=self.n_bins, normalize=False, ignore_zeros=True, ) return bunchs def _plot_cluster(self, bunch): """Make the scatter plot.""" ms = self._marker_size # Histogram in the background. self.hist_visual.add_batch_data( hist=bunch.histogram, ylim=self._ylim, color=add_alpha(bunch.color, self.histogram_alpha)) # Scatter plot. self.visual.add_batch_data( pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds) def _plot_amplitude_name(self): """Show the amplitude name.""" self.text_visual.add_batch_data(pos=[0, 1], anchor=[0, -1], text=self.amplitude_name) def get_clusters_data(self, load_all=None): """Return a list of Bunch instances, with attributes pos and spike_ids.""" if not len(self.cluster_ids): return cluster_ids = list(self.cluster_ids) # Don't need the background when splitting. if not load_all: # Add None cluster which means background spikes. cluster_ids = [None] + cluster_ids bunchs = self.amplitudes[self.amplitude_name](cluster_ids, load_all=load_all) or () # Add a pos attribute in bunchs in addition to x and y. for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)): spike_ids = _as_array(bunch.spike_ids) spike_times = _as_array(bunch.spike_times) amplitudes = _as_array(bunch.amplitudes) assert spike_ids.shape == spike_times.shape == amplitudes.shape # Ensure that bunch.pos exists, as it used by the LassoMixin. bunch.pos = np.c_[spike_times, amplitudes] assert bunch.pos.ndim == 2 bunch.cluster_id = cluster_id bunch.color = ( selected_cluster_color(i - 1, self.marker_alpha) # Background amplitude color. if cluster_id is not None else (.5, .5, .5, .5)) return bunchs def plot(self, **kwargs): """Update the view with the current cluster selection.""" bunchs = self.get_clusters_data() if not bunchs: return self.data_bounds = self._get_data_bounds(bunchs) bunchs = self._add_histograms(bunchs) # Use the same scale for all histograms. self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1. self.visual.reset_batch() self.hist_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self._plot_amplitude_name() self.canvas.update_visual(self.visual) self.canvas.update_visual(self.hist_visual) self.canvas.update_visual(self.text_visual) self._update_axes() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(AmplitudeView, self).attach(gui) self.actions.add(self.next_amplitude_type, set_busy=True) self.actions.add(self.previous_amplitude_type, set_busy=True) def _change_amplitude_type(self, dir=+1): i = self.amplitude_names.index(self.amplitude_name) n = len(self.amplitude_names) self.amplitude_name = self.amplitude_names[(i + dir) % n] logger.debug("Switch to amplitude type: %s.", self.amplitude_name) self.plot() def next_amplitude_type(self): """Switch to the next amplitude type.""" self._change_amplitude_type(+1) def previous_amplitude_type(self): """Switch to the previous amplitude type.""" self._change_amplitude_type(-1)
class HistogramView(ScalingMixin, ManualClusteringView): """This view displays a histogram for every selected cluster, along with a possible plot and some text. To be overriden. Constructor ----------- cluster_stat : function Maps `cluster_id` to `Bunch(data (1D array), plot (1D array), text)`. """ _default_position = 'right' cluster_ids = () # Number of bins in the histogram. n_bins = 100 # Minimum value on the x axis (determines the range of the histogram) # If None, then `data.min()` is used. x_min = None # Maximum value on the x axis (determines the range of the histogram) # If None, then `data.max()` is used. x_max = None # Unit of the bin in the set_bin_size, set_x_min, set_x_max actions. bin_unit = 's' # s (seconds) or ms (milliseconds) # The snippet to update this view are `hn` to change the number of bins, and `hm` to # change the maximum value on the x axis. The character `h` can be customized by child classes. alias_char = 'h' default_shortcuts = { 'change_window_size': 'ctrl+wheel', } default_snippets = { 'set_n_bins': '%sn' % alias_char, 'set_bin_size (%s)' % bin_unit: '%sb' % alias_char, 'set_x_min (%s)' % bin_unit: '%smin' % alias_char, 'set_x_max (%s)' % bin_unit: '%smax' % alias_char, } _state_attrs = ('n_bins', 'x_min', 'x_max') _local_state_attrs = () def __init__(self, cluster_stat=None): super(HistogramView, self).__init__() self.state_attrs += self._state_attrs self.local_state_attrs += self._local_state_attrs self.canvas.set_layout(layout='stacked', n_plots=1) self.canvas.enable_axes() self.cluster_stat = cluster_stat self.visual = HistogramVisual() self.canvas.add_visual(self.visual) self.plot_visual = PlotVisual() self.canvas.add_visual(self.plot_visual) self.text_visual = TextVisual(color=(1., 1., 1., 1.)) self.canvas.add_visual(self.text_visual) def _plot_cluster(self, bunch): assert bunch n_bins = self.n_bins assert n_bins >= 0 # Update the visual's data. self.visual.add_batch_data( hist=bunch.histogram, ylim=bunch.ylim, color=bunch.color, box_index=bunch.index) # Plot. plot = bunch.get('plot', None) if plot is not None: x = np.linspace(self.x_min, self.x_max, len(plot)) self.plot_visual.add_batch_data( x=x, y=plot, color=(1, 1, 1, 1), data_bounds=self.data_bounds, box_index=bunch.index, ) text = bunch.get('text', None) if not text: return # Support multiline text. text = text.splitlines() n = len(text) self.text_visual.add_batch_data( text=text, pos=[(-1, .8)] * n, anchor=[(1, -1 - 2 * i) for i in range(n)], box_index=bunch.index, ) def get_clusters_data(self, load_all=None): bunchs = [] for i, cluster_id in enumerate(self.cluster_ids): bunch = self.cluster_stat(cluster_id) if not bunch.data.size: continue bmin, bmax = bunch.data.min(), bunch.data.max() # Update self.x_max if it was not set before. self.x_min = self.x_min or bunch.get('x_min', None) or bmin self.x_max = self.x_max or bunch.get('x_max', None) or bmax self.x_min = min(self.x_min, self.x_max) assert self.x_min is not None assert self.x_max is not None assert self.x_min <= self.x_max # Compute the histogram. bunch.histogram = _compute_histogram( bunch.data, x_min=self.x_min, x_max=self.x_max, n_bins=self.n_bins) bunch.ylim = bunch.histogram.max() bunch.color = selected_cluster_color(i) bunch.index = i bunch.cluster_id = cluster_id bunchs.append(bunch) return bunchs def _get_data_bounds(self, bunchs): # Get the axes data bounds (the last subplot's extended n_cluster times on the y axis). ylim = max(bunch.ylim for bunch in bunchs) if bunchs else 1 return (self.x_min, 0, self.x_max, ylim * len(self.cluster_ids)) def plot(self, **kwargs): """Update the view with the selected clusters.""" bunchs = self.get_clusters_data() self.data_bounds = self._get_data_bounds(bunchs) self.canvas.stacked.n_boxes = len(self.cluster_ids) self.visual.reset_batch() self.plot_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.visual) self.canvas.update_visual(self.plot_visual) self.canvas.update_visual(self.text_visual) self._update_axes() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(HistogramView, self).attach(gui) self.actions.add( self.set_n_bins, alias=self.alias_char + 'n', prompt=True, prompt_default=lambda: self.n_bins) self.actions.add( self.set_bin_size, alias=self.alias_char + 'b', prompt=True, prompt_default=lambda: self.bin_size) self.actions.add( self.set_x_min, alias=self.alias_char + 'min', prompt=True, prompt_default=lambda: self.x_min) self.actions.add( self.set_x_max, alias=self.alias_char + 'max', prompt=True, prompt_default=lambda: self.x_max) self.actions.separator() @property def status(self): f = 1 if self.bin_unit == 's' else 1000 return '[{:.1f}{u}, {:.1f}{u:s}]'.format( (self.x_min or 0) * f, (self.x_max or 0) * f, u=self.bin_unit) # Histogram parameters # ------------------------------------------------------------------------- def _get_scaling_value(self): return self.x_max def _set_scaling_value(self, value): if self.bin_unit == 'ms': value *= 1000 self.set_x_max(value) def set_n_bins(self, n_bins): """Set the number of bins in the histogram.""" self.n_bins = n_bins logger.debug("Change number of bins to %d for %s.", n_bins, self.__class__.__name__) self.plot() @property def bin_size(self): """Return the bin size (in seconds or milliseconds depending on `self.bin_unit`).""" bs = (self.x_max - self.x_min) / self.n_bins if self.bin_unit == 'ms': bs *= 1000 return bs def set_bin_size(self, bin_size): """Set the bin size in the histogram.""" assert bin_size > 0 if self.bin_unit == 'ms': bin_size /= 1000 self.n_bins = np.round((self.x_max - self.x_min) / bin_size) logger.debug("Change number of bins to %d for %s.", self.n_bins, self.__class__.__name__) self.plot() def set_x_min(self, x_min): """Set the minimum value on the x axis for the histogram.""" if self.bin_unit == 'ms': x_min /= 1000 x_min = min(x_min, self.x_max) if x_min == self.x_max: return self.x_min = x_min logger.debug("Change x min to %s for %s.", x_min, self.__class__.__name__) self.plot() def set_x_max(self, x_max): """Set the maximum value on the x axis for the histogram.""" if self.bin_unit == 'ms': x_max /= 1000 x_max = max(x_max, self.x_min) if x_max == self.x_min: return self.x_max = x_max logger.debug("Change x max to %s for %s.", x_max, self.__class__.__name__) self.plot()
class FeatureView(MarkerSizeMixin, ScalingMixin, ManualClusteringView): """This view displays a 4x4 subplot matrix with different projections of the principal component features. This view keeps track of which channels are currently shown. Constructor ----------- features : function Maps `(cluster_id, channel_ids=None, load_all=False)` to `Bunch(data, channel_ids, channel_labels, spike_ids , masks)`. * `data` is an `(n_spikes, n_channels, n_features)` array * `channel_ids` contains the channel ids of every row in `data` * `channel_labels` contains the channel labels of every row in `data` * `spike_ids` is a `(n_spikes,)` array * `masks` is an `(n_spikes, n_channels)` array This allows for a sparse format. attributes : dict Maps an attribute name to a 1D array with `n_spikes` numbers (for example, spike times). """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' cluster_ids = () # Whether to disable automatic selection of channels. fixed_channels = False feature_scaling = 1. default_shortcuts = { 'change_marker_size': 'alt+wheel', 'increase': 'ctrl++', 'decrease': 'ctrl+-', 'add_lasso_point': 'ctrl+click', 'stop_lasso': 'ctrl+right click', 'toggle_automatic_channel_selection': 'c', } def __init__(self, features=None, attributes=None, **kwargs): super(FeatureView, self).__init__(**kwargs) self.state_attrs += ('fixed_channels', 'feature_scaling') assert features self.features = features self._lim = 1 self.grid_dim = _get_default_grid() # 2D array where every item a string like `0A,1B` self.n_rows, self.n_cols = np.array(self.grid_dim).shape self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols)) self.canvas.enable_lasso() # Channels being shown. self.channel_ids = None # Attributes: extra features. This is a dictionary # {name: array} # where each array is a `(n_spikes,)` array. self.attributes = attributes or {} self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) def set_grid_dim(self, grid_dim): """Change the grid dim dynamically. Parameters ---------- grid_dim : array-like (2D) `grid_dim[row, col]` is a string with two values separated by a comma. Each value is the relative channel id (0, 1, 2...) followed by the PC (A, B, C...). For example, `grid_dim[row, col] = 0B,1A`. Each value can also be an attribute name, for example `time`. For example, `grid_dim[row, col] = time,2C`. """ self.grid_dim = grid_dim self.n_rows, self.n_cols = np.array(grid_dim).shape self.canvas.grid.shape = (self.n_rows, self.n_cols) # Internal methods # ------------------------------------------------------------------------- def _iter_subplots(self): """Yield (i, j, dim).""" for i in range(self.n_rows): for j in range(self.n_cols): dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') yield i, j, dim_x, dim_y def _get_axis_label(self, dim): """Return the channel label from a dimension, if applicable.""" if str(dim[:-1]).isdecimal(): n = len(self.channel_ids) channel_id = self.channel_ids[int(dim[:-1]) % n] return self.channel_labels[channel_id] + dim[-1] else: return dim def _get_channel_and_pc(self, dim): """Return the channel_id and PC of a dim.""" if self.channel_ids is None: return assert dim not in self.attributes # This is called only on PC data. s = 'ABCDEFGHIJ' # Channel relative index, typically just 0 or 1. c_rel = int(dim[:-1]) # Get the channel_id from the currently-selected channels. channel_id = self.channel_ids[c_rel % len(self.channel_ids)] pc = s.index(dim[-1]) return channel_id, pc def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None): """Extract the points from the data on a given dimension. bunch is returned by the features() function. dim is the string specifying the dimensions to extract for the data. """ if dim in self.attributes: return self.attributes[dim](cluster_id, load_all=load_all) masks = bunch.get('masks', None) channel_id, pc = self._get_channel_and_pc(dim) # Skip the plot if the channel id is not displayed. if channel_id not in bunch.channel_ids: # pragma: no cover return Bunch(data=np.zeros((bunch.data.shape[0],))) # Get the column index of the current channel in data. c = list(bunch.channel_ids).index(channel_id) if masks is not None: masks = masks[:, c] return Bunch(data=self.feature_scaling * bunch.data[:, c, pc], masks=masks) def _get_axis_bounds(self, dim, bunch): """Return the min/max of an axis.""" if dim in self.attributes: # Attribute: specified lim, or compute the min/max. vmin, vmax = bunch.get('lim', (0, 0)) assert vmin is not None assert vmax is not None return vmin, vmax return (-self._lim, +self._lim) def _plot_points(self, bunch, clu_idx=None): if not bunch: return cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None for i, j, dim_x, dim_y in self._iter_subplots(): px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id) # Skip empty data. if px is None or py is None: # pragma: no cover logger.warning("Skipping empty data for cluster %d.", cluster_id) return assert px.data.shape == py.data.shape xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) assert xmin <= xmax assert ymin <= ymax data_bounds = (xmin, ymin, xmax, ymax) masks = _get_masks_max(px, py) # Prepare the batch visual with all subplots # for the selected cluster. self.visual.add_batch_data( x=px.data, y=py.data, color=_get_point_color(clu_idx), # Reduced marker size for background features size=self._marker_size, masks=_get_point_masks(clu_idx=clu_idx, masks=masks), data_bounds=data_bounds, box_index=(i, j), ) # Get the channel ids corresponding to the relative channel indices # specified in the dimensions. Channel 0 corresponds to the first # best channel for the selected cluster, and so on. label_x = self._get_axis_label(dim_x) label_y = self._get_axis_label(dim_y) # Add labels. self.text_visual.add_batch_data( pos=[1, 1], anchor=[-1, -1], text=label_y, data_bounds=None, box_index=(i, j), ) self.text_visual.add_batch_data( pos=[0, -1.], anchor=[0, 1], text=label_x, data_bounds=None, box_index=(i, j), ) def _plot_axes(self): self.line_visual.reset_batch() for i, j, dim_x, dim_y in self._iter_subplots(): self.line_visual.add_batch_data( pos=[[-1., 0., +1., 0.], [0., -1., 0., +1.]], color=(.5, .5, .5, .5), box_index=(i, j), data_bounds=None, ) self.canvas.update_visual(self.line_visual) def _get_lim(self, bunchs): if not bunchs: # pragma: no cover return 1 m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs) M = max(abs(m), abs(M)) return M def _get_scaling_value(self): return self.feature_scaling def _set_scaling_value(self, value): self.feature_scaling = value self.plot(fixed_channels=True) # Public methods # ------------------------------------------------------------------------- def clear_channels(self): """Reset the current channels.""" self.channel_ids = None self.plot() def get_clusters_data(self, fixed_channels=None, load_all=None): # Get the feature data. # Specify the channel ids if these are fixed, otherwise # choose the first cluster's best channels. c = self.channel_ids if fixed_channels else None bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids] bunchs = [b for b in bunchs if b] if not bunchs: # pragma: no cover return [] for cluster_id, bunch in zip(self.cluster_ids, bunchs): bunch.cluster_id = cluster_id # Choose the channels based on the first selected cluster. channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else [] common_channels = list(channel_ids) # Intersection (with order kept) of channels belonging to all clusters. for bunch in bunchs: common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels] # The selected channels will be (1) the channels common to all clusters, followed # by (2) remaining channels from the first cluster (excluding those already selected # in (1)). n = len(channel_ids) not_common_channels = [c for c in channel_ids if c not in common_channels] channel_ids = common_channels + not_common_channels[:n - len(common_channels)] assert len(channel_ids) > 0 # Choose the channels automatically unless fixed_channels is set. if (not fixed_channels or self.channel_ids is None): self.channel_ids = channel_ids assert len(self.channel_ids) # Channel labels. self.channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])]) self.channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))}) return bunchs def plot(self, **kwargs): """Update the view with the selected clusters.""" # Determine whether the channels should be fixed or not. added = kwargs.get('up', {}).get('added', None) # Fix the channels if the view updates after a cluster event # and there are new clusters. fixed_channels = ( self.fixed_channels or kwargs.get('fixed_channels', None) or added is not None) # Get the clusters data. bunchs = self.get_clusters_data(fixed_channels=fixed_channels) bunchs = [b for b in bunchs if b] if not bunchs: return self._lim = self._get_lim(bunchs) # Get the background data. background = self.features(channel_ids=self.channel_ids) # Plot all features. self._plot_axes() # NOTE: the columns in bunch.data are ordered by decreasing quality # of the associated channels. The channels corresponding to each # column are given in bunch.channel_ids in the same order. # Plot points. self.visual.reset_batch() self.text_visual.reset_batch() self._plot_points(background) # background spikes # Plot each cluster. for clu_idx, bunch in enumerate(bunchs): self._plot_points(bunch, clu_idx=clu_idx) # Upload the data on the GPU. self.canvas.update_visual(self.visual) self.canvas.update_visual(self.text_visual) self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(FeatureView, self).attach(gui) self.actions.add( self.toggle_automatic_channel_selection, checked=not self.fixed_channels, checkable=True) self.actions.add(self.clear_channels) self.actions.separator() def toggle_automatic_channel_selection(self, checked): """Toggle the automatic selection of channels when the cluster selection changes.""" self.fixed_channels = not checked @property def status(self): if self.channel_ids is None: # pragma: no cover return '' channel_labels = [self.channel_labels[ch] for ch in self.channel_ids[:2]] return 'channels: %s' % ', '.join(channel_labels) # Dimension selection # ------------------------------------------------------------------------- def on_select_channel(self, sender=None, channel_id=None, key=None, button=None): """Respond to the click on a channel from another view, and update the relevant subplots.""" channels = self.channel_ids if channels is None: return if len(channels) == 1: self.plot() return assert len(channels) >= 2 # Get the axis from the pressed button (1, 2, etc.) if key is not None: d = np.clip(len(channels) - 1, 0, key - 1) else: d = 0 if button == 'Left' else 1 # Change the first or second best channel. old = channels[d] # Avoid updating the view if the channel doesn't change. if channel_id == old: return channels[d] = channel_id # Ensure that the first two channels are different. if channels[1 - min(d, 1)] == channel_id: channels[1 - min(d, 1)] = old assert channels[0] != channels[1] # Remove duplicate channels. self.channel_ids = _uniq(channels) logger.debug("Choose channels %d and %d in feature view.", *channels[:2]) # Fix the channels temporarily. self.plot(fixed_channels=True) self.update_status() def on_mouse_click(self, e): """Select a feature dimension by clicking on a box in the feature view.""" b = e.button if 'Alt' in e.modifiers: # Get mouse position in NDC. (i, j), _ = self.canvas.grid.box_map(e.pos) dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') dim = dim_x if b == 'Left' else dim_y other_dim = dim_y if b == 'Left' else dim_x if dim not in self.attributes: # When a regular (channel, PC) dimension is selected. channel_pc = self._get_channel_and_pc(dim) if channel_pc is None: return channel_id, pc = channel_pc logger.debug("Click on feature dim %s, channel id %s, PC %s.", dim, channel_id, pc) else: # When the selected dimension is an attribute, e.g. "time". pc = None # Take the channel id in the other dimension. channel_pc = self._get_channel_and_pc(other_dim) channel_id = channel_pc[0] if channel_pc is not None else None logger.debug("Click on feature dim %s.", dim) emit('select_feature', self, dim=dim, channel_id=channel_id, pc=pc) def on_request_split(self, sender=None): """Return the spikes enclosed by the lasso.""" if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) assert len(self.channel_ids) # Get the dimensions of the lassoed subplot. i, j = self.canvas.layout.active_box dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') # Get all points from all clusters. pos = [] spike_ids = [] for cluster_id in self.cluster_ids: # Load all spikes. bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True) if not bunch: continue px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True) points = np.c_[px.data, py.data] # Normalize the points. xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) r = Range((xmin, ymin, xmax, ymax)) points = r.apply(points) pos.append(points) spike_ids.append(bunch.spike_ids) pos = np.vstack(pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.canvas.lasso.in_polygon(pos) self.canvas.lasso.clear() return np.unique(spike_ids[ind])
class WaveformClusteringView(LassoMixin,ManualClusteringView): default_shortcuts = { 'next_channel': 'f', 'previous_channel': 'r', } def __init__(self, model=None): super(WaveformClusteringView, self).__init__() self.canvas.enable_axes() self.canvas.enable_lasso() self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,)) self.model = model self.gain = 0.195 self.Fs = 30 # kHz self.visual = PlotVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0) self.cluster_ids = None self.cluster_id = None self.channel_ids = None self.wavefs = None self.current_channel_idx = None def on_select(self, cluster_ids=(), **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return if self.cluster_id != cluster_ids[0]: self.cluster_id = cluster_ids[0] self.channel_ids = self.model.get_cluster_channels(self.cluster_id) self.spike_ids = self.model.get_cluster_spikes(self.cluster_id) self.wavefs = self.model.get_waveforms(self.spike_ids,channel_ids=self.channel_ids) self.setChannelIdx(0) def plotWaveforms(self): Nspk,Ntime,Nchan = self.wavefs.shape self.visual.reset_batch() self.text_visual.reset_batch() x = np.tile(np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime), (Nspk, 1)) M=np.max(np.abs(self.wavefs[:,:,self.current_channel_idx])) #print(M*self.gain) if M*self.gain<100: M = 10*np.ceil(M*self.gain/10) elif M*self.gain<1000: M = 100*np.ceil(M*self.gain/100) else: M = 1000*np.floor(M*self.gain/1000) self.data_bounds = (x[0][0], -M, x[0][-1], M) colorwavef = selected_cluster_color(0) colormedian = selected_cluster_color(3)#(1,156/256,0,.5)#selected_cluster_color(1) colorstd = (0,1,0,1)#selected_cluster_color(2) colorqtl = (1,1,0,1) if Nspk>100: medianCl = np.median(self.wavefs[:,:,self.current_channel_idx],axis=0) stdCl = np.std(self.wavefs[:,:,self.current_channel_idx],axis=0) q1 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.01,axis=0,interpolation='higher') q9 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.99,axis=0,interpolation='lower') self.visual.add_batch_data( x=x, y=self.gain*self.wavefs[:,:,self.current_channel_idx], color=colorwavef, data_bounds=self.data_bounds, box_index=0) #stats if Nspk>100: x1 = x[0] self.visual.add_batch_data( x=x1, y=self.gain*medianCl, color=colormedian, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*(medianCl+3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*(medianCl-3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*q1, color=colorqtl, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*q9, color=colorqtl, data_bounds=self.data_bounds, box_index=0) #axes self.text_visual.add_batch_data( pos=[.9, .98], text='[uV]', anchor=[-1, -1], box_index=0, ) self.text_visual.add_batch_data( pos=[-1, -.95], text='[ms]', anchor=[1, 1], box_index=0, ) label = 'Ch {a}'.format(a=self.channel_ids[self.current_channel_idx]) self.text_visual.add_batch_data( pos=[-.98, .98], text=str(label), anchor=[1, -1], box_index=0, ) self.canvas.update_visual(self.visual) self.canvas.update_visual(self.text_visual) self.canvas.axes.reset_data_bounds(self.data_bounds) self.canvas.update() def setChannel(self,channel_id): self.channel_ids itemindex = np.where(self.channel_ids==channel_id)[0] if len(itemindex): self.setChannelIdx(itemindex[0]) def setChannelIdx(self,channel_idx): self.current_channel_idx = channel_idx self.plotWaveforms() def setNextChannelIdx(self): if self.current_channel_idx == len(self.channel_ids)-1: return self.setChannelIdx(self.current_channel_idx+1) def setPrevChannelIdx(self): if self.current_channel_idx == 0: return self.setChannelIdx(self.current_channel_idx-1) def on_request_split(self, sender=None): """Return the spikes enclosed by the lasso.""" if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) pos = [] spike_ids = [] Ntime = self.wavefs.shape[1] x = np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime) for idx,spike in enumerate(self.spike_ids): points = np.c_[x,self.gain*self.wavefs[idx,:,self.current_channel_idx]] pos.append(points) spike_ids.append([spike]*len(x)) if not pos: # pragma: no cover logger.warning("Empty lasso.") return np.array([]) pos = np.vstack(pos) pos = range_transform([self.data_bounds], [NDC], pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.canvas.lasso.in_polygon(pos) self.canvas.lasso.clear() # Return all spikes not lassoed, so the selected cluster is still the same we are working on spikes_to_remove = np.unique(spike_ids[ind]) keepspikes=np.isin(self.spike_ids,spikes_to_remove,assume_unique=True,invert=True) A=self.spike_ids[self.spike_ids != spikes_to_remove] if len(A)>0: return self.spike_ids[keepspikes] else: return np.array([], dtype=np.int64)
class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveform_type : str Default key of the waveforms dictionary to plot initially. """ _default_position = 'right' cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]]) self.canvas.enable_axes() self._box_scaling = (1., 1.) self._probe_scaling = (1., 1.) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } assert waveforms self.waveforms = waveforms self.waveforms_types = list(waveforms.keys()) # Current waveforms type. self.waveforms_type = waveforms_type or self.waveforms_types[0] assert self.waveforms_type in waveforms assert 'waveforms' in waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) return [-1, m, +1, M] def get_clusters_data(self): bunchs = [ self.waveforms[self.waveforms_type](cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. masks *= .99999 # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, n_samples) box_index = np.tile(box_index, n_spikes_clu) assert box_index.shape == (n_spikes_clu * n_channels * n_samples, ) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) wave = wave.reshape((n_spikes_clu * n_channels, n_samples)) self.waveform_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the box bounds as a function of the selected channels. if channel_ids: self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() self.data_bounds = self._get_data_bounds(bunchs) self.waveform_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.waveform_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) self._update_axes(bunchs) self.canvas.update() def _update_axes(self, bunchs): """Update the axes.""" # Update the axes data bounds. _, m, _, M = self.data_bounds # Waveform duration, scaled by overlap factor if needed. wave_dur = bunchs[0].get('waveform_duration', 1.) wave_dur /= .5 * (1 + _overlap_transform( 1, n=len(self.cluster_ids), overlap=self.overlap)) x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC, [wave_dur, M - m]) axes_data_bounds = (0, 0, x1, y1) self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True) def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def _update_boxes(self): self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling, self.box_size) def _apply_box_scaling(self): self.canvas.layout.scaling = self._box_scaling @property def box_scaling(self): """Scaling of the channel boxes.""" return self._box_scaling @box_scaling.setter def box_scaling(self, value): assert len(value) == 2 self._box_scaling = value self._apply_box_scaling() def widen(self): """Increase the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w * self._scaling_param_increment, h) self._apply_box_scaling() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w / self._scaling_param_increment, h) self._apply_box_scaling() def _get_scaling_value(self): return self._box_scaling[1] def _set_scaling_value(self, value): w, h = self._box_scaling self.box_scaling = (w, value) self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): """Scaling of the entire probe.""" return self._probe_scaling @probe_scaling.setter def probe_scaling(self, value): assert len(value) == 2 self._probe_scaling = value self._update_boxes() def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" w, h = self._probe_scaling self._probe_scaling = (w * self._scaling_param_increment, h) self._update_boxes() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w / self._scaling_param_increment, h) self._update_boxes() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h * self._scaling_param_increment) self._update_boxes() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h / self._scaling_param_increment) self._update_boxes() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('channel_click', self, channel_id=channel_id, key=key, button=b) def next_waveforms_type(self): """Switch to the next waveforms type.""" i = self.waveforms_types.index(self.waveforms_type) n = len(self.waveforms_types) self.waveforms_type = self.waveforms_types[(i + 1) % n] logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms': self.waveforms_type = 'waveforms' self.plot() elif 'mean_waveforms' in self.waveforms_types: self.waveforms_type = 'mean_waveforms' self.plot()
class TraceView(ScalingMixin, BaseColorView, ManualClusteringView): """This view shows the raw traces along with spike waveforms. Constructor ----------- traces : function Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where * `data` is an `(n_samples, n_channels)` array * `waveforms` is a list of bunchs with the following attributes: * `data` * `color` * `channel_ids` * `start_time` * `spike_id` * `spike_cluster` spike_times : function Teturns the list of relevant spike times. sample_rate : float duration : float n_channels : int channel_positions : array-like Positions of the channels, used for displaying the channels in the right y order channel_labels : list Labels of all shown channels. By default, this is just the channel ids. """ _default_position = 'left' auto_update = True auto_scale = True interval_duration = .25 # default duration of the interval shift_amount = .1 scaling_coeff_x = 1.25 trace_quantile = .01 # quantile for auto-scaling default_trace_color = (.5, .5, .5, 1) trace_color_0 = (.353, .161, .443) trace_color_1 = (.133, .404, .396) default_shortcuts = { 'change_trace_size': 'ctrl+wheel', 'switch_color_scheme': 'shift+wheel', 'navigate': 'alt+wheel', 'decrease': 'alt+down', 'increase': 'alt+up', 'go_left': 'alt+left', 'go_right': 'alt+right', 'jump_left': 'shift+alt+left', 'jump_right': 'shift+alt+right', 'go_to_start': 'alt+home', 'go_to_end': 'alt+end', 'go_to': 'alt+t', 'go_to_next_spike': 'alt+pgdown', 'go_to_previous_spike': 'alt+pgup', 'narrow': 'alt++', 'select_spike': 'ctrl+click', 'select_channel_pcA': 'shift+left click', 'select_channel_pcB': 'shift+right click', 'switch_origin': 'alt+o', 'toggle_highlighted_spikes': 'alt+s', 'toggle_show_labels': 'alt+l', 'widen': 'alt+-', } default_snippets = { 'go_to': 'tg', 'shift': 'ts', } def __init__( self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None, channel_positions=None, channel_labels=None, **kwargs): self.do_show_labels = True self.show_all_spikes = False self.get_spike_times = spike_times # Sample rate. assert sample_rate > 0 self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces # self.waveforms = None assert duration >= 0 self.duration = duration assert n_channels >= 0 self.n_channels = n_channels # Channel y ranking. self.channel_positions = ( channel_positions if channel_positions is not None else np.c_[np.zeros(n_channels), np.arange(n_channels)]) # channel_y_ranks[i] is the position of channel #i in the trace view. self.channel_y_ranks = np.argsort(np.argsort(self.channel_positions[:, 1])) assert self.channel_y_ranks.shape == (n_channels,) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(n_channels)]) assert len(self.channel_labels) == n_channels # Initialize the view. super(TraceView, self).__init__(**kwargs) self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale') self.local_state_attrs += ('interval', 'scaling',) # Visuals. self._create_visuals() # Initial interval. self._interval = None self.go_to(duration / 2.) self._waveform_times = [] self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2)) def _create_visuals(self): self.canvas.set_layout('stacked', n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) self.trace_visual = UniformPlotVisual() # Gradient of color for the traces. if self.trace_color_0 and self.trace_color_1: self.trace_visual.inserter.insert_frag( 'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % ( self.trace_color_0, self.trace_color_1, self.n_channels), 'end') self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_boxes >= 50 * u_zoom.y) && ' '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual) @property def stacked(self): return self.canvas.stacked # Internal methods # ------------------------------------------------------------------------- def _plot_traces(self, traces, color=None): traces = traces.T n_samples = traces.shape[1] n_ch = self.n_channels assert traces.shape == (n_ch, n_samples) color = color or self.default_trace_color t = self._interval[0] + np.arange(n_samples) * self.dt t = np.tile(t, (n_ch, 1)) box_index = self.channel_y_ranks box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1) assert t.shape == (n_ch, n_samples) assert traces.shape == (n_ch, n_samples) assert box_index.shape == (n_ch, n_samples) self.trace_visual.color = color self.canvas.update_visual( self.trace_visual, t, traces, data_bounds=self.data_bounds, box_index=box_index.ravel(), ) def _plot_spike(self, bunch): # The spike time corresponds to the first sample of the waveform. n_samples, n_channels = bunch.data.shape assert len(bunch.channel_ids) == n_channels # Generate the x coordinates of the waveform. t = bunch.start_time + self.dt * np.arange(n_samples) t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) # Determine the spike color. i = bunch.select_index c = bunch.spike_cluster cs = self.color_schemes.get() color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1) # We could tweak the color of each spike waveform depending on the template amplitude # on each of its best channels. # channel_amps = bunch.get('channel_amps', None) # if channel_amps is not None: # color = np.tile(color, (n_channels, 1)) # assert color.shape == (n_channels, 4) # color[:, 3] = channel_amps # The box index depends on the channel. box_index = self.channel_y_ranks[bunch.channel_ids] box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0) self.waveform_visual.add_batch_data( box_index=box_index, x=t, y=bunch.data.T, color=color, data_bounds=self.data_bounds, ) def _plot_waveforms(self, waveforms, **kwargs): """Plot the waveforms.""" # waveforms = self.waveforms assert isinstance(waveforms, list) if waveforms: self.waveform_visual.show() self.waveform_visual.reset_batch() for w in waveforms: self._plot_spike(w) self._waveform_times.append( (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None))) self.canvas.update_visual(self.waveform_visual) else: # pragma: no cover self.waveform_visual.hide() def _plot_labels(self, traces): self.text_visual.reset_batch() for ch in range(self.n_channels): bi = self.channel_y_ranks[ch] ch_label = self.channel_labels[ch] self.text_visual.add_batch_data( pos=[self.data_bounds[0], 0], text=ch_label, anchor=[+1., 0], data_bounds=self.data_bounds, box_index=bi, ) self.canvas.update_visual(self.text_visual) # Public methods # ------------------------------------------------------------------------- def _restrict_interval(self, interval): start, end = interval # Round the times to full samples to avoid subsampling shifts # in the traces. start = int(round(start * self.sample_rate)) / self.sample_rate end = int(round(end * self.sample_rate)) / self.sample_rate # Restrict the interval to the boundaries of the traces. if start < 0: end += (-start) start = 0 elif end >= self.duration: start -= (end - self.duration) end = self.duration start = np.clip(start, 0, end) end = np.clip(end, start, self.duration) assert 0 <= start < end <= self.duration return start, end def plot(self, update_traces=True, update_waveforms=True): if update_waveforms: # Load the traces in the interval. traces = self.traces(self._interval) if update_traces: logger.log(5, "Redraw the entire trace view.") start, end = self._interval # Find the data bounds. if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC: ymin = np.quantile(traces.data, self.trace_quantile) ymax = np.quantile(traces.data, 1. - self.trace_quantile) else: ymin, ymax = self.data_bounds[1], self.data_bounds[3] self.data_bounds = (start, ymin, end, ymax) # Used for spike click. self._waveform_times = [] # Plot the traces. self._plot_traces( traces.data, color=traces.get('color', None)) # Plot the labels. if self.do_show_labels: self._plot_labels(traces.data) if update_waveforms: self._plot_waveforms(traces.get('waveforms', [])) self._update_axes() self.canvas.update() def set_interval(self, interval=None): """Display the traces and spikes in a given interval.""" if interval is None: interval = self._interval interval = self._restrict_interval(interval) if interval != self._interval: logger.log(5, "Redraw the entire trace view.") self._interval = interval emit('is_busy', self, True) self.plot(update_traces=True, update_waveforms=True) emit('is_busy', self, False) emit('time_range_selected', self, interval) self.update_status() else: self.plot(update_traces=False, update_waveforms=True) def on_select(self, cluster_ids=None, **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return # Make sure we call again self.traces() when the cluster selection changes. self.set_interval() def attach(self, gui): """Attach the view to the GUI.""" super(TraceView, self).attach(gui) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add( self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes) self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale) self.actions.add(self.switch_origin) self.actions.separator() self.actions.add( self.go_to, prompt=True, prompt_default=lambda: str(self.time)) self.actions.separator() self.actions.add(self.go_to_start) self.actions.add(self.go_to_end) self.actions.separator() self.actions.add(self.shift, prompt=True) self.actions.add(self.go_right) self.actions.add(self.go_left) self.actions.add(self.jump_right) self.actions.add(self.jump_left) self.actions.separator() self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() self.actions.add(self.go_to_next_spike) self.actions.add(self.go_to_previous_spike) self.actions.separator() self.set_interval() @property def status(self): a, b = self._interval return '[{:.2f}s - {:.2f}s]. Color scheme: {}.'.format(a, b, self.color_scheme) # Origin # ------------------------------------------------------------------------- @property def origin(self): """Whether to show the channels from top to bottom (`top` option, the default), or from bottom to top (`bottom`).""" return getattr(self.canvas.layout, 'origin', Stacked._origin) @origin.setter def origin(self, value): if value is None: return if self.canvas.layout: self.canvas.layout.origin = value else: # pragma: no cover logger.warning( "Could not set origin to %s because the layout instance was not initialized yet.", value) def switch_origin(self): """Switch between top and bottom origin for the channels.""" self.origin = 'bottom' if self.origin == 'top' else 'top' # Navigation # ------------------------------------------------------------------------- @property def time(self): """Time at the center of the window.""" return sum(self._interval) * .5 @property def interval(self): """Interval as `(tmin, tmax)`.""" return self._interval @interval.setter def interval(self, value): self.set_interval(value) @property def half_duration(self): """Half of the duration of the current interval.""" if self._interval is not None: a, b = self._interval return (b - a) * .5 else: return self.interval_duration * .5 def go_to(self, time): """Go to a specific time (in seconds).""" half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) def shift(self, delay): """Shift the interval by a given delay (in seconds).""" self.go_to(self.time + delay) def go_to_start(self): """Go to the start of the recording.""" self.go_to(0) def go_to_end(self): """Go to end of the recording.""" self.go_to(self.duration) def go_right(self): """Go to right.""" start, end = self._interval delay = (end - start) * .1 self.shift(delay) def go_left(self): """Go to left.""" start, end = self._interval delay = (end - start) * .1 self.shift(-delay) def jump_right(self): """Jump to right.""" delay = self.duration * .1 self.shift(delay) def jump_left(self): """Jump to left.""" delay = self.duration * .1 self.shift(-delay) def _jump_to_spike(self, delta=+1): """Jump to next or previous spike from the selected clusters.""" spike_times = self.get_spike_times() if spike_times is not None and len(spike_times): ind = np.searchsorted(spike_times, self.time) n = len(spike_times) self.go_to(spike_times[(ind + delta) % n]) def go_to_next_spike(self, ): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(+1) def go_to_previous_spike(self, ): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(-1) def toggle_highlighted_spikes(self, checked): """Toggle between showing all spikes or selected spikes.""" self.show_all_spikes = checked self.set_interval() def widen(self): """Increase the interval size.""" t, h = self.time, self.half_duration h *= self.scaling_coeff_x self.set_interval((t - h, t + h)) def narrow(self): """Decrease the interval size.""" t, h = self.time, self.half_duration h /= self.scaling_coeff_x self.set_interval((t - h, t + h)) # Misc # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Toggle the display of the channel ids.""" logger.debug("Set show labels to %s.", checked) self.do_show_labels = checked self.text_visual.toggle() self.canvas.update() def toggle_auto_scale(self, checked): """Toggle automatic scaling of the traces.""" logger.debug("Set auto scale to %s.", checked) self.auto_scale = checked def update_color(self): """Update the view when the color scheme changes.""" self.plot(update_traces=False, update_waveforms=True) # Scaling # ------------------------------------------------------------------------- @property def scaling(self): """Scaling of the channel boxes.""" return self.stacked._box_scaling[1] @scaling.setter def scaling(self, value): self.stacked._box_scaling = (self.stacked._box_scaling[0], value) def _get_scaling_value(self): return self.scaling def _set_scaling_value(self, value): self.scaling = value self.stacked.update() # Spike selection # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on a spike.""" if 'Control' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = np.nonzero(self.channel_y_ranks == box_id)[0] # Find the spike and cluster closest to the mouse. db = self.data_bounds # Get the information about the displayed spikes. wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch] if not wt: return # Get the time coordinate of the mouse position. mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = zip(*wt) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the select_spike event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] emit('select_spike', self, channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id) if 'Shift' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0]) emit('select_channel', self, channel_id=channel_id, button=e.button) def on_mouse_wheel(self, e): # pragma: no cover """Scroll through the data with alt+wheel.""" super(TraceView, self).on_mouse_wheel(e) if e.modifiers == ('Alt',): start, end = self._interval delay = e.delta * (end - start) * .1 self.shift(-delay)
class EventMarker(IPlugin): # Line color of the event markers line_color = (1, 1, 1, 0.75) def attach_to_controller(self, controller): @connect def on_view_attached(view, gui): if isinstance(view, AmplitudeView): # Create batch of vertical lines (full height) self.line_visual = LineVisual() _fix_coordinate_in_visual(self.line_visual, 'y') view.canvas.add_visual(self.line_visual) # Create batch of annotative text self.text_visual = TextVisual(self.line_color) _fix_coordinate_in_visual(self.text_visual, 'y') self.text_visual.inserter.insert_vert( 'gl_Position.x += 0.001;', 'after_transforms') view.canvas.add_visual(self.text_visual) @view.actions.add(shortcut='alt+b', checkable=True, name='Toggle event markers') def toggle(on): """Toggle event markers""" # Use `show` and `hide` instead of `toggle` here in # case synchronization issues if on: logger.debug('Toggle on markers.') self.line_visual.show() self.text_visual.show() view.show_events = True else: logger.debug('Toggle off markers.') self.line_visual.hide() self.text_visual.hide() view.show_events = False view.canvas.update() @view.actions.add(shortcut='shift+alt+e', prompt=True, name='Go to event', alias='ge') def Go_to_event(event_num): trace_view = gui.get_view(TraceView) if 0 < event_num <= events.size: trace_view.go_to(events[event_num - 1]) # Disable the menu until events are successfully added view.actions.disable('Go to event') view.actions.disable('Toggle event markers') if not hasattr(view, 'show_events'): view.show_events = True view.state_attrs += ('show_events', ) # Read event markers from file filename = controller.dir_path / 'eventmarkers.txt' try: events = np.genfromtxt(filename, usecols=0, dtype=None) except (FileNotFoundError, OSError): logger.warn('Event marker file not found: `%s`.', filename) view.show_events = False return # Create list of event names labels = list(map(str, range(1, events.size + 1))) # Read event names from file (if present) filename = controller.dir_path / 'eventmarkernames.txt' try: eventnames = np.loadtxt(filename, usecols=0, dtype=str, max_rows=events.size) labels[:eventnames.size] = np.atleast_1d(eventnames) except (FileNotFoundError, OSError): logger.info( 'Event marker names file not found (optional):' ' `%s`. Fall back to numbering.', filename) # Obtain seconds from samples if events.dtype == int: logger.debug('Converting input from samples to seconds.') events = events / controller.model.sample_rate logger.debug('Add event markers to amplitude view.') # Obtain horizontal positions x = -1 + 2 * events / view.duration x = x.repeat(4, 0).reshape(-1, 4) x[:, 1::2] = 1, -1 # Add lines and update view self.line_visual.reset_batch() self.line_visual.add_batch_data(pos=x, color=self.line_color) view.canvas.update_visual(self.line_visual) # Add text and update view self.text_visual.reset_batch() self.text_visual.add_batch_data(pos=x[:, :2], anchor=(1, -1), text=labels) view.canvas.update_visual(self.text_visual) # Finally enable the menu logger.debug('Enable menu items.') view.actions.enable('Go to event') view.actions.enable('Toggle event markers') if view.show_events: view.actions.get('Toggle event markers').toggle() else: self.line_visual.hide() self.text_visual.hide()
class TraceView(ScalingMixin, ManualClusteringView): """This view shows the raw traces along with spike waveforms. Constructor ----------- traces : function Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where * `data` is an `(n_samples, n_channels)` array * `waveforms` is a list of bunchs with the following attributes: * `data` * `color` * `channel_ids` * `start_time` * `spike_id` * `spike_cluster` spike_times : function Teturns the list of relevant spike times. sample_rate : float duration : float n_channels : int channel_vertical_order : array-like Permutation of the channels. This 1D array gives the channel id of all channels from top to bottom (or conversely, depending on `origin=top|bottom`). channel_labels : list Labels of all shown channels. By default, this is just the channel ids. """ _default_position = 'left' auto_update = True auto_scale = True interval_duration = .25 # default duration of the interval shift_amount = .1 scaling_coeff_x = 1.25 trace_quantile = .01 # quantile for auto-scaling default_trace_color = (.5, .5, .5, 1) default_shortcuts = { 'change_trace_size': 'ctrl+wheel', 'decrease': 'alt+down', 'increase': 'alt+up', 'go_left': 'alt+left', 'go_right': 'alt+right', 'go_to_start': 'alt+home', 'go_to_end': 'alt+end', 'go_to': 'alt+t', 'go_to_next_spike': 'alt+pgdown', 'go_to_previous_spike': 'alt+pgup', 'narrow': 'alt++', 'select_spike': 'ctrl+click', 'switch_origin': 'alt+o', 'toggle_highlighted_spikes': 'alt+s', 'toggle_show_labels': 'alt+l', 'widen': 'alt+-', } default_snippets = { 'go_to': 'tg', 'shift': 'ts', } def __init__( self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None, channel_vertical_order=None, channel_labels=None, **kwargs): self.do_show_labels = True self.show_all_spikes = False self._scaling = 1. self.get_spike_times = spike_times # Sample rate. assert sample_rate > 0 self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces self.waveforms = None assert duration >= 0 self.duration = duration assert n_channels >= 0 self.n_channels = n_channels # Channel permutation. self._channel_perm = ( np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order) assert self._channel_perm.shape == (n_channels,) self._channel_perm = np.argsort(self._channel_perm) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(n_channels)]) assert len(self.channel_labels) == n_channels # Box and probe scaling. self._origin = None # Initialize the view. super(TraceView, self).__init__(**kwargs) self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale') self.local_state_attrs += ('interval', 'scaling',) self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) # Visuals. self.trace_visual = UniformPlotVisual() self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.canvas.add_visual(self.text_visual) # Make a copy of the initial box pos and size. We'll apply the scaling # to these quantities. self.box_size = np.array(self.canvas.stacked.box_size) # Initial interval. self._interval = None self.go_to(duration / 2.) self._waveform_times = [] @property def stacked(self): return self.canvas.stacked def _permute_channels(self, x, inv=False): cp = self._channel_perm cp = np.argsort(cp) return cp[x] # Internal methods # ------------------------------------------------------------------------- def _plot_traces(self, traces, color=None): traces = traces.T n_samples = traces.shape[1] n_ch = self.n_channels assert traces.shape == (n_ch, n_samples) color = color or self.default_trace_color t = self._interval[0] + np.arange(n_samples) * self.dt t = np.tile(t, (n_ch, 1)) box_index = self._permute_channels(np.arange(n_ch)) box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1) assert t.shape == (n_ch, n_samples) assert traces.shape == (n_ch, n_samples) assert box_index.shape == (n_ch, n_samples) self.trace_visual.color = color self.canvas.update_visual( self.trace_visual, t, traces, data_bounds=self.data_bounds, box_index=box_index.ravel(), ) def _plot_spike(self, bunch): # The spike time corresponds to the first sample of the waveform. n_samples, n_channels = bunch.data.shape assert len(bunch.channel_ids) == n_channels # Generate the x coordinates of the waveform. t = bunch.start_time + self.dt * np.arange(n_samples) t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) # The box index depends on the channel. box_index = self._permute_channels(bunch.channel_ids) box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0) self.waveform_visual.add_batch_data( box_index=box_index, x=t, y=bunch.data.T, color=bunch.color, data_bounds=self.data_bounds, ) def _plot_labels(self, traces): self.text_visual.reset_batch() for ch in range(self.n_channels): bi = self._permute_channels(ch) ch_label = self.channel_labels[ch] self.text_visual.add_batch_data( pos=[self.data_bounds[0], 0], text=ch_label, anchor=[+1., 0], data_bounds=self.data_bounds, box_index=bi, ) self.canvas.update_visual(self.text_visual) # Public methods # ------------------------------------------------------------------------- def _restrict_interval(self, interval): start, end = interval # Round the times to full samples to avoid subsampling shifts # in the traces. start = int(round(start * self.sample_rate)) / self.sample_rate end = int(round(end * self.sample_rate)) / self.sample_rate # Restrict the interval to the boundaries of the traces. if start < 0: end += (-start) start = 0 elif end >= self.duration: start -= (end - self.duration) end = self.duration start = np.clip(start, 0, end) end = np.clip(end, start, self.duration) assert 0 <= start < end <= self.duration return start, end def set_interval(self, interval=None, change_status=True): """Display the traces and spikes in a given interval.""" if interval is None: interval = self._interval interval = self._restrict_interval(interval) # Load the traces. traces = self.traces(interval) self.waveforms = traces.get('waveforms', []) if interval != self._interval: logger.debug("Redraw the entire trace view.") self._interval = interval start, end = interval # Set the status message. if change_status: self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) # Find the data bounds. if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC: ymin = np.quantile(traces.data, self.trace_quantile) ymax = np.quantile(traces.data, 1. - self.trace_quantile) else: ymin, ymax = self.data_bounds[1], self.data_bounds[3] self.data_bounds = (start, ymin, end, ymax) # Used for spike click. self._waveform_times = [] # Plot the traces. self._plot_traces( traces.data, color=traces.get('color', None)) # Plot the labels. if self.do_show_labels: self._plot_labels(traces.data) # Plot the waveforms. self.plot() def on_select(self, cluster_ids=None, **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return # Make sure we call again self.traces() when the cluster selection changes. self.set_interval() def plot(self, **kwargs): """Plot the waveforms.""" waveforms = self.waveforms assert isinstance(waveforms, list) if waveforms: self.waveform_visual.show() self.waveform_visual.reset_batch() for w in waveforms: self._plot_spike(w) self._waveform_times.append( (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None))) self.canvas.update_visual(self.waveform_visual) else: # pragma: no cover self.waveform_visual.hide() self._update_axes() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(TraceView, self).attach(gui) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add( self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes) self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale) self.actions.add(self.switch_origin) self.actions.separator() self.actions.add( self.go_to, prompt=True, prompt_default=lambda: str(self.time)) self.actions.separator() self.actions.add(self.go_to_start) self.actions.add(self.go_to_end) self.actions.separator() self.actions.add(self.shift, prompt=True) self.actions.add(self.go_right) self.actions.add(self.go_left) self.actions.separator() self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() self.actions.add(self.go_to_next_spike) self.actions.add(self.go_to_previous_spike) self.actions.separator() self.set_interval() # Origin # ------------------------------------------------------------------------- @property def origin(self): """Whether to show the channels from top to bottom (`top` option, the default), or from bottom to top (`bottom`).""" return self._origin @origin.setter def origin(self, value): self._origin = value if self.canvas.layout: self.canvas.layout.origin = value def switch_origin(self): """Switch between top and bottom origin for the channels.""" self.origin = 'top' if self._origin in ('bottom', None) else 'bottom' # Navigation # ------------------------------------------------------------------------- @property def time(self): """Time at the center of the window.""" return sum(self._interval) * .5 @property def interval(self): """Interval as `(tmin, tmax)`.""" return self._interval @interval.setter def interval(self, value): self.set_interval(value) @property def half_duration(self): """Half of the duration of the current interval.""" if self._interval is not None: a, b = self._interval return (b - a) * .5 else: return self.interval_duration * .5 def go_to(self, time): """Go to a specific time (in seconds).""" half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) def shift(self, delay): """Shift the interval by a given delay (in seconds).""" self.go_to(self.time + delay) def go_to_start(self): """Go to the start of the recording.""" self.go_to(0) def go_to_end(self): """Go to end of the recording.""" self.go_to(self.duration) def go_right(self): """Go to right.""" start, end = self._interval delay = (end - start) * .1 self.shift(delay) def go_left(self): """Go to left.""" start, end = self._interval delay = (end - start) * .1 self.shift(-delay) def _jump_to_spike(self, delta=+1): """Jump to next or previous spike from the selected clusters.""" spike_times = self.get_spike_times() if spike_times is not None and len(spike_times): ind = np.searchsorted(spike_times, self.time) n = len(spike_times) self.go_to(spike_times[(ind + delta) % n]) def go_to_next_spike(self, ): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(+1) def go_to_previous_spike(self, ): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(-1) def toggle_highlighted_spikes(self, checked): """Toggle between showing all spikes or selected spikes.""" self.show_all_spikes = checked self.set_interval() def widen(self): """Increase the interval size.""" t, h = self.time, self.half_duration h *= self.scaling_coeff_x self.set_interval((t - h, t + h)) def narrow(self): """Decrease the interval size.""" t, h = self.time, self.half_duration h /= self.scaling_coeff_x self.set_interval((t - h, t + h)) # Misc # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Toggle the display of the channel ids.""" logger.debug("Set show labels to %s.", checked) self.do_show_labels = checked self.set_interval() def toggle_auto_scale(self, checked): """Toggle automatic scaling of the traces.""" logger.debug("Set auto scale to %s.", checked) self.auto_scale = checked # Scaling # ------------------------------------------------------------------------- def _apply_scaling(self): self.canvas.layout.scaling = (self.canvas.layout.scaling[0], self._scaling) @property def scaling(self): """Scaling of the channel boxes.""" return self._scaling @scaling.setter def scaling(self, value): self._scaling = value self._apply_scaling() def _get_scaling_value(self): return self.scaling def _set_scaling_value(self, value): self.scaling = value # Spike selection # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on a spike.""" if 'Control' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = self._permute_channels(box_id, inv=True) # Find the spike and cluster closest to the mouse. db = self.data_bounds # Get the information about the displayed spikes. wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch] if not wt: return # Get the time coordinate of the mouse position. mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = zip(*wt) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the spike_click event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] emit('spike_click', self, channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id)