def _update_cluster_ids(self, to_remove=None, to_add=None): # Update the list of non-empty cluster ids. self._cluster_ids = _unique(self._spike_clusters) # Clusters to remove. if to_remove is not None: for clu in to_remove: self._spikes_per_cluster.pop(clu, None) # Clusters to add. if to_add: for clu, spk in to_add.items(): self._spikes_per_cluster[clu] = spk # If spikes_per_cluster is invalid, recompute the entire # spikes_per_cluster array. coherent = np.all( np.in1d(self._cluster_ids, sorted(self._spikes_per_cluster))) if not coherent: logger.debug( "Recompute spikes_per_cluster manually: this might take a while." ) sc = self._spike_clusters self._spikes_per_cluster = _spikes_per_cluster(sc)
def _do_assign(self, spike_ids, new_spike_clusters): """Make spike-cluster assignments after the spike selection has been extended to full clusters.""" # Ensure spike_clusters has the right shape. spike_ids = _as_array(spike_ids) if len(new_spike_clusters) == 1 and len(spike_ids) > 1: new_spike_clusters = np.ones( len(spike_ids), dtype=np.int64) * new_spike_clusters[0] old_spike_clusters = self._spike_clusters[spike_ids] assert len(spike_ids) == len(old_spike_clusters) assert len(new_spike_clusters) == len(spike_ids) # Update the spikes per cluster structure. old_clusters = _unique(old_spike_clusters) # NOTE: shortcut to a merge if this assignment is effectively a merge # i.e. if all spikes are assigned to a single cluster. # The fact that spike selection has been previously extended to # whole clusters is critical here. new_clusters = _unique(new_spike_clusters) if len(new_clusters) == 1: return self._do_merge(spike_ids, old_clusters, new_clusters[0]) # We return the UpdateInfo structure. up = _assign_update_info(spike_ids, old_spike_clusters, new_spike_clusters) # We update the new cluster id (strictly increasing during a session). self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1) # We make the assignments. self._spike_clusters[spike_ids] = new_spike_clusters # OPTIM: we update spikes_per_cluster manually. new_spc = _spikes_per_cluster(new_spike_clusters, spike_ids) self._update_cluster_ids(to_remove=old_clusters, to_add=new_spc) up.all_cluster_ids = list(self.cluster_ids) return up
def __init__(self, model): self.model = model self.dir_path = Path(model.dir_path) self.spc = _spikes_per_cluster(model.spike_clusters) self.cluster_ids = _unique(self.model.spike_clusters)
def test_feature_view(qtbot, gui, n_channels): nc = n_channels ns = 10000 features = artificial_features(ns, nc, 4) spike_clusters = artificial_spike_clusters(ns, 4) spike_times = np.linspace(0., 1., ns) spc = _spikes_per_cluster(spike_clusters) def get_spike_ids(cluster_id): return (spc[cluster_id] if cluster_id is not None else np.arange(ns)) def get_features(cluster_id=None, channel_ids=None, spike_ids=None, load_all=None): if load_all: spike_ids = spc[cluster_id] else: spike_ids = get_spike_ids(cluster_id) return Bunch( data=features[spike_ids], spike_ids=spike_ids, masks=np.random.rand(ns, nc), channel_ids=(channel_ids if channel_ids is not None else np.arange(nc)[::-1]), ) def get_time(cluster_id=None, load_all=None): return Bunch(data=spike_times[get_spike_ids(cluster_id)], lim=(0., 1.)) v = FeatureView(features=get_features, attributes={'time': get_time}) v.show() qtbot.waitForWindowShown(v.canvas) v.attach(gui) v.set_grid_dim(_get_default_grid()) v.on_select(cluster_ids=[]) v.on_select(cluster_ids=[0]) v.on_select(cluster_ids=[0, 2, 3]) v.on_select(cluster_ids=[0, 2]) assert v.status v.increase() v.decrease() v.increase_marker_size() v.decrease_marker_size() v.on_select_channel(channel_id=3, button='Left', key=None) v.on_select_channel(channel_id=3, button='Right', key=None) v.on_select_channel(channel_id=3, button='Right', key=2) v.clear_channels() v.toggle_automatic_channel_selection(True) # Test feature selection with Alt+click. _l = [] @connect(sender=v) def on_select_feature(sender, dim=None, channel_id=None, pc=None): _l.append((dim, channel_id, pc)) for i, j, dim_x, dim_y in v._iter_subplots(): for k, button in enumerate(('Left', 'Right')): # Click on the center of every subplot. w, h = v.canvas.get_size() w, h = w / 4, h / 4 x, y = w / 2, h / 2 mouse_click(qtbot, v.canvas, (x + j * w, y + i * h), button=button, modifiers=('Alt', )) assert _l[-1][0] == v.grid_dim[i][j].split(',')[k] # Split without selection. spike_ids = v.on_request_split() assert len(spike_ids) == 0 a, b = 10, 100 mouse_click(qtbot, v.canvas, (a, a), modifiers=('Control', )) mouse_click(qtbot, v.canvas, (a, b), modifiers=('Control', )) mouse_click(qtbot, v.canvas, (b, b), modifiers=('Control', )) mouse_click(qtbot, v.canvas, (b, a), modifiers=('Control', )) # Split lassoed points. spike_ids = v.on_request_split() # HACK: this seems to fail because qtbot.mouseClick is not working?? # assert len(spike_ids) > 0 v.set_state(v.state) _stop_and_close(qtbot, v)