def on_request_split_FTT(self): """Return the spikes enclosed by the lasso.""" if (self.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) #assert len(self.channel_ids) id_str = [str(id) for id in self.cluster_ids] dlg = QtGui.QInputDialog(None) [cl, ok] = QtGui.QInputDialog.getItem(dlg, "Which cluster id should be split:", "Which cluster id should be split:", id_str, 0, False) if not ok: return np.array([], dtype=np.int64) cluster_index = id_str.index(cl) bunchs = self._get_data(self.cluster_ids) pos = np.vstack((bunchs[cluster_index]['x'], bunchs[cluster_index]['y'])).transpose() # Normalize the points. ra = Range(self.data_bounds) pos = ra.apply(pos) # Find lassoed spikes. ind = self.lasso.in_polygon(pos) ids = bunchs[cluster_index]['spike_ids'][ind] self.lasso.clear() return np.unique(ids)
def on_mouse_press(self, e): key = self._key_pressed if 'Control' in e.modifiers or key in map(str, range(10)): key = int(key.name) if key in map(str, range(10)) else None # Get mouse position in NDC. mouse_pos = self.panzoom.get_mouse_pos(e.pos) box_id = self.stacked.get_closest_box(mouse_pos) channel_id = self._permute_channels(box_id, inv=True) # Find the spike and cluster closest to the mouse. db = self._data_bounds # Get the information about the displayed spikes. if not self._waveform_times: return # Get the time coordinate of the mouse position. mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = \ zip(*(_ for _ in self._waveform_times if channel_id in _[3])) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the spike_click event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] self.events.spike_click(channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id)
def __init__(self, amplitudes=None, amplitude_name=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitude_name',) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes self.amplitude_names = list(amplitudes.keys()) # Current amplitude type. self.amplitude_name = amplitude_name or self.amplitude_names[0] assert self.amplitude_name in amplitudes self.cluster_ids = () self.duration = duration or 1 # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add_on_gpu([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')]) self.canvas.add_visual(self.hist_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) # Amplitude name. self.text_visual = TextVisual() self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))
def on_mouse_click(self, e): """Select a cluster by clicking on a spike.""" if 'Control' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = np.nonzero(self.channel_y_ranks == box_id)[0] # Find the spike and cluster closest to the mouse. db = self.data_bounds # Get the information about the displayed spikes. wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch] if not wt: return # Get the time coordinate of the mouse position. mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) mouse_time = Range(NDC, db).apply(mouse_pos)[0][0] # Get the closest spike id. times, spike_ids, spike_clusters, channel_ids = zip(*wt) i = np.argmin(np.abs(np.array(times) - mouse_time)) # Raise the select_spike event. spike_id = spike_ids[i] cluster_id = spike_clusters[i] emit('select_spike', self, channel_id=channel_id, spike_id=spike_id, cluster_id=cluster_id) if 'Shift' in e.modifiers: # Get mouse position in NDC. box_id, _ = self.canvas.stacked.box_map(e.pos) channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0]) emit('select_channel', self, channel_id=channel_id, button=e.button)
def on_request_split(self, sender=None): """Return the spikes enclosed by the lasso.""" if (self.canvas.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) assert len(self.channel_ids) # Get the dimensions of the lassoed subplot. i, j = self.canvas.layout.active_box dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') # Get all points from all clusters. pos = [] spike_ids = [] for cluster_id in self.cluster_ids: # Load all spikes. bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True) px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True) points = np.c_[px.data, py.data] # Normalize the points. xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) r = Range((xmin, ymin, xmax, ymax)) points = r.apply(points) pos.append(points) spike_ids.append(bunch.spike_ids) pos = np.vstack(pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.canvas.lasso.in_polygon(pos) self.canvas.lasso.clear() return np.unique(spike_ids[ind])
def on_mouse_click(self, e): """Select a time from the amplitude view to display in the trace view.""" # from pdb import set_trace # set_trace() if 'Shift' in e.modifiers: mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos) time = Range(NDC, self.data_bounds).apply(mouse_pos)[0][0] emit('select_time', self, time)
def on_request_split(self): """Return the spikes enclosed by the lasso.""" if (self.lasso.count < 3 or not len(self.cluster_ids)): # pragma: no cover return np.array([], dtype=np.int64) assert len(self.channel_ids) # Get the dimensions of the lassoed subplot. i, j = self.lasso.box dim = self.grid_dim[i][j] dim_x, dim_y = dim.split(',') # Get all points from all clusters. pos = [] spike_ids = [] for cluster_id in self.cluster_ids: # Load all spikes. bunch = self.features(cluster_id, channel_ids=self.channel_ids, load_all=True) px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id, load_all=True) py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id, load_all=True) points = np.c_[px.data, py.data] # Normalize the points. xmin, xmax = self._get_axis_bounds(dim_x, px) ymin, ymax = self._get_axis_bounds(dim_y, py) r = Range((xmin, ymin, xmax, ymax)) points = r.apply(points) pos.append(points) spike_ids.append(bunch.spike_ids) pos = np.vstack(pos) spike_ids = np.concatenate(spike_ids) # Find lassoed spikes. ind = self.lasso.in_polygon(pos) self.lasso.clear() return np.unique(spike_ids[ind])
def _iter_channel(positions): size = 100 margin = 5 boxes = _get_boxes(positions, keep_aspect_ratio=False) xmin, ymin = boxes[:, :2].min(axis=0) xmax, ymax = boxes[:, 2:].max(axis=0) x = boxes[:, [0, 2]].mean(axis=1) y = - boxes[:, [1, 3]].mean(axis=1) positions = np.c_[x, y] tr = [margin, margin, size - margin, size - margin] positions = Range(NDC, tr).apply(positions) for x, y in positions: yield x, y
def __init__(self, amplitudes=None, amplitudes_type=None, duration=None): super(AmplitudeView, self).__init__() self.state_attrs += ('amplitudes_type', ) self.canvas.enable_axes() self.canvas.enable_lasso() # Ensure amplitudes is a dictionary, even if there is a single amplitude. if not isinstance(amplitudes, dict): amplitudes = {'amplitude': amplitudes} assert amplitudes self.amplitudes = amplitudes # Rotating property amplitudes types. self.amplitudes_types = RotatingProperty() for name, value in self.amplitudes.items(): self.amplitudes_types.add(name, value) # Current amplitudes type. self.amplitudes_types.set(amplitudes_type) assert self.amplitudes_type in self.amplitudes self.cluster_ids = () self.duration = duration or 1. # Histogram visual. self.hist_visual = HistogramVisual() self.hist_visual.transforms.add([ Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('cw'), Scale((1, -1)), Translate((2.05, 0)), ]) self.canvas.add_visual(self.hist_visual) self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1) self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0) # Yellow vertical bar showing the selected time interval. self.patch_visual = PatchVisual(primitive_type='triangle_fan') self.patch_visual.inserter.insert_vert( ''' const float MIN_INTERVAL_SIZE = 0.01; uniform float u_interval_size; ''', 'header') self.patch_visual.inserter.insert_vert( ''' gl_Position.y = pos_orig.y; // The following is used to ensure that (1) the bar width increases with the zoom level // but also (2) there is a minimum absolute width so that the bar remains visible // at low zoom levels. float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x); // HACK: the z coordinate is used to store 0 or 1, depending on whether the current // vertex is on the left or right edge of the bar. gl_Position.x += w * (-1 + 2 * int(a_position.z == 0)); ''', 'after_transforms') self.canvas.add_visual(self.patch_visual) # Scatter plot. self.visual = ScatterVisual() self.canvas.add_visual(self.visual) self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))