def __init__( self, templates=None, channel_ids=None, channel_labels=None, cluster_ids=None, cluster_color_selector=None, **kwargs): super(TemplateView, self).__init__(**kwargs) self.state_attrs += () self.local_state_attrs += ('scaling',) self.cluster_color_selector = cluster_color_selector # Full list of channels. self.channel_ids = channel_ids self.n_channels = len(channel_ids) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(self.n_channels)]) assert len(self.channel_labels) == self.n_channels # TODO: show channel and cluster labels # Full list of clusters. if cluster_ids is not None: self.set_cluster_ids(cluster_ids) self.canvas.set_layout('grid', box_bounds=[[-1, -1, +1, +1]], has_clip=False) self.canvas.enable_axes() self.templates = templates self.visual = PlotVisual() self.canvas.add_visual(self.visual) self._cluster_box_index = {} # dict {cluster_id: box_index} used to quickly reorder self.select_visual = PlotVisual() self.canvas.add_visual(self.select_visual)
def __init__(self, waveforms=None, waveforms_type=None, sample_rate=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () self.wave_duration = 0. # updated in the plotting method self.data_bounds = None self.sample_rate = sample_rate self._status_suffix = '' assert sample_rate > 0., "The sample rate must be provided to the waveform view." # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2))) # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms or {} waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } self.waveforms = waveforms # Rotating property waveforms types. self.waveforms_types = RotatingProperty() for name, value in self.waveforms.items(): self.waveforms_types.add(name, value) # Current waveforms type. self.waveforms_types.set(waveforms_type) assert self.waveforms_type in self.waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.tick_visual = UniformScatterVisual(marker='vbar', color=self.ax_color, size=self.tick_size) self.canvas.add_visual(self.tick_visual) # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased # agg plot visual for mean and template waveforms. self.waveform_agg_visual = PlotAggVisual() self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_agg_visual) self.canvas.add_visual(self.waveform_visual)
def __init__(self, cluster_stat=None): super(HistogramView, self).__init__() self.state_attrs += self._state_attrs self.local_state_attrs += self._local_state_attrs self.canvas.set_layout(layout='stacked', n_plots=1) self.canvas.enable_axes() self.cluster_stat = cluster_stat self.visual = HistogramVisual() self.canvas.add_visual(self.visual) self.plot_visual = PlotVisual() self.canvas.add_visual(self.plot_visual) self.text_visual = TextVisual(color=(1., 1., 1., 1.)) self.canvas.add_visual(self.text_visual)
def _create_visuals(self): self.canvas.set_layout('stacked', n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) self.trace_visual = UniformPlotVisual() self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_boxes >= 50 * u_zoom.y) && ' '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual)
def __init__(self, waveforms=None, waveforms_type=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]]) self.canvas.enable_axes() self._box_scaling = (1., 1.) self._probe_scaling = (1., 1.) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } assert waveforms self.waveforms = waveforms self.waveforms_types = list(waveforms.keys()) # Current waveforms type. self.waveforms_type = waveforms_type or self.waveforms_types[0] assert self.waveforms_type in waveforms assert 'waveforms' in waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual)
def _create_visuals(self): self.canvas.set_layout('stacked', n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) self.trace_visual = UniformPlotVisual() # Gradient of color for the traces. if self.trace_color_0 and self.trace_color_1: self.trace_visual.inserter.insert_frag( 'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % ( self.trace_color_0, self.trace_color_1, self.n_channels), 'end') self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_boxes >= 50 * u_zoom.y) && ' '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual)
def __init__(self, model=None): super(WaveformClusteringView, self).__init__() self.canvas.enable_axes() self.canvas.enable_lasso() self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,)) self.model = model self.gain = 0.195 self.Fs = 30 # kHz self.visual = PlotVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0) self.cluster_ids = None self.cluster_id = None self.channel_ids = None self.wavefs = None self.current_channel_idx = None
def __init__(self, templates=None): """ Typically, the constructor takes as arguments *functions* that take as input one or several cluster ids, and return as many Bunch instances which contain the data as NumPy arrays. Many such functions are defined in the TemplateController. """ super(MyOpenGLView, self).__init__() """ The View instance contains a special `canvas` object which is a `̀PlotCanvas` instance. This class derives from `BaseCanvas` which itself derives from the PyQt5 `QOpenGLWindow`. The canvas represents a rectangular black window where you can draw geometric objects with OpenGL. phy uses the notion of **Layout** that lets you organize graphical elements in different subplots. These subplots can be organized in several ways: * Grid layout: a `(n_rows, n_cols)` grid of subplots (example: FeatureView). * Boxed layout: boxes arbitrarily located (example: WaveformView, using the probe geometry) * Stacked layout: one column with `n_boxes` subplots (example: TraceView, one row per channel) In this example, we use the stacked layout, with one subplot per cluster. This number will change at each cluster selection, depending on the number of selected clusters. But initially, we just use 1 subplot. """ self.canvas.set_layout('stacked', n_plots=1) self.templates = templates """ phy uses the notion of **Visual**. This is a graphical element that is represented with a single type of graphical element. phy provides many visuals: * PlotVisual (plots) * ScatterVisual (points with a given marker type and different colors and sizes) * LineVisual (for lines segments) * HistogramVisual * PolygonVisual * TextVisual * ImageVisual Each visual comes with a single OpenGL program, which is defined by a vertex shader and a fragment shader. These are programs written in a C-like language called GLSL. A visual also comes with a primitive type, which can be points, line segments, or triangles. This is all a GPU is able to render, but the position and the color of these primitives can be entirely customized in the shaders. The vertex shader acts on data arrays represented as NumPy arrays. These low-level details are hidden by the visuals abstraction, so it is unlikely that you'll ever need to write your own visual. In ManualClusteringViews, you typically define one or several visuals. For example if you need to add text, you would add `self.text_visual = TextVisual()`. """ self.visual = PlotVisual() """ For internal reasons, you need to add all visuals (empty for now) directly to the canvas, in the view's constructor. Later, we will use the `visual.set_data()` method to update the visual's data and display something in the figure. """ self.canvas.add_visual(self.visual)
class MyOpenGLView(ManualClusteringView): """All OpenGL views derive from ManualClusteringView.""" def __init__(self, templates=None): """ Typically, the constructor takes as arguments *functions* that take as input one or several cluster ids, and return as many Bunch instances which contain the data as NumPy arrays. Many such functions are defined in the TemplateController. """ super(MyOpenGLView, self).__init__() """ The View instance contains a special `canvas` object which is a `̀PlotCanvas` instance. This class derives from `BaseCanvas` which itself derives from the PyQt5 `QOpenGLWindow`. The canvas represents a rectangular black window where you can draw geometric objects with OpenGL. phy uses the notion of **Layout** that lets you organize graphical elements in different subplots. These subplots can be organized in several ways: * Grid layout: a `(n_rows, n_cols)` grid of subplots (example: FeatureView). * Boxed layout: boxes arbitrarily located (example: WaveformView, using the probe geometry) * Stacked layout: one column with `n_boxes` subplots (example: TraceView, one row per channel) In this example, we use the stacked layout, with one subplot per cluster. This number will change at each cluster selection, depending on the number of selected clusters. But initially, we just use 1 subplot. """ self.canvas.set_layout('stacked', n_plots=1) self.templates = templates """ phy uses the notion of **Visual**. This is a graphical element that is represented with a single type of graphical element. phy provides many visuals: * PlotVisual (plots) * ScatterVisual (points with a given marker type and different colors and sizes) * LineVisual (for lines segments) * HistogramVisual * PolygonVisual * TextVisual * ImageVisual Each visual comes with a single OpenGL program, which is defined by a vertex shader and a fragment shader. These are programs written in a C-like language called GLSL. A visual also comes with a primitive type, which can be points, line segments, or triangles. This is all a GPU is able to render, but the position and the color of these primitives can be entirely customized in the shaders. The vertex shader acts on data arrays represented as NumPy arrays. These low-level details are hidden by the visuals abstraction, so it is unlikely that you'll ever need to write your own visual. In ManualClusteringViews, you typically define one or several visuals. For example if you need to add text, you would add `self.text_visual = TextVisual()`. """ self.visual = PlotVisual() """ For internal reasons, you need to add all visuals (empty for now) directly to the canvas, in the view's constructor. Later, we will use the `visual.set_data()` method to update the visual's data and display something in the figure. """ self.canvas.add_visual(self.visual) def on_select(self, cluster_ids=(), **kwargs): """ The main method to implement in ManualClusteringView is `on_select()`, called whenever new clusters are selected. *Note*: `cluster_ids` contains the clusters selected in the cluster view, followed by clusters selected in the similarity view. """ """ This method should always start with these few lines of code. """ self.cluster_ids = cluster_ids if not cluster_ids: return """ We update the number of boxes in the stacked layout, which is the number of selected clusters. """ self.canvas.stacked.n_boxes = len(cluster_ids) """ We obtain the template data. """ bunchs = { cluster_id: self.templates(cluster_id).data for cluster_id in cluster_ids } """ For performance reasons, it is best to use as few visuals as possible. In this example, we want 1 waveform template per subplot. We will use a single visual covering all subplots at once. This is the key to achieve good performance with OpenGL in Python. However, this comes with the drawback that the programming interface is more complicated. In principle, we would have to concatenate all data (x and y coordinates) of all subplots to pass it to `self.visual.set_data()` in order to draw all subplots at once. But this is tedious. phy uses the notion of **batch**: for each subplot, we set *partial data* for the subplot which just prepares the data for concatenation *after* we're done with looping through all clusters. The concatenation happens in the special call `self.canvas.update_visual(self.visual)`. We need to call `visual.reset_batch()` before constructing a batch. """ self.visual.reset_batch() """ We iterate through all selected clusters. """ for idx, cluster_id in enumerate(cluster_ids): bunch = bunchs[cluster_id] """ In this example, we just keep the peak channel. Note that `bunch.template` is a 2D array `(n_samples, n_channels)` where `n_channels` in the number of "best" channels for the cluster. The channels are sorted by decreasing template amplitude, so the first one is the peak channel. The channel ids can be found in `bunch.channel_ids`. """ y = bunch.template[:, 0] """ We decide to use, on the x axis, values ranging from -1 to 1. This is the standard viewport in OpenGL and phy. """ x = np.linspace(-1., 1., len(y)) """ phy requires you to specify explicitly the x and y range of the plots. The `data_bounds` variable is a `(xmin, ymin, xmax, ymax)` tuple representing the lower-left and upper-right corners of a rectangle. By default, the data bounds of the entire view is (-1, -1, 1, 1), also called normalized device coordinates. Eventually, OpenGL uses this coordinate system for display, but phy provides a transform system to convert from different coordinate systems, both on the CPU and the GPU. Here, the x range is (-1, 1), and the y range is (m, M) where m and M are respectively the min and max of the template. """ m, M = y.min(), y.max() data_bounds = (-1, m, +1, M) """ This function gives the color of the i-th selected cluster. This is a 4-tuple with values between 0 and 1 for RGBA: red, green, blue, alpha channel (transparency, 1 by default). """ color = selected_cluster_color(idx) """ The plot visual takes as input the x and y coordinates of the points, the color, and the data bounds. There is also a special keyword argument `box_index` which is the subplot index. In the stacked layout, this is just an integer identifying the subplot index, from top to bottom. Note that in the grid view, the box index is a pair (row, col). """ self.visual.add_batch_data(x=x, y=y, color=color, data_bounds=data_bounds, box_index=idx) """ After the loop, this special call automatically builds the data to upload to the GPU by concatenating the partial data set in `add_batch_data()`. """ self.canvas.update_visual(self.visual) """ After updating the data on the GPU, we need to refresh the canvas. """ self.canvas.update()
class HistogramView(ScalingMixin, ManualClusteringView): """This view displays a histogram for every selected cluster, along with a possible plot and some text. To be overriden. Constructor ----------- cluster_stat : function Maps `cluster_id` to `Bunch(data (1D array), plot (1D array), text)`. """ _default_position = 'right' cluster_ids = () # Number of bins in the histogram. n_bins = 100 # Minimum value on the x axis (determines the range of the histogram) # If None, then `data.min()` is used. x_min = None # Maximum value on the x axis (determines the range of the histogram) # If None, then `data.max()` is used. x_max = None # Unit of the bin in the set_bin_size, set_x_min, set_x_max actions. bin_unit = 's' # s (seconds) or ms (milliseconds) # The snippet to update this view are `hn` to change the number of bins, and `hm` to # change the maximum value on the x axis. The character `h` can be customized by child classes. alias_char = 'h' default_shortcuts = { 'change_window_size': 'ctrl+wheel', } default_snippets = { 'set_n_bins': '%sn' % alias_char, 'set_bin_size (%s)' % bin_unit: '%sb' % alias_char, 'set_x_min (%s)' % bin_unit: '%smin' % alias_char, 'set_x_max (%s)' % bin_unit: '%smax' % alias_char, } _state_attrs = ('n_bins', 'x_min', 'x_max') _local_state_attrs = () def __init__(self, cluster_stat=None): super(HistogramView, self).__init__() self.state_attrs += self._state_attrs self.local_state_attrs += self._local_state_attrs self.canvas.set_layout(layout='stacked', n_plots=1) self.canvas.enable_axes() self.cluster_stat = cluster_stat self.visual = HistogramVisual() self.canvas.add_visual(self.visual) self.plot_visual = PlotVisual() self.canvas.add_visual(self.plot_visual) self.text_visual = TextVisual(color=(1., 1., 1., 1.)) self.canvas.add_visual(self.text_visual) def _plot_cluster(self, bunch): assert bunch n_bins = self.n_bins assert n_bins >= 0 # Update the visual's data. self.visual.add_batch_data( hist=bunch.histogram, ylim=bunch.ylim, color=bunch.color, box_index=bunch.index) # Plot. plot = bunch.get('plot', None) if plot is not None: x = np.linspace(self.x_min, self.x_max, len(plot)) self.plot_visual.add_batch_data( x=x, y=plot, color=(1, 1, 1, 1), data_bounds=self.data_bounds, box_index=bunch.index, ) text = bunch.get('text', None) if not text: return # Support multiline text. text = text.splitlines() n = len(text) self.text_visual.add_batch_data( text=text, pos=[(-1, .8)] * n, anchor=[(1, -1 - 2 * i) for i in range(n)], box_index=bunch.index, ) def get_clusters_data(self, load_all=None): bunchs = [] for i, cluster_id in enumerate(self.cluster_ids): bunch = self.cluster_stat(cluster_id) if not bunch.data.size: continue bmin, bmax = bunch.data.min(), bunch.data.max() # Update self.x_max if it was not set before. self.x_min = self.x_min or bunch.get('x_min', None) or bmin self.x_max = self.x_max or bunch.get('x_max', None) or bmax self.x_min = min(self.x_min, self.x_max) assert self.x_min is not None assert self.x_max is not None assert self.x_min <= self.x_max # Compute the histogram. bunch.histogram = _compute_histogram( bunch.data, x_min=self.x_min, x_max=self.x_max, n_bins=self.n_bins) bunch.ylim = bunch.histogram.max() bunch.color = selected_cluster_color(i) bunch.index = i bunch.cluster_id = cluster_id bunchs.append(bunch) return bunchs def _get_data_bounds(self, bunchs): # Get the axes data bounds (the last subplot's extended n_cluster times on the y axis). ylim = max(bunch.ylim for bunch in bunchs) if bunchs else 1 return (self.x_min, 0, self.x_max, ylim * len(self.cluster_ids)) def plot(self, **kwargs): """Update the view with the selected clusters.""" bunchs = self.get_clusters_data() self.data_bounds = self._get_data_bounds(bunchs) self.canvas.stacked.n_boxes = len(self.cluster_ids) self.visual.reset_batch() self.plot_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.visual) self.canvas.update_visual(self.plot_visual) self.canvas.update_visual(self.text_visual) self._update_axes() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(HistogramView, self).attach(gui) self.actions.add( self.set_n_bins, alias=self.alias_char + 'n', prompt=True, prompt_default=lambda: self.n_bins) self.actions.add( self.set_bin_size, alias=self.alias_char + 'b', prompt=True, prompt_default=lambda: self.bin_size) self.actions.add( self.set_x_min, alias=self.alias_char + 'min', prompt=True, prompt_default=lambda: self.x_min) self.actions.add( self.set_x_max, alias=self.alias_char + 'max', prompt=True, prompt_default=lambda: self.x_max) self.actions.separator() @property def status(self): f = 1 if self.bin_unit == 's' else 1000 return '[{:.1f}{u}, {:.1f}{u:s}]'.format( (self.x_min or 0) * f, (self.x_max or 0) * f, u=self.bin_unit) # Histogram parameters # ------------------------------------------------------------------------- def _get_scaling_value(self): return self.x_max def _set_scaling_value(self, value): if self.bin_unit == 'ms': value *= 1000 self.set_x_max(value) def set_n_bins(self, n_bins): """Set the number of bins in the histogram.""" self.n_bins = n_bins logger.debug("Change number of bins to %d for %s.", n_bins, self.__class__.__name__) self.plot() @property def bin_size(self): """Return the bin size (in seconds or milliseconds depending on `self.bin_unit`).""" bs = (self.x_max - self.x_min) / self.n_bins if self.bin_unit == 'ms': bs *= 1000 return bs def set_bin_size(self, bin_size): """Set the bin size in the histogram.""" assert bin_size > 0 if self.bin_unit == 'ms': bin_size /= 1000 self.n_bins = np.round((self.x_max - self.x_min) / bin_size) logger.debug("Change number of bins to %d for %s.", self.n_bins, self.__class__.__name__) self.plot() def set_x_min(self, x_min): """Set the minimum value on the x axis for the histogram.""" if self.bin_unit == 'ms': x_min /= 1000 x_min = min(x_min, self.x_max) if x_min == self.x_max: return self.x_min = x_min logger.debug("Change x min to %s for %s.", x_min, self.__class__.__name__) self.plot() def set_x_max(self, x_max): """Set the maximum value on the x axis for the histogram.""" if self.bin_unit == 'ms': x_max /= 1000 x_max = max(x_max, self.x_min) if x_max == self.x_min: return self.x_max = x_max logger.debug("Change x max to %s for %s.", x_max, self.__class__.__name__) self.plot()
class TraceView(ScalingMixin, ManualClusteringView): """This view shows the raw traces along with spike waveforms. Constructor ----------- traces : function Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where * `data` is an `(n_samples, n_channels)` array * `waveforms` is a list of bunchs with the following attributes: * `data` * `color` * `channel_ids` * `start_time` * `spike_id` * `spike_cluster` spike_times : function Teturns the list of relevant spike times. sample_rate : float duration : float n_channels : int channel_vertical_order : array-like Permutation of the channels. This 1D array gives the channel id of all channels from top to bottom (or conversely, depending on `origin=top|bottom`). channel_labels : list Labels of all shown channels. By default, this is just the channel ids. """ _default_position = 'left' auto_update = True auto_scale = True interval_duration = .25 # default duration of the interval shift_amount = .1 scaling_coeff_x = 1.25 trace_quantile = .01 # quantile for auto-scaling default_trace_color = (.5, .5, .5, 1) default_shortcuts = { 'change_trace_size': 'ctrl+wheel', 'decrease': 'alt+down', 'increase': 'alt+up', 'go_left': 'alt+left', 'go_right': 'alt+right', 'go_to_start': 'alt+home', 'go_to_end': 'alt+end', 'go_to': 'alt+t', 'go_to_next_spike': 'alt+pgdown', 'go_to_previous_spike': 'alt+pgup', 'narrow': 'alt++', 'select_spike': 'ctrl+click', 'switch_origin': 'alt+o', 'toggle_highlighted_spikes': 'alt+s', 'toggle_show_labels': 'alt+l', 'widen': 'alt+-', } default_snippets = { 'go_to': 'tg', 'shift': 'ts', } def __init__( self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None, channel_vertical_order=None, channel_labels=None, **kwargs): self.do_show_labels = True self.show_all_spikes = False self._scaling = 1. self.get_spike_times = spike_times # Sample rate. assert sample_rate > 0 self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces self.waveforms = None assert duration >= 0 self.duration = duration assert n_channels >= 0 self.n_channels = n_channels # Channel permutation. self._channel_perm = ( np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order) assert self._channel_perm.shape == (n_channels,) self._channel_perm = np.argsort(self._channel_perm) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(n_channels)]) assert len(self.channel_labels) == n_channels # Box and probe scaling. self._origin = None # Initialize the view. super(TraceView, self).__init__(**kwargs) self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale') self.local_state_attrs += ('interval', 'scaling',) self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) # Visuals. self.trace_visual = UniformPlotVisual() self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.canvas.add_visual(self.text_visual) # Make a copy of the initial box pos and size. We'll apply the scaling # to these quantities. self.box_size = np.array(self.canvas.stacked.box_size) # Initial interval. self._interval = None self.go_to(duration / 2.) self._waveform_times = [] @property def stacked(self): return self.canvas.stacked def _permute_channels(self, x, inv=False): cp = self._channel_perm cp = np.argsort(cp) return cp[x] # Internal methods # ------------------------------------------------------------------------- def _plot_traces(self, traces, color=None): traces = traces.T n_samples = traces.shape[1] n_ch = self.n_channels assert traces.shape == (n_ch, n_samples) color = color or self.default_trace_color t = self._interval[0] + np.arange(n_samples) * self.dt t = np.tile(t, (n_ch, 1)) box_index = self._permute_channels(np.arange(n_ch)) box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1) assert t.shape == (n_ch, n_samples) assert traces.shape == (n_ch, n_samples) assert box_index.shape == (n_ch, n_samples) self.trace_visual.color = color self.canvas.update_visual( self.trace_visual, t, traces, data_bounds=self.data_bounds, box_index=box_index.ravel(), ) def _plot_spike(self, bunch): # The spike time corresponds to the first sample of the waveform. n_samples, n_channels = bunch.data.shape assert len(bunch.channel_ids) == n_channels # Generate the x coordinates of the waveform. t = bunch.start_time + self.dt * np.arange(n_samples) t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) # The box index depends on the channel. box_index = self._permute_channels(bunch.channel_ids) box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0) self.waveform_visual.add_batch_data( box_index=box_index, x=t, y=bunch.data.T, color=bunch.color, data_bounds=self.data_bounds, ) def _plot_labels(self, traces): self.text_visual.reset_batch() for ch in range(self.n_channels): bi = self._permute_channels(ch) ch_label = self.channel_labels[ch] self.text_visual.add_batch_data( pos=[self.data_bounds[0], 0], text=ch_label, anchor=[+1., 0], data_bounds=self.data_bounds, box_index=bi, ) self.canvas.update_visual(self.text_visual) # Public methods # ------------------------------------------------------------------------- def _restrict_interval(self, interval): start, end = interval # Round the times to full samples to avoid subsampling shifts # in the traces. start = int(round(start * self.sample_rate)) / self.sample_rate end = int(round(end * self.sample_rate)) / self.sample_rate # Restrict the interval to the boundaries of the traces. if start < 0: end += (-start) start = 0 elif end >= self.duration: start -= (end - self.duration) end = self.duration start = np.clip(start, 0, end) end = np.clip(end, start, self.duration) assert 0 <= start < end <= self.duration return start, end def set_interval(self, interval=None, change_status=True): """Display the traces and spikes in a given interval.""" if interval is None: interval = self._interval interval = self._restrict_interval(interval) # Load the traces. traces = self.traces(interval) self.waveforms = traces.get('waveforms', []) if interval != self._interval: logger.debug("Redraw the entire trace view.") self._interval = interval start, end = interval # Set the status message. if change_status: self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) # Find the data bounds. if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC: ymin = np.quantile(traces.data, self.trace_quantile) ymax = np.quantile(traces.data, 1. - self.trace_quantile) else: ymin, ymax = self.data_bounds[1], self.data_bounds[3] self.data_bounds = (start, ymin, end, ymax) # Used for spike click. self._waveform_times = [] # Plot the traces. self._plot_traces( traces.data, color=traces.get('color', None)) # Plot the labels. if self.do_show_labels: self._plot_labels(traces.data) # Plot the waveforms. self.plot() def on_select(self, cluster_ids=None, **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return # Make sure we call again self.traces() when the cluster selection changes. self.set_interval() def plot(self, **kwargs): """Plot the waveforms.""" waveforms = self.waveforms assert isinstance(waveforms, list) if waveforms: self.waveform_visual.show() self.waveform_visual.reset_batch() for w in waveforms: self._plot_spike(w) self._waveform_times.append( (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None))) self.canvas.update_visual(self.waveform_visual) else: # pragma: no cover self.waveform_visual.hide() self._update_axes() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(TraceView, self).attach(gui) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add( self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes) self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale) self.actions.add(self.switch_origin) self.actions.separator() self.actions.add( self.go_to, prompt=True, prompt_default=lambda: str(self.time)) self.actions.separator() self.actions.add(self.go_to_start) self.actions.add(self.go_to_end) self.actions.separator() self.actions.add(self.shift, prompt=True) self.actions.add(self.go_right) self.actions.add(self.go_left) self.actions.separator() self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() self.actions.add(self.go_to_next_spike) self.actions.add(self.go_to_previous_spike) self.actions.separator() self.set_interval() # Origin # ------------------------------------------------------------------------- @property def origin(self): """Whether to show the channels from top to bottom (`top` option, the default), or from bottom to top (`bottom`).""" return self._origin @origin.setter def origin(self, value): self._origin = value if self.canvas.layout: self.canvas.layout.origin = value def switch_origin(self): """Switch between top and bottom origin for the channels.""" self.origin = 'top' if self._origin in ('bottom', None) else 'bottom' # Navigation # ------------------------------------------------------------------------- @property def time(self): """Time at the center of the window.""" return sum(self._interval) * .5 @property def interval(self): """Interval as `(tmin, tmax)`.""" return self._interval @interval.setter def interval(self, value): self.set_interval(value) @property def half_duration(self): """Half of the duration of the current interval.""" if self._interval is not None: a, b = self._interval return (b - a) * .5 else: return self.interval_duration * .5 def go_to(self, time): """Go to a specific time (in seconds).""" half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) def shift(self, delay): """Shift the interval by a given delay (in seconds).""" self.go_to(self.time + delay) def go_to_start(self): """Go to the start of the recording.""" self.go_to(0) def go_to_end(self): """Go to end of the recording.""" self.go_to(self.duration) def go_right(self): """Go to right.""" start, end = self._interval delay = (end - start) * .1 self.shift(delay) def go_left(self): """Go to left.""" start, end = self._interval delay = (end - start) * .1 self.shift(-delay) def _jump_to_spike(self, delta=+1): """Jump to next or previous spike from the selected clusters.""" spike_times = self.get_spike_times() if spike_times is not None and len(spike_times): ind = np.searchsorted(spike_times, self.time) n = len(spike_times) self.go_to(spike_times[(ind + delta) % n]) def go_to_next_spike(self, ): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(+1) def go_to_previous_spike(self, ): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(-1) def toggle_highlighted_spikes(self, checked): """Toggle between showing all spikes or selected spikes.""" self.show_all_spikes = checked self.set_interval() def widen(self): """Increase the interval size.""" t, h = self.time, self.half_duration h *= self.scaling_coeff_x self.set_interval((t - h, t + h)) def narrow(self): """Decrease the interval size.""" t, h = self.time, self.half_duration h /= self.scaling_coeff_x self.set_interval((t - h, t + h)) # Misc # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Toggle the display of the channel ids.""" logger.debug("Set show labels to %s.", checked) self.do_show_labels = checked self.set_interval() def toggle_auto_scale(self, checked): """Toggle automatic scaling of the traces.""" logger.debug("Set auto scale to %s.", checked) self.auto_scale = checked # Scaling # ------------------------------------------------------------------------- def _apply_scaling(self): self.canvas.layout.scaling = (self.canvas.layout.scaling[0], self._scaling) @property def scaling(self): """Scaling of the channel boxes.""" return self._scaling @scaling.setter def scaling(self, value): self._scaling = value self._apply_scaling() def _get_scaling_value(self): return self.scaling def _set_scaling_value(self, value): self.scaling = value # Spike selection # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on a spike.""" if 'Control' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = self._permute_channels(box_id, inv=True) # Find the spike and cluster closest to the mouse. db = self.data_bounds # Get the information about the displayed spikes. wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch] if not wt: return # Get the time coordinate of the mouse position. mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = zip(*wt) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the spike_click event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] emit('spike_click', self, channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id)
class WaveformClusteringView(LassoMixin,ManualClusteringView): default_shortcuts = { 'next_channel': 'f', 'previous_channel': 'r', } def __init__(self, model=None): super(WaveformClusteringView, self).__init__() self.canvas.enable_axes() self.canvas.enable_lasso() self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,)) self.model = model self.gain = 0.195 self.Fs = 30 # kHz self.visual = PlotVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.97, .95) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.01, 0) self.cluster_ids = None self.cluster_id = None self.channel_ids = None self.wavefs = None self.current_channel_idx = None def on_select(self, cluster_ids=(), **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return if self.cluster_id != cluster_ids[0]: self.cluster_id = cluster_ids[0] self.channel_ids = self.model.get_cluster_channels(self.cluster_id) self.spike_ids = self.model.get_cluster_spikes(self.cluster_id) self.wavefs = self.model.get_waveforms(self.spike_ids,channel_ids=self.channel_ids) self.setChannelIdx(0) def plotWaveforms(self): Nspk,Ntime,Nchan = self.wavefs.shape self.visual.reset_batch() self.text_visual.reset_batch() x = np.tile(np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime), (Nspk, 1)) M=np.max(np.abs(self.wavefs[:,:,self.current_channel_idx])) #print(M*self.gain) if M*self.gain<100: M = 10*np.ceil(M*self.gain/10) elif M*self.gain<1000: M = 100*np.ceil(M*self.gain/100) else: M = 1000*np.floor(M*self.gain/1000) self.data_bounds = (x[0][0], -M, x[0][-1], M) colorwavef = selected_cluster_color(0) colormedian = selected_cluster_color(3)#(1,156/256,0,.5)#selected_cluster_color(1) colorstd = (0,1,0,1)#selected_cluster_color(2) colorqtl = (1,1,0,1) if Nspk>100: medianCl = np.median(self.wavefs[:,:,self.current_channel_idx],axis=0) stdCl = np.std(self.wavefs[:,:,self.current_channel_idx],axis=0) q1 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.01,axis=0,interpolation='higher') q9 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.99,axis=0,interpolation='lower') self.visual.add_batch_data( x=x, y=self.gain*self.wavefs[:,:,self.current_channel_idx], color=colorwavef, data_bounds=self.data_bounds, box_index=0) #stats if Nspk>100: x1 = x[0] self.visual.add_batch_data( x=x1, y=self.gain*medianCl, color=colormedian, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*(medianCl+3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*(medianCl-3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*q1, color=colorqtl, data_bounds=self.data_bounds, box_index=0) self.visual.add_batch_data( x=x1, y=self.gain*q9, color=colorqtl, data_bounds=self.data_bounds, box_index=0) #axes self.text_visual.add_batch_data( pos=[.9, .98], text='[uV]', anchor=[-1, -1], box_index=0, ) self.text_visual.add_batch_data( pos=[-1, -.95], text='[ms]', anchor=[1, 1], box_index=0, ) label = 'Ch {a}'.format(a=self.channel_ids[self.current_channel_idx]) self.text_visual.add_batch_data( pos=[-.98, .98], text=str(label), anchor=[1, -1], box_index=0, ) self.canvas.update_visual(self.visual) self.canvas.update_visual(self.text_visual) self.canvas.axes.reset_data_bounds(self.data_bounds) self.canvas.update() def setChannel(self,channel_id): self.channel_ids itemindex = np.where(self.channel_ids==channel_id)[0] if len(itemindex): self.setChannelIdx(itemindex[0]) def setChannelIdx(self,channel_idx): self.current_channel_idx = channel_idx self.plotWaveforms() def setNextChannelIdx(self): if self.current_channel_idx == len(self.channel_ids)-1: return self.setChannelIdx(self.current_channel_idx+1) def setPrevChannelIdx(self): if self.current_channel_idx == 0: return self.setChannelIdx(self.current_channel_idx-1) def on_request_split(self, sender=None): """Return the spikes enclosed by the lasso.""" if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) pos = [] spike_ids = [] Ntime = self.wavefs.shape[1] x = np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime) for idx,spike in enumerate(self.spike_ids): points = np.c_[x,self.gain*self.wavefs[idx,:,self.current_channel_idx]] pos.append(points) spike_ids.append([spike]*len(x)) if not pos: # pragma: no cover logger.warning("Empty lasso.") return np.array([]) pos = np.vstack(pos) pos = range_transform([self.data_bounds], [NDC], pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.canvas.lasso.in_polygon(pos) self.canvas.lasso.clear() # Return all spikes not lassoed, so the selected cluster is still the same we are working on spikes_to_remove = np.unique(spike_ids[ind]) keepspikes=np.isin(self.spike_ids,spikes_to_remove,assume_unique=True,invert=True) A=self.spike_ids[self.spike_ids != spikes_to_remove] if len(A)>0: return self.spike_ids[keepspikes] else: return np.array([], dtype=np.int64)
class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveform_type : str Default key of the waveforms dictionary to plot initially. """ _default_position = 'right' cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_bounds=[[-1, -1, +1, +1]]) self.canvas.enable_axes() self._box_scaling = (1., 1.) self._probe_scaling = (1., 1.) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } assert waveforms self.waveforms = waveforms self.waveforms_types = list(waveforms.keys()) # Current waveforms type. self.waveforms_type = waveforms_type or self.waveforms_types[0] assert self.waveforms_type in waveforms assert 'waveforms' in waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) return [-1, m, +1, M] def get_clusters_data(self): bunchs = [ self.waveforms[self.waveforms_type](cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. masks *= .99999 # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, n_samples) box_index = np.tile(box_index, n_spikes_clu) assert box_index.shape == (n_spikes_clu * n_channels * n_samples, ) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) wave = wave.reshape((n_spikes_clu * n_channels, n_samples)) self.waveform_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the box bounds as a function of the selected channels. if channel_ids: self.canvas.boxed.box_bounds = _get_box_bounds(bunchs, channel_ids) self.box_pos = np.array(self.canvas.boxed.box_pos) self.box_size = np.array(self.canvas.boxed.box_size) self._update_boxes() self.data_bounds = self._get_data_bounds(bunchs) self.waveform_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.waveform_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) self._update_axes(bunchs) self.canvas.update() def _update_axes(self, bunchs): """Update the axes.""" # Update the axes data bounds. _, m, _, M = self.data_bounds # Waveform duration, scaled by overlap factor if needed. wave_dur = bunchs[0].get('waveform_duration', 1.) wave_dur /= .5 * (1 + _overlap_transform( 1, n=len(self.cluster_ids), overlap=self.overlap)) x1, y1 = range_transform(self.canvas.boxed.box_bounds[0], NDC, [wave_dur, M - m]) axes_data_bounds = (0, 0, x1, y1) self.canvas.axes.reset_data_bounds(axes_data_bounds, do_update=True) def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def _update_boxes(self): self.canvas.boxed.update_boxes(self.box_pos * self.probe_scaling, self.box_size) def _apply_box_scaling(self): self.canvas.layout.scaling = self._box_scaling @property def box_scaling(self): """Scaling of the channel boxes.""" return self._box_scaling @box_scaling.setter def box_scaling(self, value): assert len(value) == 2 self._box_scaling = value self._apply_box_scaling() def widen(self): """Increase the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w * self._scaling_param_increment, h) self._apply_box_scaling() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._box_scaling self._box_scaling = (w / self._scaling_param_increment, h) self._apply_box_scaling() def _get_scaling_value(self): return self._box_scaling[1] def _set_scaling_value(self, value): w, h = self._box_scaling self.box_scaling = (w, value) self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): """Scaling of the entire probe.""" return self._probe_scaling @probe_scaling.setter def probe_scaling(self, value): assert len(value) == 2 self._probe_scaling = value self._update_boxes() def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" w, h = self._probe_scaling self._probe_scaling = (w * self._scaling_param_increment, h) self._update_boxes() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w / self._scaling_param_increment, h) self._update_boxes() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h * self._scaling_param_increment) self._update_boxes() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" w, h = self._probe_scaling self._probe_scaling = (w, h / self._scaling_param_increment) self._update_boxes() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('channel_click', self, channel_id=channel_id, key=key, button=b) def next_waveforms_type(self): """Switch to the next waveforms type.""" i = self.waveforms_types.index(self.waveforms_type) n = len(self.waveforms_types) self.waveforms_type = self.waveforms_types[(i + 1) % n] logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms': self.waveforms_type = 'waveforms' self.plot() elif 'mean_waveforms' in self.waveforms_types: self.waveforms_type = 'mean_waveforms' self.plot()
class TraceView(ScalingMixin, BaseColorView, ManualClusteringView): """This view shows the raw traces along with spike waveforms. Constructor ----------- traces : function Maps a time interval `(t0, t1)` to a `Bunch(data, color, waveforms)` where * `data` is an `(n_samples, n_channels)` array * `waveforms` is a list of bunchs with the following attributes: * `data` * `color` * `channel_ids` * `start_time` * `spike_id` * `spike_cluster` spike_times : function Teturns the list of relevant spike times. sample_rate : float duration : float n_channels : int channel_positions : array-like Positions of the channels, used for displaying the channels in the right y order channel_labels : list Labels of all shown channels. By default, this is just the channel ids. """ _default_position = 'left' auto_update = True auto_scale = True interval_duration = .25 # default duration of the interval shift_amount = .1 scaling_coeff_x = 1.25 trace_quantile = .01 # quantile for auto-scaling default_trace_color = (.5, .5, .5, 1) trace_color_0 = (.353, .161, .443) trace_color_1 = (.133, .404, .396) default_shortcuts = { 'change_trace_size': 'ctrl+wheel', 'switch_color_scheme': 'shift+wheel', 'navigate': 'alt+wheel', 'decrease': 'alt+down', 'increase': 'alt+up', 'go_left': 'alt+left', 'go_right': 'alt+right', 'jump_left': 'shift+alt+left', 'jump_right': 'shift+alt+right', 'go_to_start': 'alt+home', 'go_to_end': 'alt+end', 'go_to': 'alt+t', 'go_to_next_spike': 'alt+pgdown', 'go_to_previous_spike': 'alt+pgup', 'narrow': 'alt++', 'select_spike': 'ctrl+click', 'select_channel_pcA': 'shift+left click', 'select_channel_pcB': 'shift+right click', 'switch_origin': 'alt+o', 'toggle_highlighted_spikes': 'alt+s', 'toggle_show_labels': 'alt+l', 'widen': 'alt+-', } default_snippets = { 'go_to': 'tg', 'shift': 'ts', } def __init__( self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None, channel_positions=None, channel_labels=None, **kwargs): self.do_show_labels = True self.show_all_spikes = False self.get_spike_times = spike_times # Sample rate. assert sample_rate > 0 self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces # self.waveforms = None assert duration >= 0 self.duration = duration assert n_channels >= 0 self.n_channels = n_channels # Channel y ranking. self.channel_positions = ( channel_positions if channel_positions is not None else np.c_[np.zeros(n_channels), np.arange(n_channels)]) # channel_y_ranks[i] is the position of channel #i in the trace view. self.channel_y_ranks = np.argsort(np.argsort(self.channel_positions[:, 1])) assert self.channel_y_ranks.shape == (n_channels,) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(n_channels)]) assert len(self.channel_labels) == n_channels # Initialize the view. super(TraceView, self).__init__(**kwargs) self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale') self.local_state_attrs += ('interval', 'scaling',) # Visuals. self._create_visuals() # Initial interval. self._interval = None self.go_to(duration / 2.) self._waveform_times = [] self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2)) def _create_visuals(self): self.canvas.set_layout('stacked', n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) self.trace_visual = UniformPlotVisual() # Gradient of color for the traces. if self.trace_color_0 and self.trace_color_1: self.trace_visual.inserter.insert_frag( 'gl_FragColor.rgb = mix(vec3%s, vec3%s, (v_signal_index / %d));' % ( self.trace_color_0, self.trace_color_1, self.n_channels), 'end') self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_boxes >= 50 * u_zoom.y) && ' '(mod(int(a_box_index), int(n_boxes / (50 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual) @property def stacked(self): return self.canvas.stacked # Internal methods # ------------------------------------------------------------------------- def _plot_traces(self, traces, color=None): traces = traces.T n_samples = traces.shape[1] n_ch = self.n_channels assert traces.shape == (n_ch, n_samples) color = color or self.default_trace_color t = self._interval[0] + np.arange(n_samples) * self.dt t = np.tile(t, (n_ch, 1)) box_index = self.channel_y_ranks box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=1) assert t.shape == (n_ch, n_samples) assert traces.shape == (n_ch, n_samples) assert box_index.shape == (n_ch, n_samples) self.trace_visual.color = color self.canvas.update_visual( self.trace_visual, t, traces, data_bounds=self.data_bounds, box_index=box_index.ravel(), ) def _plot_spike(self, bunch): # The spike time corresponds to the first sample of the waveform. n_samples, n_channels = bunch.data.shape assert len(bunch.channel_ids) == n_channels # Generate the x coordinates of the waveform. t = bunch.start_time + self.dt * np.arange(n_samples) t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) # Determine the spike color. i = bunch.select_index c = bunch.spike_cluster cs = self.color_schemes.get() color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1) # We could tweak the color of each spike waveform depending on the template amplitude # on each of its best channels. # channel_amps = bunch.get('channel_amps', None) # if channel_amps is not None: # color = np.tile(color, (n_channels, 1)) # assert color.shape == (n_channels, 4) # color[:, 3] = channel_amps # The box index depends on the channel. box_index = self.channel_y_ranks[bunch.channel_ids] box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0) self.waveform_visual.add_batch_data( box_index=box_index, x=t, y=bunch.data.T, color=color, data_bounds=self.data_bounds, ) def _plot_waveforms(self, waveforms, **kwargs): """Plot the waveforms.""" # waveforms = self.waveforms assert isinstance(waveforms, list) if waveforms: self.waveform_visual.show() self.waveform_visual.reset_batch() for w in waveforms: self._plot_spike(w) self._waveform_times.append( (w.start_time, w.spike_id, w.spike_cluster, w.get('channel_ids', None))) self.canvas.update_visual(self.waveform_visual) else: # pragma: no cover self.waveform_visual.hide() def _plot_labels(self, traces): self.text_visual.reset_batch() for ch in range(self.n_channels): bi = self.channel_y_ranks[ch] ch_label = self.channel_labels[ch] self.text_visual.add_batch_data( pos=[self.data_bounds[0], 0], text=ch_label, anchor=[+1., 0], data_bounds=self.data_bounds, box_index=bi, ) self.canvas.update_visual(self.text_visual) # Public methods # ------------------------------------------------------------------------- def _restrict_interval(self, interval): start, end = interval # Round the times to full samples to avoid subsampling shifts # in the traces. start = int(round(start * self.sample_rate)) / self.sample_rate end = int(round(end * self.sample_rate)) / self.sample_rate # Restrict the interval to the boundaries of the traces. if start < 0: end += (-start) start = 0 elif end >= self.duration: start -= (end - self.duration) end = self.duration start = np.clip(start, 0, end) end = np.clip(end, start, self.duration) assert 0 <= start < end <= self.duration return start, end def plot(self, update_traces=True, update_waveforms=True): if update_waveforms: # Load the traces in the interval. traces = self.traces(self._interval) if update_traces: logger.log(5, "Redraw the entire trace view.") start, end = self._interval # Find the data bounds. if self.auto_scale or getattr(self, 'data_bounds', NDC) == NDC: ymin = np.quantile(traces.data, self.trace_quantile) ymax = np.quantile(traces.data, 1. - self.trace_quantile) else: ymin, ymax = self.data_bounds[1], self.data_bounds[3] self.data_bounds = (start, ymin, end, ymax) # Used for spike click. self._waveform_times = [] # Plot the traces. self._plot_traces( traces.data, color=traces.get('color', None)) # Plot the labels. if self.do_show_labels: self._plot_labels(traces.data) if update_waveforms: self._plot_waveforms(traces.get('waveforms', [])) self._update_axes() self.canvas.update() def set_interval(self, interval=None): """Display the traces and spikes in a given interval.""" if interval is None: interval = self._interval interval = self._restrict_interval(interval) if interval != self._interval: logger.log(5, "Redraw the entire trace view.") self._interval = interval emit('is_busy', self, True) self.plot(update_traces=True, update_waveforms=True) emit('is_busy', self, False) emit('time_range_selected', self, interval) self.update_status() else: self.plot(update_traces=False, update_waveforms=True) def on_select(self, cluster_ids=None, **kwargs): self.cluster_ids = cluster_ids if not cluster_ids: return # Make sure we call again self.traces() when the cluster selection changes. self.set_interval() def attach(self, gui): """Attach the view to the GUI.""" super(TraceView, self).attach(gui) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add( self.toggle_highlighted_spikes, checkable=True, checked=self.show_all_spikes) self.actions.add(self.toggle_auto_scale, checkable=True, checked=self.auto_scale) self.actions.add(self.switch_origin) self.actions.separator() self.actions.add( self.go_to, prompt=True, prompt_default=lambda: str(self.time)) self.actions.separator() self.actions.add(self.go_to_start) self.actions.add(self.go_to_end) self.actions.separator() self.actions.add(self.shift, prompt=True) self.actions.add(self.go_right) self.actions.add(self.go_left) self.actions.add(self.jump_right) self.actions.add(self.jump_left) self.actions.separator() self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() self.actions.add(self.go_to_next_spike) self.actions.add(self.go_to_previous_spike) self.actions.separator() self.set_interval() @property def status(self): a, b = self._interval return '[{:.2f}s - {:.2f}s]. Color scheme: {}.'.format(a, b, self.color_scheme) # Origin # ------------------------------------------------------------------------- @property def origin(self): """Whether to show the channels from top to bottom (`top` option, the default), or from bottom to top (`bottom`).""" return getattr(self.canvas.layout, 'origin', Stacked._origin) @origin.setter def origin(self, value): if value is None: return if self.canvas.layout: self.canvas.layout.origin = value else: # pragma: no cover logger.warning( "Could not set origin to %s because the layout instance was not initialized yet.", value) def switch_origin(self): """Switch between top and bottom origin for the channels.""" self.origin = 'bottom' if self.origin == 'top' else 'top' # Navigation # ------------------------------------------------------------------------- @property def time(self): """Time at the center of the window.""" return sum(self._interval) * .5 @property def interval(self): """Interval as `(tmin, tmax)`.""" return self._interval @interval.setter def interval(self, value): self.set_interval(value) @property def half_duration(self): """Half of the duration of the current interval.""" if self._interval is not None: a, b = self._interval return (b - a) * .5 else: return self.interval_duration * .5 def go_to(self, time): """Go to a specific time (in seconds).""" half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) def shift(self, delay): """Shift the interval by a given delay (in seconds).""" self.go_to(self.time + delay) def go_to_start(self): """Go to the start of the recording.""" self.go_to(0) def go_to_end(self): """Go to end of the recording.""" self.go_to(self.duration) def go_right(self): """Go to right.""" start, end = self._interval delay = (end - start) * .1 self.shift(delay) def go_left(self): """Go to left.""" start, end = self._interval delay = (end - start) * .1 self.shift(-delay) def jump_right(self): """Jump to right.""" delay = self.duration * .1 self.shift(delay) def jump_left(self): """Jump to left.""" delay = self.duration * .1 self.shift(-delay) def _jump_to_spike(self, delta=+1): """Jump to next or previous spike from the selected clusters.""" spike_times = self.get_spike_times() if spike_times is not None and len(spike_times): ind = np.searchsorted(spike_times, self.time) n = len(spike_times) self.go_to(spike_times[(ind + delta) % n]) def go_to_next_spike(self, ): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(+1) def go_to_previous_spike(self, ): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(-1) def toggle_highlighted_spikes(self, checked): """Toggle between showing all spikes or selected spikes.""" self.show_all_spikes = checked self.set_interval() def widen(self): """Increase the interval size.""" t, h = self.time, self.half_duration h *= self.scaling_coeff_x self.set_interval((t - h, t + h)) def narrow(self): """Decrease the interval size.""" t, h = self.time, self.half_duration h /= self.scaling_coeff_x self.set_interval((t - h, t + h)) # Misc # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Toggle the display of the channel ids.""" logger.debug("Set show labels to %s.", checked) self.do_show_labels = checked self.text_visual.toggle() self.canvas.update() def toggle_auto_scale(self, checked): """Toggle automatic scaling of the traces.""" logger.debug("Set auto scale to %s.", checked) self.auto_scale = checked def update_color(self): """Update the view when the color scheme changes.""" self.plot(update_traces=False, update_waveforms=True) # Scaling # ------------------------------------------------------------------------- @property def scaling(self): """Scaling of the channel boxes.""" return self.stacked._box_scaling[1] @scaling.setter def scaling(self, value): self.stacked._box_scaling = (self.stacked._box_scaling[0], value) def _get_scaling_value(self): return self.scaling def _set_scaling_value(self, value): self.scaling = value self.stacked.update() # Spike selection # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on a spike.""" if 'Control' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = np.nonzero(self.channel_y_ranks == box_id)[0] # Find the spike and cluster closest to the mouse. db = self.data_bounds # Get the information about the displayed spikes. wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch] if not wt: return # Get the time coordinate of the mouse position. mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = zip(*wt) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the select_spike event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] emit('select_spike', self, channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id) if 'Shift' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0]) emit('select_channel', self, channel_id=channel_id, button=e.button) def on_mouse_wheel(self, e): # pragma: no cover """Scroll through the data with alt+wheel.""" super(TraceView, self).on_mouse_wheel(e) if e.modifiers == ('Alt',): start, end = self._interval delay = e.delta * (end - start) * .1 self.shift(-delay)
class WaveformView(ScalingMixin, ManualClusteringView): """This view shows the waveforms of the selected clusters, on relevant channels, following the probe geometry. Constructor ----------- waveforms : dict of functions Every function maps a cluster id to a Bunch with the following attributes: * `data` : a 3D array `(n_spikes, n_samples, n_channels_loc)` * `channel_ids` : the channel ids corresponding to the third dimension in `data` * `channel_labels` : a list of channel labels for every channel in `channel_ids` * `channel_positions` : a 2D array with the coordinates of the channels on the probe * `masks` : a 2D array `(n_spikes, n_channels)` with the waveforms masks * `alpha` : the alpha transparency channel The keys of the dictionary are called **waveform types**. The `next_waveforms_type` action cycles through all available waveform types. The key `waveforms` is mandatory. waveforms_type : str Default key of the waveforms dictionary to plot initially. """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' ax_color = (.75, .75, .75, 1.) tick_size = 5. cluster_ids = () default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_show_labels': 'ctrl+l', 'next_waveforms_type': 'w', 'previous_waveforms_type': 'shift+w', 'toggle_mean_waveforms': 'm', # Box scaling. 'widen': 'ctrl+right', 'narrow': 'ctrl+left', 'increase': 'ctrl+up', 'decrease': 'ctrl+down', 'change_box_size': 'ctrl+wheel', # Probe scaling. 'extend_horizontally': 'shift+right', 'shrink_horizontally': 'shift+left', 'extend_vertically': 'shift+up', 'shrink_vertically': 'shift+down', } default_snippets = { 'change_n_spikes_waveforms': 'wn', } def __init__(self, waveforms=None, waveforms_type=None, sample_rate=None, **kwargs): self._overlap = False self.do_show_labels = True self.channel_ids = None self.filtered_tags = () self.wave_duration = 0. # updated in the plotting method self.data_bounds = None self.sample_rate = sample_rate self._status_suffix = '' assert sample_rate > 0., "The sample rate must be provided to the waveform view." # Initialize the view. super(WaveformView, self).__init__(**kwargs) self.state_attrs += ('waveforms_type', 'overlap', 'do_show_labels') self.local_state_attrs += ('box_scaling', 'probe_scaling') # Box and probe scaling. self.canvas.set_layout('boxed', box_pos=np.zeros((1, 2))) # Ensure waveforms is a dictionary, even if there is a single waveforms type. waveforms = waveforms or {} waveforms = waveforms if isinstance(waveforms, dict) else { 'waveforms': waveforms } self.waveforms = waveforms # Rotating property waveforms types. self.waveforms_types = RotatingProperty() for name, value in self.waveforms.items(): self.waveforms_types.add(name, value) # Current waveforms type. self.waveforms_types.set(waveforms_type) assert self.waveforms_type in self.waveforms self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) self.tick_visual = UniformScatterVisual(marker='vbar', color=self.ax_color, size=self.tick_size) self.canvas.add_visual(self.tick_visual) # Two types of visuals: thin raw line visual for normal waveforms, thick antialiased # agg plot visual for mean and template waveforms. self.waveform_agg_visual = PlotAggVisual() self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_agg_visual) self.canvas.add_visual(self.waveform_visual) # Internal methods # ------------------------------------------------------------------------- @property def _current_visual(self): if self.waveforms_type == 'waveforms': return self.waveform_visual else: return self.waveform_agg_visual def _get_data_bounds(self, bunchs): m = min(_min(b.data) for b in bunchs) M = max(_max(b.data) for b in bunchs) # Symmetrize on the y axis. M = max(abs(m), abs(M)) return [-1, -M, +1, M] def get_clusters_data(self): if self.waveforms_type not in self.waveforms: return bunchs = [ self.waveforms_types.get()(cluster_id) for cluster_id in self.cluster_ids ] clu_offsets = _get_clu_offsets(bunchs) n_clu = max(clu_offsets) + 1 # Offset depending on the overlap. for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)): bunch.index = i bunch.offset = offset bunch.n_clu = n_clu bunch.color = selected_cluster_color(i, bunch.get('alpha', .75)) return bunchs def _plot_cluster(self, bunch): wave = bunch.data if wave is None or not wave.size: return channel_ids_loc = bunch.channel_ids n_channels = len(channel_ids_loc) masks = bunch.get('masks', np.ones((wave.shape[0], n_channels))) # By default, this is 0, 1, 2 for the first 3 clusters. # But it can be customized when displaying several sets # of waveforms per cluster. n_spikes_clu, n_samples = wave.shape[:2] assert wave.shape[2] == n_channels assert masks.shape == (n_spikes_clu, n_channels) # Find the x coordinates. t = get_linear_x(n_spikes_clu * n_channels, n_samples) t = _overlap_transform(t, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) # HACK: on the GPU, we get the actual masks with fract(masks) # since we add the relative cluster index. We need to ensure # that the masks is never 1.0, otherwise it is interpreted as # 0. eps = .001 masks = eps + (1 - 2 * eps) * masks # NOTE: we add the cluster index which is used for the # computation of the depth on the GPU. masks += bunch.index # Generate the box index (one number per channel). box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.tile(box_index, n_spikes_clu) # Find the correct number of vertices depending on the current waveform visual. if self._current_visual == self.waveform_visual: # PlotVisual box_index = np.repeat(box_index, n_samples) assert box_index.size == n_spikes_clu * n_channels * n_samples else: # PlotAggVisual box_index = np.repeat(box_index, 2 * (n_samples + 2)) assert box_index.size == n_spikes_clu * n_channels * 2 * ( n_samples + 2) # Generate the waveform array. wave = np.transpose(wave, (0, 2, 1)) nw = n_spikes_clu * n_channels wave = wave.reshape((nw, n_samples)) assert self.data_bounds is not None self._current_visual.add_batch_data(x=t, y=wave, color=bunch.color, masks=masks, box_index=box_index, data_bounds=self.data_bounds) # Waveform axes. # -------------- # Horizontal y=0 lines. ax_db = self.data_bounds a, b = _overlap_transform(np.array([-1, 1]), offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, 2) box_index = np.tile(box_index, n_spikes_clu) hpos = np.tile([[a, 0, b, 0]], (nw, 1)) assert box_index.size == hpos.shape[0] * 2 self.line_visual.add_batch_data( pos=hpos, color=self.ax_color, data_bounds=ax_db, box_index=box_index, ) # Vertical ticks every millisecond. steps = np.arange(np.round(self.wave_duration * 1000)) # A vline every millisecond. x = .001 * steps # Scale to [-1, 1], same coordinates as the waveform points. x = -1 + 2 * x / self.wave_duration # Take overlap into account. x = _overlap_transform(x, offset=bunch.offset, n=bunch.n_clu, overlap=self.overlap) x = np.tile(x, len(channel_ids_loc)) # Generate the box index. box_index = _index_of(channel_ids_loc, self.channel_ids) box_index = np.repeat(box_index, x.size // len(box_index)) assert x.size == box_index.size self.tick_visual.add_batch_data( x=x, y=np.zeros_like(x), data_bounds=ax_db, box_index=box_index, ) def _plot_labels(self, channel_ids, n_clusters, channel_labels): # Add channel labels. if not self.do_show_labels: return self.text_visual.reset_batch() for i, ch in enumerate(channel_ids): label = channel_labels[ch] self.text_visual.add_batch_data( pos=[-1, 0], text=str(label), anchor=[-1.25, 0], box_index=i, ) self.canvas.update_visual(self.text_visual) def plot(self, **kwargs): """Update the view with the current cluster selection.""" if not self.cluster_ids: return bunchs = self.get_clusters_data() if not bunchs: return # All channel ids appearing in all selected clusters. channel_ids = sorted(set(_flatten([d.channel_ids for d in bunchs]))) self.channel_ids = channel_ids if bunchs[0].data is not None: self.wave_duration = bunchs[0].data.shape[1] / float( self.sample_rate) else: # pragma: no cover self.wave_duration = 1. # Channel labels. channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.channel_ids]) channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.channel_ids) }) # Update the Boxed box positions as a function of the selected channels. if channel_ids: self.canvas.boxed.update_boxes(_get_box_pos(bunchs, channel_ids)) self.data_bounds = self.data_bounds or self._get_data_bounds(bunchs) self._current_visual.reset_batch() self.line_visual.reset_batch() self.tick_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.tick_visual) self.canvas.update_visual(self.line_visual) self.canvas.update_visual(self._current_visual) self._plot_labels(channel_ids, len(self.cluster_ids), channel_labels) # Only show the current waveform visual. if self._current_visual == self.waveform_visual: self.waveform_visual.show() self.waveform_agg_visual.hide() elif self._current_visual == self.waveform_agg_visual: self.waveform_agg_visual.show() self.waveform_visual.hide() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap, checkable=True, checked=self.overlap) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) self.actions.add(self.next_waveforms_type) self.actions.add(self.previous_waveforms_type) self.actions.add(self.toggle_mean_waveforms, checkable=True) self.actions.separator() # Box scaling. self.actions.add(self.widen) self.actions.add(self.narrow) self.actions.separator() # Probe scaling. self.actions.add(self.extend_horizontally) self.actions.add(self.shrink_horizontally) self.actions.separator() self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) self.actions.separator() @property def boxed(self): """Layout instance.""" return self.canvas.boxed @property def status(self): return self.waveforms_type # Overlap # ------------------------------------------------------------------------- @property def overlap(self): """Whether to overlap the waveforms belonging to different clusters.""" return self._overlap @overlap.setter def overlap(self, value): self._overlap = value self.plot() def toggle_waveform_overlap(self, checked): """Toggle the overlap of the waveforms.""" self.overlap = checked # Box scaling # ------------------------------------------------------------------------- def widen(self): """Increase the horizontal scaling of the waveforms.""" self.boxed.expand_box_width() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_box_width() @property def box_scaling(self): return self.boxed._box_scaling @box_scaling.setter def box_scaling(self, value): self.boxed._box_scaling = value def _get_scaling_value(self): return self.boxed._box_scaling[1] def _set_scaling_value(self, value): w, h = self.boxed._box_scaling self.boxed._box_scaling = (w, value) self.boxed.update() # Probe scaling # ------------------------------------------------------------------------- @property def probe_scaling(self): return self.boxed._layout_scaling @probe_scaling.setter def probe_scaling(self, value): self.boxed._layout_scaling = value def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" self.boxed.expand_layout_width() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" self.boxed.shrink_layout_width() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" self.boxed.expand_layout_height() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" self.boxed.shrink_layout_height() # Navigation # ------------------------------------------------------------------------- def toggle_show_labels(self, checked): """Whether to show the channel ids or not.""" self.do_show_labels = checked self.text_visual.show() if checked else self.text_visual.hide() self.canvas.update() def on_mouse_click(self, e): """Select a channel by clicking on a box in the waveform view.""" b = e.button nums = tuple('%d' % i for i in range(10)) if 'Control' in e.modifiers or e.key in nums: key = int(e.key) if e.key in nums else None # Get mouse position in NDC. channel_idx, _ = self.canvas.boxed.box_map(e.pos) channel_id = self.channel_ids[channel_idx] logger.debug("Click on channel_id %d with key %s and button %s.", channel_id, key, b) emit('select_channel', self, channel_id=channel_id, key=key, button=b) @property def waveforms_type(self): return self.waveforms_types.current @waveforms_type.setter def waveforms_type(self, value): self.waveforms_types.set(value) def next_waveforms_type(self): """Switch to the next waveforms type.""" self.waveforms_types.next() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def previous_waveforms_type(self): """Switch to the previous waveforms type.""" self.waveforms_types.previous() logger.debug("Switch to waveforms type %s.", self.waveforms_type) self.plot() def toggle_mean_waveforms(self, checked): """Switch to the `mean_waveforms` type, if it is available.""" if self.waveforms_type == 'mean_waveforms' and 'waveforms' in self.waveforms: self.waveforms_types.set('waveforms') logger.debug("Switch to raw waveforms.") self.plot() elif 'mean_waveforms' in self.waveforms: self.waveforms_types.set('mean_waveforms') logger.debug("Switch to mean waveforms.") self.plot()
class TemplateView(ScalingMixin, BaseColorView, BaseGlobalView, ManualClusteringView): """This view shows all template waveforms of all clusters in a large grid of shape `(n_channels, n_clusters)`. Constructor ----------- templates : function Maps `cluster_ids` to a list of `[Bunch(template, channel_ids)]` where `template` is an `(n_samples, n_channels)` array, and `channel_ids` specifies the channels of the `template` array (sparse format). channel_ids : array-like The list of all channel ids. channel_labels : list Labels of all shown channels. By default, this is just the channel ids. cluster_ids : array-like The list of all clusters to show initially. """ _default_position = 'right' _scaling = 1. default_shortcuts = { 'change_template_size': 'ctrl+wheel', 'switch_color_scheme': 'shift+wheel', 'decrease': 'ctrl+alt+-', 'increase': 'ctrl+alt++', 'select_cluster': 'ctrl+click', 'select_more': 'shift+click', } def __init__(self, templates=None, channel_ids=None, channel_labels=None, cluster_ids=None, **kwargs): super(TemplateView, self).__init__(**kwargs) self.state_attrs += () self.local_state_attrs += ('scaling', ) # Full list of channels. self.channel_ids = channel_ids self.n_channels = len(channel_ids) # Channel labels. self.channel_labels = (channel_labels if channel_labels is not None else ['%d' % ch for ch in range(self.n_channels)]) assert len(self.channel_labels) == self.n_channels # TODO: show channel and cluster labels # Full list of clusters. if cluster_ids is not None: self.set_cluster_ids(cluster_ids) self.canvas.set_layout('grid', has_clip=False) self.canvas.enable_axes() self.templates = templates self.visual = PlotVisual() self.canvas.add_visual(self.visual) self._cluster_box_index = { } # dict {cluster_id: box_index} used to quickly reorder self.select_visual = PlotVisual() self.canvas.add_visual(self.select_visual) # Internal plot functions # ------------------------------------------------------------------------- def _get_data_bounds(self, bunchs): """Get the data bounds.""" m = np.median([b.template.min() for b in bunchs]) M = np.median([b.template.max() for b in bunchs]) M = max(abs(m), abs(M)) return [-1, -M, +1, M] def _get_box_index(self, bunch): """Get the box_index array for a cluster.""" # Generate the box index (channel_idx, cluster_idx) per vertex. n_samples, nc = bunch.template.shape box_index = _index_of(bunch.channel_ids, self.channel_ids) box_index = np.repeat(box_index, n_samples) box_index = np.c_[box_index.reshape((-1, 1)), bunch.cluster_idx * np.ones( (n_samples * len(bunch.channel_ids), 1))] assert box_index.shape == (len(bunch.channel_ids) * n_samples, 2) assert box_index.size == bunch.template.size * 2 return box_index def _plot_cluster(self, bunch, color=None): """Plot one cluster.""" wave = bunch.template # shape: (n_samples, n_channels) channel_ids_loc = bunch.channel_ids n_channels_loc = len(channel_ids_loc) n_samples, nc = wave.shape assert nc == n_channels_loc # Find the x coordinates. t = get_linear_x(n_channels_loc, n_samples) color = color or self.cluster_colors[bunch.cluster_rel] assert len(color) == 4 box_index = self._get_box_index(bunch) return Bunch(x=t, y=wave.T, color=color, box_index=box_index, data_bounds=self.data_bounds) def set_cluster_ids(self, cluster_ids): """Update the cluster ids when their identity or order has changed.""" if cluster_ids is None or not len(cluster_ids): return self.all_cluster_ids = np.array(cluster_ids, dtype=np.int32) # Permutation of the clusters. self.cluster_idxs = np.argsort(self.all_cluster_ids) self.sorted_cluster_ids = self.all_cluster_ids[self.cluster_idxs] # Cluster colors, ordered by cluster id. self.cluster_colors = self.get_cluster_colors(self.sorted_cluster_ids, alpha=.75) def get_clusters_data(self, load_all=None): """Return all templates data.""" bunchs = self.templates(self.all_cluster_ids) out = [] for cluster_rel, cluster_idx, cluster_id in self._iter_clusters(): b = bunchs[cluster_id] b.cluster_rel = cluster_rel b.cluster_idx = cluster_idx b.cluster_id = cluster_id out.append(b) return out # Main methods # ------------------------------------------------------------------------- def update_cluster_sort(self, cluster_ids): """Update the order of the clusters.""" if not self._cluster_box_index: # pragma: no cover return self.plot() # Only the order of the cluster_ids is supposed to change here. # We just have to update box_index instead of replotting everything. assert len(cluster_ids) == len(self.all_cluster_ids) # Update the cluster ids, in the new order. self.all_cluster_ids = np.array(cluster_ids, dtype=np.int32) # Update the permutation of the clusters. self.cluster_idxs = np.argsort(self.all_cluster_ids) box_index = [] for cluster_rel, cluster_idx in enumerate(self.cluster_idxs): cluster_id = self.all_cluster_ids[cluster_idx] clu_box_index = self._cluster_box_index[cluster_id] clu_box_index[:, 1] = cluster_idx box_index.append(clu_box_index) box_index = np.concatenate(box_index, axis=0) self.visual.set_box_index(box_index) self.canvas.update() def update_color(self): """Update the color of the clusters, taking the selected clusters into account.""" # This method is only used when the view has been plotted at least once, # such that self._cluster_box_index has been filled. if not self._cluster_box_index: return self.plot() # The call to set_cluster_ids() update the cluster_colors array. self.set_cluster_ids(self.all_cluster_ids) # Selected cluster colors. cluster_colors = self.cluster_colors selected_clusters = self.cluster_ids if selected_clusters is not None: cluster_colors = _add_selected_clusters_colors( selected_clusters, self.sorted_cluster_ids, cluster_colors) # Number of vertices per cluster = number of vertices per signal n_vertices_clu = [ len(self._cluster_box_index[cluster_id]) for cluster_id in self.sorted_cluster_ids ] # The argument passed to set_color() must have 1 row per vertex. self.visual.set_color(np.repeat(cluster_colors, n_vertices_clu, axis=0)) self.canvas.update() @property def status(self): return 'Color scheme: %s' % self.color_schemes.current def plot(self, **kwargs): """Make the template plot.""" # Retrieve the waveform data. bunchs = self.get_clusters_data() if not bunchs: return n_clusters = len(self.all_cluster_ids) self.canvas.grid.shape = (self.n_channels, n_clusters) self.visual.reset_batch() # Go through all clusters, ordered by cluster id. self.data_bounds = self._get_data_bounds(bunchs) for bunch in bunchs: data = self._plot_cluster(bunch) self._cluster_box_index[bunch.cluster_id] = data.box_index self.visual.add_batch_data(**data) self.canvas.update_visual(self.visual) self._apply_scaling() self.canvas.axes.reset_data_bounds((0, 0, n_clusters, self.n_channels)) self.canvas.update() def on_select(self, *args, **kwargs): super(TemplateView, self).on_select(*args, **kwargs) self.update_color() # Scaling # ------------------------------------------------------------------------- def _set_scaling_value(self, value): self._scaling = value self._apply_scaling() def _apply_scaling(self): sx, sy = self.canvas.layout.scaling self.canvas.layout.scaling = (sx, self._scaling) @property def scaling(self): """Return the grid scaling.""" return self._scaling @scaling.setter def scaling(self, value): self._scaling = value # Interactivity # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on its template waveform.""" if 'Control' not in e.modifiers: return b = e.button # Get mouse position in NDC. (channel_idx, cluster_rel), _ = self.canvas.grid.box_map(e.pos) cluster_id = self.all_cluster_ids[cluster_rel] logger.debug("Click on cluster %d with button %s.", cluster_id, b) if 'Shift' in e.modifiers: emit('select_more', self, [cluster_id]) else: emit('request_select', self, [cluster_id])
def __init__( self, traces=None, sample_rate=None, spike_times=None, duration=None, n_channels=None, channel_vertical_order=None, channel_labels=None, **kwargs): self.do_show_labels = True self.show_all_spikes = False self._scaling = 1. self.get_spike_times = spike_times # Sample rate. assert sample_rate > 0 self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces self.waveforms = None assert duration >= 0 self.duration = duration assert n_channels >= 0 self.n_channels = n_channels # Channel permutation. self._channel_perm = ( np.arange(n_channels) if channel_vertical_order is None else channel_vertical_order) assert self._channel_perm.shape == (n_channels,) self._channel_perm = np.argsort(self._channel_perm) # Channel labels. self.channel_labels = ( channel_labels if channel_labels is not None else ['%d' % ch for ch in range(n_channels)]) assert len(self.channel_labels) == n_channels # Box and probe scaling. self._origin = None # Initialize the view. super(TraceView, self).__init__(**kwargs) self.state_attrs += ('origin', 'do_show_labels', 'show_all_spikes', 'auto_scale') self.local_state_attrs += ('interval', 'scaling',) self.canvas.set_layout('stacked', origin=self.origin, n_plots=self.n_channels) self.canvas.enable_axes(show_y=False) # Visuals. self.trace_visual = UniformPlotVisual() self.canvas.add_visual(self.trace_visual) self.waveform_visual = PlotVisual() self.canvas.add_visual(self.waveform_visual) self.text_visual = TextVisual() _fix_coordinate_in_visual(self.text_visual, 'x') self.canvas.add_visual(self.text_visual) # Make a copy of the initial box pos and size. We'll apply the scaling # to these quantities. self.box_size = np.array(self.canvas.stacked.box_size) # Initial interval. self._interval = None self.go_to(duration / 2.) self._waveform_times = []