def __init__(self, spike_times, spike_clusters, cluster_ids=None, **kwargs): self.spike_times = spike_times self.n_spikes = len(spike_times) self.duration = spike_times[-1] * 1.01 self.n_clusters = 1 assert len(spike_clusters) == self.n_spikes self.set_spike_clusters(spike_clusters) self.set_cluster_ids(cluster_ids) super(RasterView, self).__init__(**kwargs) self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False) self.canvas.enable_axes() self.visual = ScatterVisual( marker='vbar', marker_scaling=''' point_size = v_size * u_zoom.y + 5.; float width = 0.2; float height = 0.5; vec2 marker_size = point_size * vec2(width, height); marker_size.x = clamp(marker_size.x, 1, 20); ''', ) self.visual.inserter.insert_vert(''' gl_PointSize = a_size * u_zoom.y + 5.0; ''', 'end') self.canvas.add_visual(self.visual) self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2))
def __init__(self, amplitudes=None, amplitude_name=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitude_name',) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes self.amplitude_names = list(amplitudes.keys()) # Current amplitude type. self.amplitude_name = amplitude_name or self.amplitude_names[0] assert self.amplitude_name in amplitudes self.cluster_ids = () self.duration = duration or 1 # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add_on_gpu([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')]) self.canvas.add_visual(self.hist_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) # Amplitude name. self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))
def __init__(self, features=None, attributes=None, **kwargs): super(FeatureView, self).__init__(**kwargs) self.state_attrs += ('fixed_channels', 'feature_scaling') assert features self.features = features self._lim = 1 self.grid_dim = _get_default_grid() # 2D array where every item a string like `0A,1B` self.n_rows, self.n_cols = np.array(self.grid_dim).shape self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols)) self.canvas.enable_lasso() # Channels being shown. self.channel_ids = None # Attributes: extra features. This is a dictionary # {name: array} # where each array is a `(n_spikes,)` array. self.attributes = attributes or {} self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual)
def __init__(self, coords=None, **kwargs): super(ScatterView, self).__init__(**kwargs) # Save the marker size in the global and local view's config. self.canvas.enable_axes() self.canvas.enable_lasso() assert coords self.coords = coords self.visual = ScatterVisual() self.canvas.add_visual(self.visual)
def __init__( self, positions=None, best_channels=None, channel_labels=None, dead_channels=None, **kwargs): super(ProbeView, self).__init__(**kwargs) self.state_attrs += ('do_show_labels',) # Normalize positions. assert positions.ndim == 2 assert positions.shape[1] == 2 positions = positions.astype(np.float32) self.positions, self.data_bounds = _get_pos_data_bounds(positions) self.n_channels = positions.shape[0] self.best_channels = best_channels self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)] self.dead_channels = dead_channels if dead_channels is not None else () self.probe_visual = ScatterVisual() self.canvas.add_visual(self.probe_visual) # Probe visual. color = np.ones((self.n_channels, 4)) color[:, :3] = .5 # Change alpha value for dead channels. if len(self.dead_channels): color[self.dead_channels, 3] = self.dead_channel_alpha self.probe_visual.set_data( pos=self.positions, data_bounds=self.data_bounds, color=color, size=self.unselected_marker_size) # Cluster visual. self.cluster_visual = ScatterVisual() self.canvas.add_visual(self.cluster_visual) # Text visual color[:] = 1 color[self.dead_channels, :3] = self.dead_channel_alpha * 2 self.text_visual = TextVisual() self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_channels >= 200 * u_zoom.y) && ' '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual) self.text_visual.set_data( pos=self.positions, text=self.channel_labels, anchor=[0, -1], data_bounds=self.data_bounds, color=color ) self.text_visual.program['n_channels'] = self.n_channels self.canvas.update()
def __init__( self, spike_times, spike_clusters, cluster_ids=None, cluster_color_selector=None, **kwargs): self.spike_times = spike_times self.n_spikes = len(spike_times) self.duration = spike_times[-1] * 1.01 self.n_clusters = 1 assert len(spike_clusters) == self.n_spikes self.set_spike_clusters(spike_clusters) self.set_cluster_ids(cluster_ids if cluster_ids is not None else None) self.cluster_color_selector = cluster_color_selector super(RasterView, self).__init__(**kwargs) self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False) self.canvas.enable_axes() self.visual = ScatterVisual(marker='vbar') self.canvas.add_visual(self.visual)
def __init__(self, cluster_ids=None, cluster_info=None, bindings=None, **kwargs): super(ClusterScatterView, self).__init__(**kwargs) self.state_attrs += ( 'scaling', 'x_axis', 'y_axis', 'size', 'x_axis_log_scale', 'y_axis_log_scale', 'size_log_scale', ) self.local_state_attrs += () self.canvas.enable_axes() self.canvas.enable_lasso() bindings = bindings or {} self.cluster_info = cluster_info # update self.x_axis, y_axis, size self.__dict__.update({(k, v) for k, v in bindings.items() if k in self._dims}) # Size range computed initially so that it doesn't change during the course of the session. self._size_min = self._size_max = None # Full list of clusters. self.all_cluster_ids = cluster_ids self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.label_visual = TextVisual() self.canvas.add_visual(self.label_visual, exclude_origins=(self.canvas.panzoom, )) self.marker_positions = self.marker_colors = self.marker_sizes = None
def __init__(self, positions=None, best_channels=None, **kwargs): super(ProbeView, self).__init__(**kwargs) # Normalize positions. assert positions.ndim == 2 assert positions.shape[1] == 2 self.positions, self.data_bounds = _get_pos_data_bounds(positions) self.n_channels = positions.shape[0] self.best_channels = best_channels self.probe_visual = ScatterVisual() self.canvas.add_visual(self.probe_visual) # Probe visual. self.probe_visual.set_data(pos=self.positions, data_bounds=self.data_bounds, color=(.5, .5, .5, 1.), size=self.unselected_marker_size) # Cluster visual. self.cluster_visual = ScatterVisual() self.canvas.add_visual(self.cluster_visual)
class RasterView(MarkerSizeMixin, BaseGlobalView, ManualClusteringView): """This view shows a raster plot of all clusters. Constructor ----------- spike_times : array-like An `(n_spikes,)` array with the spike times, in seconds. spike_clusters : array-like An `(n_spikes,)` array with the spike-cluster assignments. cluster_ids : array-like The list of all clusters to show initially. cluster_color_selector : ClusterColorSelector The object managing the color mapping. """ _default_position = 'right' default_shortcuts = { 'change_marker_size': 'ctrl+wheel', 'decrease': 'ctrl+shift+-', 'increase': 'ctrl+shift++', 'select_cluster': 'ctrl+click', } def __init__( self, spike_times, spike_clusters, cluster_ids=None, cluster_color_selector=None, **kwargs): self.spike_times = spike_times self.n_spikes = len(spike_times) self.duration = spike_times[-1] * 1.01 self.n_clusters = 1 assert len(spike_clusters) == self.n_spikes self.set_spike_clusters(spike_clusters) self.set_cluster_ids(cluster_ids if cluster_ids is not None else None) self.cluster_color_selector = cluster_color_selector super(RasterView, self).__init__(**kwargs) self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False) self.canvas.enable_axes() self.visual = ScatterVisual(marker='vbar') self.canvas.add_visual(self.visual) # Data-related functions # ------------------------------------------------------------------------- def set_spike_clusters(self, spike_clusters): """Set the spike clusters for all spikes.""" self.spike_clusters = spike_clusters def set_cluster_ids(self, cluster_ids): """Set the shown clusters, which can be filtered and in any order (from top to bottom).""" if cluster_ids is None or not len(cluster_ids): return self.all_cluster_ids = cluster_ids self.n_clusters = len(self.all_cluster_ids) # Only keep spikes that belong to the selected clusters. self.spike_ids = np.isin(self.spike_clusters, self.all_cluster_ids) # Internal plotting functions # ------------------------------------------------------------------------- def _get_x(self): """Return the x position of the spikes.""" return self.spike_times[self.spike_ids] def _get_y(self): """Return the y position of the spikes, given the relative position of the clusters.""" return np.zeros(np.sum(self.spike_ids)) def _get_box_index(self): """Return, for every spike, its row in the raster plot. This depends on the ordering in self.cluster_ids.""" cl = self.spike_clusters[self.spike_ids] # Sanity check. # assert np.all(np.in1d(cl, self.cluster_ids)) return _index_of(cl, self.all_cluster_ids) def _get_color(self, box_index, selected_clusters=None): """Return, for every spike, its color, based on its box index.""" cluster_colors = self.cluster_color_selector.get_colors(self.all_cluster_ids, alpha=.75) # Selected cluster colors. if selected_clusters is not None: cluster_colors = _add_selected_clusters_colors( selected_clusters, self.all_cluster_ids, cluster_colors) return cluster_colors[box_index, :] # Main methods # ------------------------------------------------------------------------- def _get_data_bounds(self): """Bounds of the raster plot view.""" return (0, 0, self.duration, self.n_clusters) def update_cluster_sort(self, cluster_ids): """Update the order of all clusters.""" self.all_cluster_ids = cluster_ids self.visual.set_box_index(self._get_box_index()) self.canvas.update() def update_color(self, selected_clusters=None): """Update the color of the spikes, depending on the selected clustersd.""" box_index = self._get_box_index() color = self._get_color(box_index, selected_clusters=selected_clusters) self.visual.set_color(color) self.canvas.update() def plot(self, **kwargs): """Make the raster plot.""" if not len(self.spike_clusters): return x = self._get_x() # spike times for the selected spikes y = self._get_y() # just 0 box_index = self._get_box_index() color = self._get_color(box_index) assert x.shape == y.shape == box_index.shape assert color.shape[0] == len(box_index) self.data_bounds = self._get_data_bounds() self.visual.set_data( x=x, y=y, color=color, size=self.marker_size, data_bounds=(0, -1, self.duration, 1)) self.visual.set_box_index(box_index) self.canvas.stacked.n_boxes = self.n_clusters self._update_axes() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(RasterView, self).attach(gui) self.actions.add(self.increase) self.actions.add(self.decrease) self.actions.separator() def on_mouse_click(self, e): """Select a cluster by clicking in the raster plot.""" b = e.button if 'Control' in e.modifiers: # Get mouse position in NDC. cluster_idx, _ = self.canvas.stacked.box_map(e.pos) cluster_id = self.all_cluster_ids[cluster_idx] logger.debug("Click on cluster %d with button %s.", cluster_id, b) emit('cluster_click', self, cluster_id, button=b)
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView): """This view displays an amplitude plot for all selected clusters. Constructor ----------- amplitudes : function Maps `cluster_ids` to a list `[Bunch(amplitudes, spike_ids), ...]` for each cluster. Use `cluster_id=None` for background amplitudes. """ _default_position = 'right' # Alpha channel of the markers in the scatter plot. marker_alpha = 1. # Number of bins in the histogram. n_bins = 100 # Alpha channel of the histogram in the background. histogram_alpha = .5 # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers). quantile = .99 # Size of the histogram, between 0 and 1. histogram_scale = .25 default_shortcuts = { 'change_marker_size': 'ctrl+wheel', 'next_amplitude_type': 'a', 'previous_amplitude_type': 'shift+a', 'select_x_dim': 'alt+left click', 'select_y_dim': 'alt+right click', } def __init__(self, amplitudes=None, amplitude_name=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitude_name',) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes self.amplitude_names = list(amplitudes.keys()) # Current amplitude type. self.amplitude_name = amplitude_name or self.amplitude_names[0] assert self.amplitude_name in amplitudes self.cluster_ids = () self.duration = duration or 1 # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add_on_gpu([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')]) self.canvas.add_visual(self.hist_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) # Amplitude name. self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,)) def _get_data_bounds(self, bunchs): """Compute the data bounds.""" if not bunchs: # pragma: no cover return (0, 0, self.duration, 1) m = min(np.quantile(bunch.amplitudes, 1 - self.quantile) for bunch in bunchs) m = min(0, m) # ensure ymin <= 0 M = max(np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs) return (0, m, self.duration, M) def _add_histograms(self, bunchs): # We do this after get_clusters_data because we need x_max. for bunch in bunchs: bunch.histogram = _compute_histogram( bunch.amplitudes, x_min=self.data_bounds[1], x_max=self.data_bounds[3], n_bins=self.n_bins, normalize=False, ignore_zeros=True, ) return bunchs def _plot_cluster(self, bunch): """Make the scatter plot.""" ms = self._marker_size # Histogram in the background. self.hist_visual.add_batch_data( hist=bunch.histogram, ylim=self._ylim, color=add_alpha(bunch.color, self.histogram_alpha)) # Scatter plot. self.visual.add_batch_data( pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds) def _plot_amplitude_name(self): """Show the amplitude name.""" self.text_visual.add_batch_data(pos=[0, 1], anchor=[0, -1], text=self.amplitude_name) def get_clusters_data(self, load_all=None): """Return a list of Bunch instances, with attributes pos and spike_ids.""" if not len(self.cluster_ids): return cluster_ids = list(self.cluster_ids) # Don't need the background when splitting. if not load_all: # Add None cluster which means background spikes. cluster_ids = [None] + cluster_ids bunchs = self.amplitudes[self.amplitude_name](cluster_ids, load_all=load_all) or () # Add a pos attribute in bunchs in addition to x and y. for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)): spike_ids = _as_array(bunch.spike_ids) spike_times = _as_array(bunch.spike_times) amplitudes = _as_array(bunch.amplitudes) assert spike_ids.shape == spike_times.shape == amplitudes.shape # Ensure that bunch.pos exists, as it used by the LassoMixin. bunch.pos = np.c_[spike_times, amplitudes] assert bunch.pos.ndim == 2 bunch.cluster_id = cluster_id bunch.color = ( selected_cluster_color(i - 1, self.marker_alpha) # Background amplitude color. if cluster_id is not None else (.5, .5, .5, .5)) return bunchs def plot(self, **kwargs): """Update the view with the current cluster selection.""" bunchs = self.get_clusters_data() if not bunchs: return self.data_bounds = self._get_data_bounds(bunchs) bunchs = self._add_histograms(bunchs) # Use the same scale for all histograms. self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1. self.visual.reset_batch() self.hist_visual.reset_batch() self.text_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self._plot_amplitude_name() self.canvas.update_visual(self.visual) self.canvas.update_visual(self.hist_visual) self.canvas.update_visual(self.text_visual) self._update_axes() self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(AmplitudeView, self).attach(gui) self.actions.add(self.next_amplitude_type, set_busy=True) self.actions.add(self.previous_amplitude_type, set_busy=True) def _change_amplitude_type(self, dir=+1): i = self.amplitude_names.index(self.amplitude_name) n = len(self.amplitude_names) self.amplitude_name = self.amplitude_names[(i + dir) % n] logger.debug("Switch to amplitude type: %s.", self.amplitude_name) self.plot() def next_amplitude_type(self): """Switch to the next amplitude type.""" self._change_amplitude_type(+1) def previous_amplitude_type(self): """Switch to the previous amplitude type.""" self._change_amplitude_type(-1)
class FeatureView(MarkerSizeMixin, ScalingMixin, ManualClusteringView): """This view displays a 4x4 subplot matrix with different projections of the principal component features. This view keeps track of which channels are currently shown. Constructor ----------- features : function Maps `(cluster_id, channel_ids=None, load_all=False)` to `Bunch(data, channel_ids, channel_labels, spike_ids , masks)`. * `data` is an `(n_spikes, n_channels, n_features)` array * `channel_ids` contains the channel ids of every row in `data` * `channel_labels` contains the channel labels of every row in `data` * `spike_ids` is a `(n_spikes,)` array * `masks` is an `(n_spikes, n_channels)` array This allows for a sparse format. attributes : dict Maps an attribute name to a 1D array with `n_spikes` numbers (for example, spike times). """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' cluster_ids = () # Whether to disable automatic selection of channels. fixed_channels = False feature_scaling = 1. default_shortcuts = { 'change_marker_size': 'alt+wheel', 'increase': 'ctrl++', 'decrease': 'ctrl+-', 'add_lasso_point': 'ctrl+click', 'stop_lasso': 'ctrl+right click', 'toggle_automatic_channel_selection': 'c', } def __init__(self, features=None, attributes=None, **kwargs): super(FeatureView, self).__init__(**kwargs) self.state_attrs += ('fixed_channels', 'feature_scaling') assert features self.features = features self._lim = 1 self.grid_dim = _get_default_grid() # 2D array where every item a string like `0A,1B` self.n_rows, self.n_cols = np.array(self.grid_dim).shape self.canvas.set_layout('grid', shape=(self.n_rows, self.n_cols)) self.canvas.enable_lasso() # Channels being shown. self.channel_ids = None # Attributes: extra features. This is a dictionary # {name: array} # where each array is a `(n_spikes,)` array. self.attributes = attributes or {} self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual) self.line_visual = LineVisual() self.canvas.add_visual(self.line_visual) def set_grid_dim(self, grid_dim): """Change the grid dim dynamically. Parameters ---------- grid_dim : array-like (2D) `grid_dim[row, col]` is a string with two values separated by a comma. Each value is the relative channel id (0, 1, 2...) followed by the PC (A, B, C...). For example, `grid_dim[row, col] = 0B,1A`. Each value can also be an attribute name, for example `time`. For example, `grid_dim[row, col] = time,2C`. """ self.grid_dim = grid_dim self.n_rows, self.n_cols = np.array(grid_dim).shape self.canvas.grid.shape = (self.n_rows, self.n_cols) # Internal methods # ------------------------------------------------------------------------- def _iter_subplots(self): """Yield (i, j, dim).""" for i in range(self.n_rows): for j in range(self.n_cols): dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') yield i, j, dim_x, dim_y def _get_axis_label(self, dim): """Return the channel label from a dimension, if applicable.""" if str(dim[:-1]).isdecimal(): n = len(self.channel_ids) channel_id = self.channel_ids[int(dim[:-1]) % n] return self.channel_labels[channel_id] + dim[-1] else: return dim def _get_channel_and_pc(self, dim): """Return the channel_id and PC of a dim.""" if self.channel_ids is None: return assert dim not in self.attributes # This is called only on PC data. s = 'ABCDEFGHIJ' # Channel relative index, typically just 0 or 1. c_rel = int(dim[:-1]) # Get the channel_id from the currently-selected channels. channel_id = self.channel_ids[c_rel % len(self.channel_ids)] pc = s.index(dim[-1]) return channel_id, pc def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None): """Extract the points from the data on a given dimension. bunch is returned by the features() function. dim is the string specifying the dimensions to extract for the data. """ if dim in self.attributes: return self.attributes[dim](cluster_id, load_all=load_all) masks = bunch.get('masks', None) channel_id, pc = self._get_channel_and_pc(dim) # Skip the plot if the channel id is not displayed. if channel_id not in bunch.channel_ids: # pragma: no cover return Bunch(data=np.zeros((bunch.data.shape[0],))) # Get the column index of the current channel in data. c = list(bunch.channel_ids).index(channel_id) if masks is not None: masks = masks[:, c] return Bunch(data=self.feature_scaling * bunch.data[:, c, pc], masks=masks) def _get_axis_bounds(self, dim, bunch): """Return the min/max of an axis.""" if dim in self.attributes: # Attribute: specified lim, or compute the min/max. vmin, vmax = bunch.get('lim', (0, 0)) assert vmin is not None assert vmax is not None return vmin, vmax return (-self._lim, +self._lim) def _plot_points(self, bunch, clu_idx=None): if not bunch: return cluster_id = self.cluster_ids[clu_idx] if clu_idx is not None else None for i, j, dim_x, dim_y in self._iter_subplots(): px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id) # Skip empty data. if px is None or py is None: # pragma: no cover logger.warning("Skipping empty data for cluster %d.", cluster_id) return assert px.data.shape == py.data.shape xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) assert xmin <= xmax assert ymin <= ymax data_bounds = (xmin, ymin, xmax, ymax) masks = _get_masks_max(px, py) # Prepare the batch visual with all subplots # for the selected cluster. self.visual.add_batch_data( x=px.data, y=py.data, color=_get_point_color(clu_idx), # Reduced marker size for background features size=self._marker_size, masks=_get_point_masks(clu_idx=clu_idx, masks=masks), data_bounds=data_bounds, box_index=(i, j), ) # Get the channel ids corresponding to the relative channel indices # specified in the dimensions. Channel 0 corresponds to the first # best channel for the selected cluster, and so on. label_x = self._get_axis_label(dim_x) label_y = self._get_axis_label(dim_y) # Add labels. self.text_visual.add_batch_data( pos=[1, 1], anchor=[-1, -1], text=label_y, data_bounds=None, box_index=(i, j), ) self.text_visual.add_batch_data( pos=[0, -1.], anchor=[0, 1], text=label_x, data_bounds=None, box_index=(i, j), ) def _plot_axes(self): self.line_visual.reset_batch() for i, j, dim_x, dim_y in self._iter_subplots(): self.line_visual.add_batch_data( pos=[[-1., 0., +1., 0.], [0., -1., 0., +1.]], color=(.5, .5, .5, .5), box_index=(i, j), data_bounds=None, ) self.canvas.update_visual(self.line_visual) def _get_lim(self, bunchs): if not bunchs: # pragma: no cover return 1 m, M = min(bunch.data.min() for bunch in bunchs), max(bunch.data.max() for bunch in bunchs) M = max(abs(m), abs(M)) return M def _get_scaling_value(self): return self.feature_scaling def _set_scaling_value(self, value): self.feature_scaling = value self.plot(fixed_channels=True) # Public methods # ------------------------------------------------------------------------- def clear_channels(self): """Reset the current channels.""" self.channel_ids = None self.plot() def get_clusters_data(self, fixed_channels=None, load_all=None): # Get the feature data. # Specify the channel ids if these are fixed, otherwise # choose the first cluster's best channels. c = self.channel_ids if fixed_channels else None bunchs = [self.features(cluster_id, channel_ids=c) for cluster_id in self.cluster_ids] bunchs = [b for b in bunchs if b] if not bunchs: # pragma: no cover return [] for cluster_id, bunch in zip(self.cluster_ids, bunchs): bunch.cluster_id = cluster_id # Choose the channels based on the first selected cluster. channel_ids = list(bunchs[0].get('channel_ids', [])) if bunchs else [] common_channels = list(channel_ids) # Intersection (with order kept) of channels belonging to all clusters. for bunch in bunchs: common_channels = [c for c in bunch.get('channel_ids', []) if c in common_channels] # The selected channels will be (1) the channels common to all clusters, followed # by (2) remaining channels from the first cluster (excluding those already selected # in (1)). n = len(channel_ids) not_common_channels = [c for c in channel_ids if c not in common_channels] channel_ids = common_channels + not_common_channels[:n - len(common_channels)] assert len(channel_ids) > 0 # Choose the channels automatically unless fixed_channels is set. if (not fixed_channels or self.channel_ids is None): self.channel_ids = channel_ids assert len(self.channel_ids) # Channel labels. self.channel_labels = {} for d in bunchs: chl = d.get('channel_labels', ['%d' % ch for ch in d.get('channel_ids', [])]) self.channel_labels.update({ channel_id: chl[i] for i, channel_id in enumerate(d.get('channel_ids', []))}) return bunchs def plot(self, **kwargs): """Update the view with the selected clusters.""" # Determine whether the channels should be fixed or not. added = kwargs.get('up', {}).get('added', None) # Fix the channels if the view updates after a cluster event # and there are new clusters. fixed_channels = ( self.fixed_channels or kwargs.get('fixed_channels', None) or added is not None) # Get the clusters data. bunchs = self.get_clusters_data(fixed_channels=fixed_channels) bunchs = [b for b in bunchs if b] if not bunchs: return self._lim = self._get_lim(bunchs) # Get the background data. background = self.features(channel_ids=self.channel_ids) # Plot all features. self._plot_axes() # NOTE: the columns in bunch.data are ordered by decreasing quality # of the associated channels. The channels corresponding to each # column are given in bunch.channel_ids in the same order. # Plot points. self.visual.reset_batch() self.text_visual.reset_batch() self._plot_points(background) # background spikes # Plot each cluster. for clu_idx, bunch in enumerate(bunchs): self._plot_points(bunch, clu_idx=clu_idx) # Upload the data on the GPU. self.canvas.update_visual(self.visual) self.canvas.update_visual(self.text_visual) self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(FeatureView, self).attach(gui) self.actions.add( self.toggle_automatic_channel_selection, checked=not self.fixed_channels, checkable=True) self.actions.add(self.clear_channels) self.actions.separator() def toggle_automatic_channel_selection(self, checked): """Toggle the automatic selection of channels when the cluster selection changes.""" self.fixed_channels = not checked @property def status(self): if self.channel_ids is None: # pragma: no cover return '' channel_labels = [self.channel_labels[ch] for ch in self.channel_ids[:2]] return 'channels: %s' % ', '.join(channel_labels) # Dimension selection # ------------------------------------------------------------------------- def on_select_channel(self, sender=None, channel_id=None, key=None, button=None): """Respond to the click on a channel from another view, and update the relevant subplots.""" channels = self.channel_ids if channels is None: return if len(channels) == 1: self.plot() return assert len(channels) >= 2 # Get the axis from the pressed button (1, 2, etc.) if key is not None: d = np.clip(len(channels) - 1, 0, key - 1) else: d = 0 if button == 'Left' else 1 # Change the first or second best channel. old = channels[d] # Avoid updating the view if the channel doesn't change. if channel_id == old: return channels[d] = channel_id # Ensure that the first two channels are different. if channels[1 - min(d, 1)] == channel_id: channels[1 - min(d, 1)] = old assert channels[0] != channels[1] # Remove duplicate channels. self.channel_ids = _uniq(channels) logger.debug("Choose channels %d and %d in feature view.", *channels[:2]) # Fix the channels temporarily. self.plot(fixed_channels=True) self.update_status() def on_mouse_click(self, e): """Select a feature dimension by clicking on a box in the feature view.""" b = e.button if 'Alt' in e.modifiers: # Get mouse position in NDC. (i, j), _ = self.canvas.grid.box_map(e.pos) dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') dim = dim_x if b == 'Left' else dim_y other_dim = dim_y if b == 'Left' else dim_x if dim not in self.attributes: # When a regular (channel, PC) dimension is selected. channel_pc = self._get_channel_and_pc(dim) if channel_pc is None: return channel_id, pc = channel_pc logger.debug("Click on feature dim %s, channel id %s, PC %s.", dim, channel_id, pc) else: # When the selected dimension is an attribute, e.g. "time". pc = None # Take the channel id in the other dimension. channel_pc = self._get_channel_and_pc(other_dim) channel_id = channel_pc[0] if channel_pc is not None else None logger.debug("Click on feature dim %s.", dim) emit('select_feature', self, dim=dim, channel_id=channel_id, pc=pc) def on_request_split(self, sender=None): """Return the spikes enclosed by the lasso.""" if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) assert len(self.channel_ids) # Get the dimensions of the lassoed subplot. i, j = self.canvas.layout.active_box dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') # Get all points from all clusters. pos = [] spike_ids = [] for cluster_id in self.cluster_ids: # Load all spikes. bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True) if not bunch: continue px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True) points = np.c_[px.data, py.data] # Normalize the points. xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) r = Range((xmin, ymin, xmax, ymax)) points = r.apply(points) pos.append(points) spike_ids.append(bunch.spike_ids) pos = np.vstack(pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.canvas.lasso.in_polygon(pos) self.canvas.lasso.clear() return np.unique(spike_ids[ind])
class ProbeView(ManualClusteringView): """This view displays the positions of all channels on the probe, highlighting channels where the selected clusters belong. Constructor ----------- positions : array-like An `(n_channels, 2)` array with the channel positions best_channels : function Maps `cluster_id` to the list of the best_channel_ids. channel_labels : list List of channel label strings. dead_channels : list List of dead channel ids. """ _default_position = 'right' # Marker size of channels without selected clusters. unselected_marker_size = 10 # Marker size of channels with selected clusters. selected_marker_size = 15 # Alpha value of the dead channels. dead_channel_alpha = .25 do_show_labels = False def __init__( self, positions=None, best_channels=None, channel_labels=None, dead_channels=None, **kwargs): super(ProbeView, self).__init__(**kwargs) self.state_attrs += ('do_show_labels',) # Normalize positions. assert positions.ndim == 2 assert positions.shape[1] == 2 positions = positions.astype(np.float32) self.positions, self.data_bounds = _get_pos_data_bounds(positions) self.n_channels = positions.shape[0] self.best_channels = best_channels self.channel_labels = channel_labels or [str(ch) for ch in range(self.n_channels)] self.dead_channels = dead_channels if dead_channels is not None else () self.probe_visual = ScatterVisual() self.canvas.add_visual(self.probe_visual) # Probe visual. color = np.ones((self.n_channels, 4)) color[:, :3] = .5 # Change alpha value for dead channels. if len(self.dead_channels): color[self.dead_channels, 3] = self.dead_channel_alpha self.probe_visual.set_data( pos=self.positions, data_bounds=self.data_bounds, color=color, size=self.unselected_marker_size) # Cluster visual. self.cluster_visual = ScatterVisual() self.canvas.add_visual(self.cluster_visual) # Text visual color[:] = 1 color[self.dead_channels, :3] = self.dead_channel_alpha * 2 self.text_visual = TextVisual() self.text_visual.inserter.insert_vert('uniform float n_channels;', 'header') self.text_visual.inserter.add_varying( 'float', 'v_discard', 'float((n_channels >= 200 * u_zoom.y) && ' '(mod(int(a_string_index), int(n_channels / (200 * u_zoom.y))) >= 1))') self.text_visual.inserter.insert_frag('if (v_discard > 0) discard;', 'end') self.canvas.add_visual(self.text_visual) self.text_visual.set_data( pos=self.positions, text=self.channel_labels, anchor=[0, -1], data_bounds=self.data_bounds, color=color ) self.text_visual.program['n_channels'] = self.n_channels self.canvas.update() def _get_clu_positions(self, cluster_ids): """Get the positions of the channels containing selected clusters.""" # List of channels per cluster. cluster_channels = {i: self.best_channels(cl) for i, cl in enumerate(cluster_ids)} # List of clusters per channel. clusters_per_channel = defaultdict(lambda: []) for clu_idx, channels in cluster_channels.items(): for channel in channels: clusters_per_channel[channel].append(clu_idx) # Enumerate the discs for each channel. w = self.data_bounds[2] - self.data_bounds[0] clu_pos = [] clu_colors = [] for channel_id, (x, y) in enumerate(self.positions): for i, clu_idx in enumerate(clusters_per_channel[channel_id]): n = len(clusters_per_channel[channel_id]) # Translation. t = .025 * w * (i - .5 * (n - 1)) x += t alpha = 1.0 if channel_id not in self.dead_channels else self.dead_channel_alpha clu_pos.append((x, y)) clu_colors.append(selected_cluster_color(clu_idx, alpha=alpha)) return np.array(clu_pos), np.array(clu_colors) def on_select(self, cluster_ids=(), **kwargs): """Update the view with the selected clusters.""" self.cluster_ids = cluster_ids if not cluster_ids: return pos, colors = self._get_clu_positions(cluster_ids) self.cluster_visual.set_data( pos=pos, color=colors, size=self.selected_marker_size, data_bounds=self.data_bounds) self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(ProbeView, self).attach(gui) self.actions.add(self.toggle_show_labels, checkable=True, checked=self.do_show_labels) if not self.do_show_labels: self.text_visual.hide() def toggle_show_labels(self, checked): """Toggle the display of the channel ids.""" logger.debug("Set show labels to %s.", checked) self.do_show_labels = checked self.text_visual._hidden = not checked self.canvas.update()
class ProbeView(ManualClusteringView): """This view displays the positions of all channels on the probe, highlighting channels where the selected clusters belong. Constructor ----------- positions : array-like An `(n_channels, 2)` array with the channel positions best_channels : function Maps `cluster_id` to the list of the best_channel_ids. """ _default_position = 'right' # Marker size of channels without selected clusters. unselected_marker_size = 10 # Marker size of channels with selected clusters. selected_marker_size = 15 def __init__(self, positions=None, best_channels=None, **kwargs): super(ProbeView, self).__init__(**kwargs) # Normalize positions. assert positions.ndim == 2 assert positions.shape[1] == 2 self.positions, self.data_bounds = _get_pos_data_bounds(positions) self.n_channels = positions.shape[0] self.best_channels = best_channels self.probe_visual = ScatterVisual() self.canvas.add_visual(self.probe_visual) # Probe visual. self.probe_visual.set_data(pos=self.positions, data_bounds=self.data_bounds, color=(.5, .5, .5, 1.), size=self.unselected_marker_size) # Cluster visual. self.cluster_visual = ScatterVisual() self.canvas.add_visual(self.cluster_visual) def _get_clu_positions(self, cluster_ids): """Get the positions of the channels containing selected clusters.""" # List of channels per cluster. cluster_channels = { i: self.best_channels(cl) for i, cl in enumerate(cluster_ids) } # List of clusters per channel. clusters_per_channel = defaultdict(lambda: []) for clu_idx, channels in cluster_channels.items(): for channel in channels: clusters_per_channel[channel].append(clu_idx) # Enumerate the discs for each channel. w = self.data_bounds[2] - self.data_bounds[0] clu_pos = [] clu_colors = [] for channel_id, (x, y) in enumerate(self.positions): for i, clu_idx in enumerate(clusters_per_channel[channel_id]): n = len(clusters_per_channel[channel_id]) # Translation. t = .025 * w * (i - .5 * (n - 1)) x += t clu_pos.append((x, y)) clu_colors.append(selected_cluster_color(clu_idx)) return np.array(clu_pos), np.array(clu_colors) def on_select(self, cluster_ids=(), **kwargs): """Update the view with the selected clusters.""" self.cluster_ids = cluster_ids if not cluster_ids: return pos, colors = self._get_clu_positions(cluster_ids) self.cluster_visual.set_data(pos=pos, color=colors, size=self.selected_marker_size, data_bounds=self.data_bounds) self.canvas.update()
class RasterView(MarkerSizeMixin, BaseColorView, BaseGlobalView, ManualClusteringView): """This view shows a raster plot of all clusters. Constructor ----------- spike_times : array-like An `(n_spikes,)` array with the spike times, in seconds. spike_clusters : array-like An `(n_spikes,)` array with the spike-cluster assignments. cluster_ids : array-like The list of all clusters to show initially. """ _default_position = 'right' default_shortcuts = { 'change_marker_size': 'alt+wheel', 'switch_color_scheme': 'shift+wheel', 'decrease_marker_size': 'ctrl+shift+-', 'increase_marker_size': 'ctrl+shift++', 'select_cluster': 'ctrl+click', 'select_more': 'shift+click', } def __init__(self, spike_times, spike_clusters, cluster_ids=None, **kwargs): self.spike_times = spike_times self.n_spikes = len(spike_times) self.duration = spike_times[-1] * 1.01 self.n_clusters = 1 assert len(spike_clusters) == self.n_spikes self.set_spike_clusters(spike_clusters) self.set_cluster_ids(cluster_ids) super(RasterView, self).__init__(**kwargs) self.canvas.set_layout('stacked', origin='top', n_plots=self.n_clusters, has_clip=False) self.canvas.enable_axes() self.visual = ScatterVisual( marker='vbar', marker_scaling=''' point_size = v_size * u_zoom.y + 5.; float width = 0.2; float height = 0.5; vec2 marker_size = point_size * vec2(width, height); marker_size.x = clamp(marker_size.x, 1, 20); ''', ) self.visual.inserter.insert_vert(''' gl_PointSize = a_size * u_zoom.y + 5.0; ''', 'end') self.canvas.add_visual(self.visual) self.canvas.panzoom.set_constrain_bounds((-1, -2, +1, +2)) # Data-related functions # ------------------------------------------------------------------------- def set_spike_clusters(self, spike_clusters): """Set the spike clusters for all spikes.""" self.spike_clusters = spike_clusters def set_cluster_ids(self, cluster_ids): """Set the shown clusters, which can be filtered and in any order (from top to bottom).""" if cluster_ids is None or not len(cluster_ids): return self.all_cluster_ids = cluster_ids self.n_clusters = len(self.all_cluster_ids) # Only keep spikes that belong to the selected clusters. self.spike_ids = np.isin(self.spike_clusters, self.all_cluster_ids) # Internal plotting functions # ------------------------------------------------------------------------- def _get_x(self): """Return the x position of the spikes.""" return self.spike_times[self.spike_ids] def _get_y(self): """Return the y position of the spikes, given the relative position of the clusters.""" return np.zeros(np.sum(self.spike_ids)) def _get_box_index(self): """Return, for every spike, its row in the raster plot. This depends on the ordering in self.cluster_ids.""" cl = self.spike_clusters[self.spike_ids] # Sanity check. # assert np.all(np.in1d(cl, self.cluster_ids)) return _index_of(cl, self.all_cluster_ids) def _get_color(self, box_index, selected_clusters=None): """Return, for every spike, its color, based on its box index.""" cluster_colors = self.get_cluster_colors(self.all_cluster_ids, alpha=.75) # Selected cluster colors. if selected_clusters is not None: cluster_colors = _add_selected_clusters_colors( selected_clusters, self.all_cluster_ids, cluster_colors) return cluster_colors[box_index, :] # Main methods # ------------------------------------------------------------------------- def _get_data_bounds(self): """Bounds of the raster plot view.""" return (0, 0, self.duration, self.n_clusters) def update_cluster_sort(self, cluster_ids): """Update the order of all clusters.""" self.all_cluster_ids = cluster_ids self.visual.set_box_index(self._get_box_index()) self.canvas.update() def update_color(self): """Update the color of the spikes, depending on the selected clusters.""" box_index = self._get_box_index() color = self._get_color(box_index, selected_clusters=self.cluster_ids) self.visual.set_color(color) self.canvas.update() @property def status(self): return 'Color scheme: %s' % self.color_scheme def plot(self, **kwargs): """Make the raster plot.""" if not len(self.spike_clusters): return x = self._get_x() # spike times for the selected spikes y = self._get_y() # just 0 box_index = self._get_box_index() color = self._get_color(box_index) assert x.shape == y.shape == box_index.shape assert color.shape[0] == len(box_index) self.data_bounds = self._get_data_bounds() self.visual.set_data( x=x, y=y, color=color, size=self.marker_size, data_bounds=(0, -1, self.duration, 1)) self.visual.set_box_index(box_index) self.canvas.stacked.n_boxes = self.n_clusters self._update_axes() # self.canvas.stacked.add_boxes(self.canvas) self.canvas.update() def attach(self, gui): """Attach the view to the GUI.""" super(RasterView, self).attach(gui) self.actions.add(self.increase_marker_size) self.actions.add(self.decrease_marker_size) self.actions.separator() def on_select(self, *args, **kwargs): super(RasterView, self).on_select(*args, **kwargs) self.update_color() def zoom_to_time_range(self, interval): """Zoom to a time interval.""" if not interval: return t0, t1 = interval w = .5 * (t1 - t0) # half width tm = .5 * (t0 + t1) w = min(5, w) # minimum 5s time range t0, t1 = tm - w, tm + w x0 = -1 + 2 * t0 / self.duration x1 = -1 + 2 * t1 / self.duration box = (x0, -1, x1, +1) self.canvas.panzoom.set_range(box) # Interactivity # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking in the raster plot.""" if 'Control' not in e.modifiers: return b = e.button # Get mouse position in NDC. cluster_idx, _ = self.canvas.stacked.box_map(e.pos) cluster_id = self.all_cluster_ids[cluster_idx] logger.debug("Click on cluster %d with button %s.", cluster_id, b) if 'Shift' in e.modifiers: emit('select_more', self, [cluster_id]) else: emit('request_select', self, [cluster_id])
def __init__(self, amplitudes=None, amplitudes_type=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitudes_type', ) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes # Rotating property amplitudes types. self.amplitudes_types = RotatingProperty() for name, value in self.amplitudes.items(): self.amplitudes_types.add(name, value) # Current amplitudes type. self.amplitudes_types.set(amplitudes_type) assert self.amplitudes_type in self.amplitudes self.cluster_ids = () self.duration = duration or 1. # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('cw'), Scale((1, -1)), Translate((2.05, 0)), ]) self.canvas.add_visual(self.hist_visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0) # Yellow vertical bar showing the selected time interval. self.patch_visual = PatchVisual(primitive_type='triangle_fan') self.patch_visual.inserter.insert_vert( ''' const float MIN_INTERVAL_SIZE = 0.01; uniform float u_interval_size; ''', 'header') self.patch_visual.inserter.insert_vert( ''' gl_Position.y = pos_orig.y; // The following is used to ensure that (1) the bar width increases with the zoom level // but also (2) there is a minimum absolute width so that the bar remains visible // at low zoom levels. float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x); // HACK: the z coordinate is used to store 0 or 1, depending on whether the current // vertex is on the left or right edge of the bar. gl_Position.x += w * (-1 + 2 * int(a_position.z == 0)); ''', 'after_transforms') self.canvas.add_visual(self.patch_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))
class AmplitudeView(MarkerSizeMixin, LassoMixin, ManualClusteringView): """This view displays an amplitude plot for all selected clusters. Constructor ----------- amplitudes : dict Dictionary `{amplitudes_type: function}`, for different types of amplitudes. Each function maps `cluster_ids` to a list `[Bunch(amplitudes, spike_ids, spike_times), ...]` for each cluster. Use `cluster_id=None` for background amplitudes. """ # Do not show too many clusters. max_n_clusters = 8 _default_position = 'right' # Alpha channel of the markers in the scatter plot. marker_alpha = 1. time_range_color = (1., 1., 0., .25) # Number of bins in the histogram. n_bins = 100 # Alpha channel of the histogram in the background. histogram_alpha = .5 # Quantile used for scaling of the amplitudes (less than 1 to avoid outliers). quantile = .99 # Size of the histogram, between 0 and 1. histogram_scale = .25 default_shortcuts = { 'change_marker_size': 'alt+wheel', 'next_amplitudes_type': 'a', 'previous_amplitudes_type': 'shift+a', 'select_x_dim': 'shift+left click', 'select_y_dim': 'shift+right click', 'select_time': 'alt+click', } def __init__(self, amplitudes=None, amplitudes_type=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitudes_type', ) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes # Rotating property amplitudes types. self.amplitudes_types = RotatingProperty() for name, value in self.amplitudes.items(): self.amplitudes_types.add(name, value) # Current amplitudes type. self.amplitudes_types.set(amplitudes_type) assert self.amplitudes_type in self.amplitudes self.cluster_ids = () self.duration = duration or 1. # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('cw'), Scale((1, -1)), Translate((2.05, 0)), ]) self.canvas.add_visual(self.hist_visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0) # Yellow vertical bar showing the selected time interval. self.patch_visual = PatchVisual(primitive_type='triangle_fan') self.patch_visual.inserter.insert_vert( ''' const float MIN_INTERVAL_SIZE = 0.01; uniform float u_interval_size; ''', 'header') self.patch_visual.inserter.insert_vert( ''' gl_Position.y = pos_orig.y; // The following is used to ensure that (1) the bar width increases with the zoom level // but also (2) there is a minimum absolute width so that the bar remains visible // at low zoom levels. float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x); // HACK: the z coordinate is used to store 0 or 1, depending on whether the current // vertex is on the left or right edge of the bar. gl_Position.x += w * (-1 + 2 * int(a_position.z == 0)); ''', 'after_transforms') self.canvas.add_visual(self.patch_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2)) def _get_data_bounds(self, bunchs): """Compute the data bounds.""" if not bunchs: # pragma: no cover return (0, 0, self.duration, 1) m = min( np.quantile(bunch.amplitudes, 1 - self.quantile) for bunch in bunchs if len(bunch.amplitudes)) m = min(0, m) # ensure ymin <= 0 M = max( np.quantile(bunch.amplitudes, self.quantile) for bunch in bunchs if len(bunch.amplitudes)) return (0, m, self.duration, M) def _add_histograms(self, bunchs): # We do this after get_clusters_data because we need x_max. for bunch in bunchs: bunch.histogram = _compute_histogram( bunch.amplitudes, x_min=self.data_bounds[1], x_max=self.data_bounds[3], n_bins=self.n_bins, normalize=True, ignore_zeros=True, ) return bunchs def show_time_range(self, interval=(0, 0)): start, end = interval x0 = -1 + 2 * (start / self.duration) x1 = -1 + 2 * (end / self.duration) xm = .5 * (x0 + x1) pos = np.array([ [xm, -1], [xm, +1], [xm, +1], [xm, -1], ]) self.patch_visual.program['u_interval_size'] = .5 * (x1 - x0) self.patch_visual.set_data(pos=pos, color=self.time_range_color, depth=[0, 0, 1, 1]) self.canvas.update() def _plot_cluster(self, bunch): """Make the scatter plot.""" ms = self._marker_size if not len(bunch.histogram): return # Histogram in the background. self.hist_visual.add_batch_data(hist=bunch.histogram, ylim=self._ylim, color=add_alpha( bunch.color, self.histogram_alpha)) # Scatter plot. self.visual.add_batch_data(pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds) def get_clusters_data(self, load_all=None): """Return a list of Bunch instances, with attributes pos and spike_ids.""" if not len(self.cluster_ids): return cluster_ids = list(self.cluster_ids) # Don't need the background when splitting. if not load_all: # Add None cluster which means background spikes. cluster_ids = [None] + cluster_ids bunchs = self.amplitudes[self.amplitudes_type](cluster_ids, load_all=load_all) or () # Add a pos attribute in bunchs in addition to x and y. for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)): spike_ids = _as_array(bunch.spike_ids) spike_times = _as_array(bunch.spike_times) amplitudes = _as_array(bunch.amplitudes) assert spike_ids.shape == spike_times.shape == amplitudes.shape # Ensure that bunch.pos exists, as it used by the LassoMixin. bunch.pos = np.c_[spike_times, amplitudes] assert bunch.pos.ndim == 2 bunch.cluster_id = cluster_id bunch.color = ( selected_cluster_color(i - 1, self.marker_alpha) # Background amplitude color. if cluster_id is not None else (.5, .5, .5, .5)) return bunchs def plot(self, **kwargs): """Update the view with the current cluster selection.""" bunchs = self.get_clusters_data(**kwargs) if not bunchs: return self.data_bounds = self._get_data_bounds(bunchs) bunchs = self._add_histograms(bunchs) # Use the same scale for all histograms. self._ylim = max(bunch.histogram.max() for bunch in bunchs) if bunchs else 1. self.visual.reset_batch() self.hist_visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.visual) self.canvas.update_visual(self.hist_visual) self._update_axes() self.canvas.update() self.update_status() def attach(self, gui): """Attach the view to the GUI.""" super(AmplitudeView, self).attach(gui) # Amplitude type actions. def _make_amplitude_action(a): def callback(): self.amplitudes_type = a self.plot() return callback for a in self.amplitudes_types.keys(): name = 'Change amplitudes type to %s' % a self.actions.add(_make_amplitude_action(a), show_shortcut=False, name=name, view_submenu='Change amplitudes type') self.actions.add(self.next_amplitudes_type, set_busy=True) self.actions.add(self.previous_amplitudes_type, set_busy=True) @property def status(self): return self.amplitudes_type @property def amplitudes_type(self): return self.amplitudes_types.current @amplitudes_type.setter def amplitudes_type(self, value): self.amplitudes_types.set(value) def next_amplitudes_type(self): """Switch to the next amplitudes type.""" self.amplitudes_types.next() logger.debug("Switch to amplitudes type: %s.", self.amplitudes_types.current) self.plot() def previous_amplitudes_type(self): """Switch to the previous amplitudes type.""" self.amplitudes_types.previous() logger.debug("Switch to amplitudes type: %s.", self.amplitudes_types.current) self.plot() def on_mouse_click(self, e): """Select a time from the amplitude view to display in the trace view.""" if 'Alt' in e.modifiers: mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) time = Range(NDC, self.data_bounds).apply(mouse_pos)[0][0] emit('select_time', self, time)
class ScatterView(MarkerSizeMixin, LassoMixin, ManualClusteringView): """This view displays a scatter plot for all selected clusters. Constructor ----------- coords : function Maps `cluster_ids` to a list `[Bunch(x, y, spike_ids, data_bounds), ...]` for each cluster. """ _default_position = 'right' default_shortcuts = { 'change_marker_size': 'ctrl+wheel', } def __init__(self, coords=None, **kwargs): super(ScatterView, self).__init__(**kwargs) # Save the marker size in the global and local view's config. self.canvas.enable_axes() self.canvas.enable_lasso() assert coords self.coords = coords self.visual = ScatterVisual() self.canvas.add_visual(self.visual) def _plot_cluster(self, bunch): ms = self._marker_size self.visual.add_batch_data( pos=bunch.pos, color=bunch.color, size=ms, data_bounds=self.data_bounds) def _get_split_cluster_data(self, bunchs): """Get the data when there is one Bunch per cluster.""" # Add a pos attribute in bunchs in addition to x and y. for i, (cluster_id, bunch) in enumerate(zip(self.cluster_ids, bunchs)): bunch.cluster_id = cluster_id if 'pos' not in bunch: assert bunch.x.ndim == 1 assert bunch.x.shape == bunch.y.shape bunch.pos = np.c_[bunch.x, bunch.y] assert bunch.pos.ndim == 2 assert 'spike_ids' in bunch bunch.color = selected_cluster_color(i, .75) return bunchs def _get_collated_cluster_data(self, bunch): """Get the data when there is a single Bunch for all selected clusters.""" assert 'spike_ids' in bunch if 'pos' not in bunch: assert bunch.x.ndim == 1 assert bunch.x.shape == bunch.y.shape bunch.pos = np.c_[bunch.x, bunch.y] assert bunch.pos.ndim == 2 bunch.color = spike_colors(bunch.spike_clusters, self.cluster_ids) return bunch def get_clusters_data(self, load_all=None): """Return a list of Bunch instances, with attributes pos and spike_ids.""" if not load_all: bunchs = self.coords(self.cluster_ids) or () elif 'load_all' in inspect.signature(self.coords).parameters: bunchs = self.coords(self.cluster_ids, load_all=load_all) or () else: logger.warning( "The view `%s` may not load all spikes when using the lasso for splitting.", self.__class__.__name__) bunchs = self.coords(self.cluster_ids) if isinstance(bunchs, dict): return [self._get_collated_cluster_data(bunchs)] elif isinstance(bunchs, (list, tuple)): return self._get_split_cluster_data(bunchs) raise ValueError("The output of `coords()` should be either a list of Bunch, or a Bunch.") def plot(self, **kwargs): """Update the view with the current cluster selection.""" bunchs = self.get_clusters_data() # Hide the visual if there is no data. if not bunchs: self.visual.hide() self.canvas.update() return self.data_bounds = self._get_data_bounds(bunchs) self.visual.reset_batch() for bunch in bunchs: self._plot_cluster(bunch) self.canvas.update_visual(self.visual) self.visual.show() self._update_axes() self.canvas.update()
class ClusterScatterView(MarkerSizeMixin, BaseColorView, BaseGlobalView, ManualClusteringView): """This view shows all clusters in a customizable scatter plot. Constructor ----------- cluster_ids : array-like cluster_info: function Maps cluster_id => Bunch() with attributes. bindings: dict Maps plot dimension to cluster attributes. """ _default_position = 'right' _scaling = 1. _default_alpha = .75 _min_marker_size = 5.0 _max_marker_size = 30.0 _dims = ('x_axis', 'y_axis', 'size') # NOTE: this is not the actual marker size, but a scaling factor for the normal marker size. _marker_size = 1. _default_marker_size = 1. x_axis = '' y_axis = '' size = '' x_axis_log_scale = False y_axis_log_scale = False size_log_scale = False default_shortcuts = { 'change_marker_size': 'alt+wheel', 'switch_color_scheme': 'shift+wheel', 'select_cluster': 'click', 'select_more': 'shift+click', 'add_to_lasso': 'control+left click', 'clear_lasso': 'control+right click', } default_snippets = { 'set_x_axis': 'csx', 'set_y_axis': 'csy', 'set_size': 'css', } def __init__(self, cluster_ids=None, cluster_info=None, bindings=None, **kwargs): super(ClusterScatterView, self).__init__(**kwargs) self.state_attrs += ( 'scaling', 'x_axis', 'y_axis', 'size', 'x_axis_log_scale', 'y_axis_log_scale', 'size_log_scale', ) self.local_state_attrs += () self.canvas.enable_axes() self.canvas.enable_lasso() bindings = bindings or {} self.cluster_info = cluster_info # update self.x_axis, y_axis, size self.__dict__.update({(k, v) for k, v in bindings.items() if k in self._dims}) # Size range computed initially so that it doesn't change during the course of the session. self._size_min = self._size_max = None # Full list of clusters. self.all_cluster_ids = cluster_ids self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.label_visual = TextVisual() self.canvas.add_visual(self.label_visual, exclude_origins=(self.canvas.panzoom, )) self.marker_positions = self.marker_colors = self.marker_sizes = None def _update_labels(self): self.label_visual.set_data(pos=[[-1, -1], [1, 1]], text=[self.x_axis, self.y_axis], anchor=[[1.25, 3], [-3, -1.25]]) # Data access # ------------------------------------------------------------------------- @property def bindings(self): return {k: getattr(self, k) for k in self._dims} def get_cluster_data(self, cluster_id): """Return the data of one cluster.""" data = self.cluster_info(cluster_id) return {k: data.get(v, 0.) for k, v in self.bindings.items()} def get_clusters_data(self, cluster_ids): """Return the data of a set of clusters, as a dictionary {cluster_id: Bunch}.""" return { cluster_id: self.get_cluster_data(cluster_id) for cluster_id in cluster_ids } def set_cluster_ids(self, all_cluster_ids): """Update the cluster data by specifying the list of all cluster ids.""" self.all_cluster_ids = all_cluster_ids if len(all_cluster_ids) == 0: return self.prepare_position() self.prepare_size() self.prepare_color() # Data preparation # ------------------------------------------------------------------------- def set_fields(self): data = self.cluster_info(self.all_cluster_ids[0]) self.fields = sorted(data.keys()) self.fields = [f for f in self.fields if not isinstance(data[f], str)] def prepare_data(self): """Prepare the marker position, size, and color from the cluster information.""" self.prepare_position() self.prepare_size() self.prepare_color() def prepare_position(self): """Compute the marker positions.""" self.cluster_data = self.get_clusters_data(self.all_cluster_ids) # Get the list of fields returned by cluster_info. self.set_fields() # Create the x array. x = np.array([ self.cluster_data[cluster_id]['x_axis'] or 0. for cluster_id in self.all_cluster_ids ]) if self.x_axis_log_scale: x = np.log(1.0 + x - x.min()) # Create the y array. y = np.array([ self.cluster_data[cluster_id]['y_axis'] or 0. for cluster_id in self.all_cluster_ids ]) if self.y_axis_log_scale: y = np.log(1.0 + y - y.min()) self.marker_positions = np.c_[x, y] # Update the data bounds. self.data_bounds = (x.min(), y.min(), x.max(), y.max()) def prepare_size(self): """Compute the marker sizes.""" size = np.array([ self.cluster_data[cluster_id]['size'] or 1. for cluster_id in self.all_cluster_ids ]) # Log scale for the size. if self.size_log_scale: size = np.log(1.0 + size - size.min()) # Find the size range. if self._size_min is None: self._size_min, self._size_max = size.min(), size.max() m, M = self._size_min, self._size_max # Normalize the marker size. size = (size - m) / ((M - m) or 1.0) # size is in [0, 1] ms, Ms = self._min_marker_size, self._max_marker_size size = ms + size * (Ms - ms) # now, size is in [ms, Ms] self.marker_sizes = size def prepare_color(self): """Compute the marker colors.""" colors = self.get_cluster_colors(self.all_cluster_ids, self._default_alpha) self.marker_colors = colors # Marker size # ------------------------------------------------------------------------- @property def marker_size(self): """Size of the spike markers, in pixels.""" return self._marker_size @marker_size.setter def marker_size(self, val): # We override this method so as to use self._marker_size as a scaling factor, not # as an actual fixed marker size. self._marker_size = val self._set_marker_size() self.canvas.update() def _set_marker_size(self): if self.marker_sizes is not None: self.visual.set_marker_size(self.marker_sizes * self._marker_size) # Plotting functions # ------------------------------------------------------------------------- def update_color(self): """Update the cluster colors depending on the current color scheme.""" self.prepare_color() self.visual.set_color(self.marker_colors) self.canvas.update() def update_select_color(self): """Update the cluster colors after the cluster selection changes.""" if self.marker_colors is None: return selected_clusters = self.cluster_ids if selected_clusters is not None and len(selected_clusters) > 0: colors = _add_selected_clusters_colors(selected_clusters, self.all_cluster_ids, self.marker_colors.copy()) self.visual.set_color(colors) self.canvas.update() def plot(self, **kwargs): """Make the scatter plot.""" if self.marker_positions is None: self.prepare_data() self.visual.set_data( pos=self.marker_positions, color=self.marker_colors, size=self.marker_sizes * self._marker_size, # marker size scaling factor data_bounds=self.data_bounds) self.canvas.axes.reset_data_bounds(self.data_bounds) self.canvas.update() def change_bindings(self, **kwargs): """Change the bindings.""" # Ensure the specified fields are valid. kwargs = {k: v for k, v in kwargs.items() if v in self.fields} assert set(kwargs.keys()) <= set(self._dims) # Reset the size scaling. if 'size' in kwargs: self._size_min = self._size_max = None self.__dict__.update(kwargs) self._update_labels() self.update_status() self.prepare_data() self.plot() def toggle_log_scale(self, dim, checked): """Toggle logarithmic scaling for one of the dimensions.""" self._size_min = None setattr(self, '%s_log_scale' % dim, checked) self.prepare_data() self.plot() self.canvas.update() def set_x_axis(self, field): """Set the dimension for the x axis.""" self.change_bindings(x_axis=field) def set_y_axis(self, field): """Set the dimension for the y axis.""" self.change_bindings(y_axis=field) def set_size(self, field): """Set the dimension for the marker size.""" self.change_bindings(size=field) # Misc functions # ------------------------------------------------------------------------- def attach(self, gui): """Attach the GUI.""" super(ClusterScatterView, self).attach(gui) def _make_action(dim, name): def callback(): self.change_bindings(**{dim: name}) return callback def _make_log_toggle(dim): def callback(checked): self.toggle_log_scale(dim, checked) return callback # Change the bindings. for dim in self._dims: view_submenu = 'Change %s' % dim # Change to every cluster info. for name in self.fields: self.actions.add(_make_action(dim, name), show_shortcut=False, name='Change %s to %s' % (dim, name), view_submenu=view_submenu) # Toggle logarithmic scale. self.actions.separator(view_submenu=view_submenu) self.actions.add(_make_log_toggle(dim), checkable=True, view_submenu=view_submenu, name='Toggle log scale for %s' % dim, show_shortcut=False, checked=getattr(self, '%s_log_scale' % dim)) self.actions.separator() self.actions.add(self.set_x_axis, prompt=True, prompt_default=lambda: self.x_axis) self.actions.add(self.set_y_axis, prompt=True, prompt_default=lambda: self.y_axis) self.actions.add(self.set_size, prompt=True, prompt_default=lambda: self.size) connect(self.on_select) connect(self.on_cluster) @connect(sender=self.canvas) def on_lasso_updated(sender, polygon): if len(polygon) < 3: return pos = range_transform([self.data_bounds], [NDC], self.marker_positions) ind = self.canvas.lasso.in_polygon(pos) cluster_ids = self.all_cluster_ids[ind] emit("request_select", self, list(cluster_ids)) @connect(sender=self) def on_close_view(view_, gui): """Unconnect all events when closing the view.""" unconnect(self.on_select) unconnect(self.on_cluster) unconnect(on_lasso_updated) if self.all_cluster_ids is not None: self.set_cluster_ids(self.all_cluster_ids) self._update_labels() def on_select(self, *args, **kwargs): super(ClusterScatterView, self).on_select(*args, **kwargs) self.update_select_color() def on_cluster(self, sender, up): if 'all_cluster_ids' in up: self.set_cluster_ids(up.all_cluster_ids) self.plot() @property def status(self): return 'Size: %s. Color scheme: %s.' % (self.size, self.color_scheme) # Interactivity # ------------------------------------------------------------------------- def on_mouse_click(self, e): """Select a cluster by clicking on its template waveform.""" if 'Control' in e.modifiers: return b = e.button pos = self.canvas.window_to_ndc(e.pos) marker_pos = range_transform([self.data_bounds], [NDC], self.marker_positions) cluster_rel = np.argmin(((marker_pos - pos)**2).sum(axis=1)) cluster_id = self.all_cluster_ids[cluster_rel] logger.debug("Click on cluster %d with button %s.", cluster_id, b) if 'Shift' in e.modifiers: emit('select_more', self, [cluster_id]) else: emit('request_select', self, [cluster_id])