def create_threads(self): # Create the external threads. self.tasks = ThreadedTasks() self.tasks.selection_task.set_loader(self.loader) self.tasks.selection_task.selectionDone.connect( self.selection_done_callback) self.tasks.correlograms_task.correlogramsComputed.connect( self.correlograms_computed_callback) self.tasks.similarity_matrix_task.correlationMatrixComputed.connect( self.similarity_matrix_computed_callback)
def create_threads(self): # Create the external threads. self.tasks = ThreadedTasks() self.tasks.selection_task.set_loader(self.loader) self.tasks.selection_task.selectionDone.connect(self.selection_done_callback) self.tasks.recluster_task.reclusterDone.connect(self.recluster_done_callback) self.tasks.correlograms_task.correlogramsComputed.connect(self.correlograms_computed_callback) self.tasks.similarity_matrix_task.correlationMatrixComputed.connect(self.similarity_matrix_computed_callback)
class TaskGraph(AbstractTaskGraph): def __init__(self, mainwindow): # Shortcuts for the main window. self.set(mainwindow) # Create external threads/processes for long-lasting tasks. self.create_threads() def set(self, mainwindow): # Shortcuts for the main window. self.mainwindow = mainwindow self.get_view = self.mainwindow.get_view self.get_views = self.mainwindow.get_views self.loader = self.mainwindow.loader self.wizard = self.mainwindow.wizard self.controller = self.mainwindow.controller self.statscache = self.mainwindow.statscache def create_threads(self): # Create the external threads. self.tasks = ThreadedTasks() self.tasks.selection_task.set_loader(self.loader) self.tasks.selection_task.selectionDone.connect( self.selection_done_callback) self.tasks.correlograms_task.correlogramsComputed.connect( self.correlograms_computed_callback) self.tasks.similarity_matrix_task.correlationMatrixComputed.connect( self.similarity_matrix_computed_callback) def join(self): self.tasks.join() # Selection. # ---------- def _select(self, clusters, wizard=False): self.tasks.selection_task.select(clusters, wizard) def _select_done(self, clusters, wizard=False): if wizard: target = (self.wizard.current_target(),) else: target = () # self.loader.select(clusters=clusters) log.debug("Selected clusters {0:s}.".format(str(clusters))) return [ ('_update_feature_view', target), ('_update_waveform_view', (), dict(wizard=wizard)), ('_show_selection_in_matrix', (clusters,)), ('_compute_correlograms', (clusters,),), ] def _select_in_cluster_view(self, clusters, groups=[], wizard=False): self.get_view('ClusterView').select(clusters, groups=groups, wizard=wizard) # Callbacks. # ---------- def selection_done_callback(self, clusters, wizard): self.select_done(clusters, wizard=wizard) def correlograms_computed_callback(self, clusters, correlograms, ncorrbins, corrbin): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.correlograms_computed(clusters, correlograms, ncorrbins, corrbin) def similarity_matrix_computed_callback(self, clusters_selected, matrix, clusters, cluster_groups, target_next=None): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.similarity_matrix_computed(clusters_selected, matrix, clusters, cluster_groups, target_next=target_next) # Computations. # ------------- def _compute_correlograms(self, clusters_selected): # Get the correlograms parameters. spiketimes = get_array(self.loader.get_spiketimes('all')) # Make a copy of the array so that it does not change before the # computation of the correlograms begins. clusters = np.array(get_array(self.loader.get_clusters('all'))) # corrbin = self.loader.corrbin # ncorrbins = self.loader.ncorrbins corrbin = SETTINGS.get('correlograms.corrbin', .001) ncorrbins = SETTINGS.get('correlograms.ncorrbins', 100) # Get cluster indices that need to be updated. clusters_to_update = (self.statscache.correlograms. not_in_key_indices(clusters_selected)) # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: # Set wait cursor. self.mainwindow.set_busy(computing_correlograms=True) # Launch the task. self.tasks.correlograms_task.compute(spiketimes, clusters, clusters_to_update=clusters_to_update, clusters_selected=clusters_selected, ncorrbins=ncorrbins, corrbin=corrbin) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: # self.update_correlograms_view() return '_update_correlograms_view' def _compute_similarity_matrix(self, target_next=None): similarity_measure = self.loader.similarity_measure features = self.loader.background_features masks = self.loader.background_masks clusters = get_array(self.loader.get_clusters( spikes=self.loader.background_spikes)) cluster_groups = get_array(self.loader.get_cluster_groups('all')) clusters_all = self.loader.get_clusters_unique() # Get cluster indices that need to be updated. # if clusters_to_update is None: # NOTE: not specifying explicitely clusters_to_update ensures that # all clusters that need to be updated are updated. # Allows to fix a bug where the matrix is not updated correctly # when multiple calls to this functions are called quickly. clusters_to_update = (self.statscache.similarity_matrix. not_in_key_indices(clusters_all)) log.debug("Clusters to update: {0:s}".format(str(clusters_to_update))) # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: self.mainwindow.set_busy(computing_matrix=True) # Launch the task. self.tasks.similarity_matrix_task.compute(features, clusters, cluster_groups, masks, clusters_to_update, target_next=target_next, similarity_measure=similarity_measure) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: return [('_wizard_update', (target_next,)), ('_update_similarity_matrix_view',), ] def _correlograms_computed(self, clusters, correlograms, ncorrbins, corrbin): clusters_selected = self.loader.get_clusters_selected() # Abort if the selection has changed during the computation of the # correlograms. # Reset the cursor. self.mainwindow.set_busy(computing_correlograms=False) if not np.array_equal(clusters, clusters_selected): log.debug("Skip update correlograms with clusters selected={0:s}" " and clusters updated={1:s}.".format(clusters_selected, clusters)) return if self.statscache.ncorrbins != ncorrbins: log.debug(("Skip updating correlograms because ncorrbins has " "changed (from {0:d} to {1:d})".format( ncorrbins, self.statscache.ncorrbins))) return # Put the computed correlograms in the cache. self.statscache.correlograms.update(clusters, correlograms) # Update the view. # self.update_correlograms_view() return '_update_correlograms_view' def _similarity_matrix_computed(self, clusters_selected, matrix, clusters, cluster_groups, target_next=None): self.mainwindow.set_busy(computing_matrix=False) # spikes_slice = _get_similarity_matrix_slice( # self.loader.nspikes, # len(self.loader.get_clusters_unique())) clusters_now = self.loader.get_clusters( spikes=self.loader.background_spikes) if not np.array_equal(clusters, clusters_now): return False self.statscache.similarity_matrix.update(clusters_selected, matrix) self.statscache.similarity_matrix_normalized = normalize( self.statscache.similarity_matrix.to_array(copy=True)) # Update the cluster view with cluster quality. quality = np.diag(self.statscache.similarity_matrix_normalized) self.statscache.cluster_quality = pd.Series( quality, index=self.statscache.similarity_matrix.indices, ) self.get_view('ClusterView').set_quality( self.statscache.cluster_quality) return [('_wizard_update', (target_next,)), ('_update_similarity_matrix_view',), ] def _invalidate(self, clusters): self.statscache.invalidate(clusters) # View updates. # ------------- def _update_correlograms_view(self): data = vd.get_correlogramsview_data(self.loader, self.statscache) [view.set_data(**data) for view in self.get_views('CorrelogramsView')] def _update_similarity_matrix_view(self): data = vd.get_similaritymatrixview_data(self.loader, self.statscache) [view.set_data(**data) for view in self.get_views('SimilarityMatrixView')] # Show selected clusters when the matrix has been updated. clusters = self.loader.get_clusters_selected() return ('_show_selection_in_matrix', (clusters,)) def _update_feature_view(self, autozoom=None): data = vd.get_featureview_data(self.loader, autozoom=autozoom) [view.set_data(**data) for view in self.get_views('FeatureView')] def _update_waveform_view(self, autozoom=None, wizard=None): data = vd.get_waveformview_data(self.loader, autozoom=autozoom, wizard=wizard) [view.set_data(**data) for view in self.get_views('WaveformView')] def _update_trace_view(self): data = vd.get_traceview_data(self.loader) [view.set_data(**data) for view in self.get_views('TraceView')] def _update_cluster_view(self, clusters=None): """Update the cluster view using the data stored in the loader object.""" data = vd.get_clusterview_data(self.loader, self.statscache, clusters=clusters) self.get_view('ClusterView').set_data(**data) if clusters is not None: return def _show_selection_in_matrix(self, clusters): if clusters is not None and 1 <= len(clusters) <= 2: [view.show_selection(clusters[0], clusters[-1]) for view in self.get_views('SimilarityMatrixView')] # Override colors. # ---------------- def _override_color(self, override_color): self.loader.set_override_color(override_color) return ['_update_feature_view', '_update_waveform_view', '_update_correlograms_view'] # Change correlograms parameter. # ------------------------------ def _change_correlograms_parameters(self, ncorrbins=None, corrbin=None): # Update the correlograms parameters. if ncorrbins is not None: SETTINGS['correlograms.ncorrbins'] = ncorrbins if corrbin is not None: SETTINGS['correlograms.corrbin'] = corrbin # Reset the cache. self.statscache.reset(ncorrbins) # Update the correlograms. clusters = self.loader.get_clusters_selected() return ('_compute_correlograms', (clusters,)) # Merge/split actions. # -------------------- def _merge(self, clusters, wizard=False): if len(clusters) >= 2: action, output = self.controller.merge_clusters(clusters) # Tell the next nodes whether the merge occurred after a wizard # selection or not, so that the merged cluster background is # highlighted or not. output['wizard'] = wizard return after_merge(output) def _split(self, clusters, spikes_selected, wizard=False): if len(spikes_selected) >= 1: action, output = self.controller.split_clusters(clusters, spikes_selected) output['wizard'] = wizard return after_split(output) def _undo(self, wizard=False): undo = self.controller.undo() if undo is None: return action, output = undo output['wizard'] = wizard if action == 'merge_clusters_undo': return after_merge_undo(output) elif action == 'split_clusters_undo': return after_split_undo(output) elif action == 'change_cluster_color_undo': return after_cluster_color_changed_undo(output) elif action == 'change_group_color_undo': return after_group_color_changed(output) elif action == 'move_clusters_undo': return after_clusters_moved_undo(output) elif action == 'add_group_undo': return after_group_added(output) elif action == 'rename_group_undo': return after_group_renamed(output) elif action == 'remove_group_undo': return after_group_removed(output) def _redo(self, wizard=False): redo = self.controller.redo() if redo is None: return action, output = redo output['wizard'] = wizard if action == 'merge_clusters': return after_merge(output) elif action == 'split_clusters': return after_split(output) elif action == 'change_cluster_color': return after_cluster_color_changed(output) elif action == 'change_group_color': return after_group_color_changed(output) elif action == 'move_clusters': return after_clusters_moved(output) elif action == 'add_group': return after_group_added(output) elif action == 'rename_group': return after_group_renamed(output) elif action == 'remove_group': return after_group_removed(output) # Other actions. # -------------- def _cluster_color_changed(self, cluster, color, wizard=True): action, output = self.controller.change_cluster_color(cluster, color) # if cluster == self.wizard.current_target(): output['wizard'] = wizard return after_cluster_color_changed(output) def _group_color_changed(self, group, color): action, output = self.controller.change_group_color(group, color) return after_group_color_changed(output) def _group_renamed(self, group, name): action, output = self.controller.rename_group(group, name) return after_group_renamed(output) def _clusters_moved(self, clusters, group, wizard=False,): action, output = self.controller.move_clusters(clusters, group) output['wizard'] = wizard return after_clusters_moved(output) def _group_removed(self, group): action, output = self.controller.remove_group(group) return after_group_removed(output) def _group_added(self, group, name, color): action, output = self.controller.add_group(group, name, color) return after_group_added(output) # Wizard. # ------- def _wizard_update(self, target=None, update_matrix=True): if update_matrix: self.wizard.set_data( cluster_groups=self.loader.get_cluster_groups('all'), similarity_matrix=self.statscache.similarity_matrix_normalized, ) else: self.wizard.set_data( cluster_groups=self.loader.get_cluster_groups('all'), ) self.wizard.update_candidates(target) def _wizard_change_color(self, clusters): if clusters is not None: # Set the background color in the cluster view for the wizard # target and candidate. self.get_view('ClusterView').set_background( {cluster: {0: 'target', 1: 'candidate'}.get(i, None) for i, cluster in enumerate(clusters[:2])}) def _wizard_change_candidate_color(self): candidate = self.wizard.current_candidate() target = self.wizard.current_target() # color = self.loader.get_cluster_color(candidate) return ('_cluster_color_changed', (candidate, random_color(),)) def _wizard_show_pair(self, target=None, candidate=None): if target is None: target = (self.wizard.current_target(), self.loader.get_cluster_color(self.wizard.current_target())) if candidate is None: candidate = (self.wizard.current_candidate(), self.loader.get_cluster_color(self.wizard.current_candidate())) [view.set_wizard_pair(target, candidate) for view in self.get_views('FeatureView')] # Navigation. def _wizard_reset(self): clusters = self.wizard.reset() return ['_wizard_update', '_wizard_current_candidate'] def _wizard_previous_candidate(self): clusters = self.wizard.previous_pair() return after_wizard_selection(clusters) def _wizard_current_candidate(self): clusters = self.wizard.current_pair() return after_wizard_selection(clusters) def _wizard_next_candidate(self): clusters = self.wizard.next_pair() return after_wizard_selection(clusters) def _wizard_skip_target(self): # Skip the current target and go the next target. self.wizard.skip_target() return [('_wizard_update', ()), ('_wizard_next_candidate',),] def _wizard_reset_skipped(self): self.wizard.reset_skipped() # Control. def _wizard_move_and_next(self, what, group): """Move target, candidate, or both, to a given group, and go to the next proposition.""" # Current proposition. clusters = self.wizard.current_pair() if clusters is None: return target, candidate = clusters # Select appropriate clusters to move. if what == 'candidate': clusters = [candidate] # Keep the current target. target_next = target reset_skipped = False elif what == 'target': clusters = [target] # Go to the next best target cluster. target_next = None reset_skipped = True elif what == 'both': clusters = [candidate, target] # Go to the next best target cluster. target_next = None reset_skipped = True # Move clusters, and select next proposition. r = [('_clusters_moved', (clusters, group, True)), ] if reset_skipped: r += [('_wizard_reset_skipped',),] r += [('_wizard_update', (target_next,)), ('_wizard_next_candidate',), ] return r
class TaskGraph(AbstractTaskGraph): def __init__(self, mainwindow): # Shortcuts for the main window. self.set(mainwindow) # Create external threads/processes for long-lasting tasks. self.create_threads() def set(self, mainwindow): # Shortcuts for the main window. self.mainwindow = mainwindow self.get_view = self.mainwindow.get_view self.get_views = self.mainwindow.get_views self.loader = self.mainwindow.loader self.experiment = self.loader.experiment self.wizard = self.mainwindow.wizard self.controller = self.mainwindow.controller self.statscache = self.mainwindow.statscache def create_threads(self): # Create the external threads. self.tasks = ThreadedTasks() self.tasks.selection_task.set_loader(self.loader) self.tasks.selection_task.selectionDone.connect(self.selection_done_callback) self.tasks.recluster_task.reclusterDone.connect(self.recluster_done_callback) self.tasks.correlograms_task.correlogramsComputed.connect(self.correlograms_computed_callback) self.tasks.similarity_matrix_task.correlationMatrixComputed.connect(self.similarity_matrix_computed_callback) def join(self): self.tasks.join() # Selection. # ---------- def _select(self, clusters, wizard=False): self.tasks.selection_task.select(clusters, wizard) def _select_done(self, clusters, wizard=False): if wizard: target = (self.wizard.current_target(),) else: target = () # self.loader.select(clusters=clusters) log.debug("Selected clusters {0:s}.".format(str(clusters))) return [ ("_update_feature_view", target, dict()), ("_update_waveform_view", (), dict(wizard=wizard)), ("_show_selection_in_matrix", (clusters,)), ("_compute_correlograms", (clusters,), dict(wizard=wizard)), ] def _select_in_cluster_view(self, clusters, groups=[], wizard=False): self.get_view("ClusterView").select(clusters, groups=groups, wizard=wizard) # Callbacks. # ---------- def selection_done_callback(self, clusters, wizard): self.select_done(clusters, wizard=wizard) def recluster_done_callback(self, channel_group, clusters, spikes, clu, wizard): self.recluster_done(channel_group=channel_group, clusters=clusters, spikes=spikes, clu=clu, wizard=wizard) def correlograms_computed_callback(self, clusters, correlograms, ncorrbins, corrbin, sample_rate, wizard): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.correlograms_computed(clusters, correlograms, ncorrbins, corrbin, sample_rate, wizard) def similarity_matrix_computed_callback( self, clusters_selected, matrix, clusters, cluster_groups, target_next=None ): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.similarity_matrix_computed(clusters_selected, matrix, clusters, cluster_groups, target_next=target_next) # Computations. # ------------- def _compute_correlograms(self, clusters_selected, wizard=None): # Get the correlograms parameters. spiketimes = get_array(self.loader.get_spiketimes("all")) sample_rate = self.loader.freq # print spiketimes.dtype # Make a copy of the array so that it does not change before the # computation of the correlograms begins. clusters = np.array(get_array(self.loader.get_clusters("all"))) # Get excerpts nexcerpts = USERPREF.get("correlograms_nexcerpts", 50) excerpt_size = USERPREF.get("correlograms_excerpt_size", 10000) spiketimes_excerpts = get_excerpts(spiketimes, nexcerpts=nexcerpts, excerpt_size=excerpt_size) clusters_excerpts = get_excerpts(clusters, nexcerpts=nexcerpts, excerpt_size=excerpt_size) # corrbin = self.loader.corrbin # ncorrbins = self.loader.ncorrbins corrbin = SETTINGS.get("correlograms.corrbin", 0.001) ncorrbins = SETTINGS.get("correlograms.ncorrbins", 101) # Ensure ncorrbins is odd. if ncorrbins % 2 == 0: ncorrbins += 1 # Get cluster indices that need to be updated. # clusters_to_update = (self.statscache.correlograms. # not_in_key_indices(clusters_selected)) clusters_to_update = clusters_selected # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: # Set wait cursor. self.mainwindow.set_busy(computing_correlograms=True) # Launch the task. self.tasks.correlograms_task.compute( spiketimes_excerpts, clusters_excerpts, clusters_to_update=clusters_to_update, clusters_selected=clusters_selected, ncorrbins=ncorrbins, corrbin=corrbin, sample_rate=sample_rate, wizard=wizard, ) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: # self.update_correlograms_view() return ("_update_correlograms_view", (wizard,), {}) def _recluster(self): exp = self.loader.experiment channel_group = self.loader.shank clusters_selected = self.loader.get_clusters_selected() self.tasks.recluster_task.recluster(exp, channel_group=channel_group, clusters=clusters_selected) def _recluster_done(self, channel_group=0, clusters=None, spikes=None, clu=None, wizard=False): return [("_split2", (spikes, clu, wizard))] def _compute_similarity_matrix(self, target_next=None): exp = self.experiment channel_group = self.loader.shank clustering = "main" # TODO fetdim = exp.application_data.spikedetekt.n_features_per_channel clusters_data = getattr(exp.channel_groups[channel_group].clusters, clustering) spikes_data = exp.channel_groups[channel_group].spikes cluster_groups_data = getattr(exp.channel_groups[channel_group].cluster_groups, clustering) clusters_all = sorted(clusters_data.keys()) cluster_groups = pd.Series([clusters_data[cl].cluster_group or 0 for cl in clusters_all], index=clusters_all) spikes_selected, fm = spikes_data.load_features_masks(fraction=0.1) clusters = getattr(spikes_data.clusters, clustering)[:][spikes_selected] fm = np.atleast_3d(fm) features = fm[:, :, 0] if features.shape[1] <= 1: return [] # masks = fm[:, ::fetdim, 1] if fm.shape[2] > 1: masks = fm[:, :, 1] else: masks = None # features = pandaize(features, spikes_selected) # masks = pandaize(masks, spikes_selected) # Get cluster indices that need to be updated. # if clusters_to_update is None: # NOTE: not specifying explicitely clusters_to_update ensures that # all clusters that need to be updated are updated. # Allows to fix a bug where the matrix is not updated correctly # when multiple calls to this functions are called quickly. clusters_to_update = self.statscache.similarity_matrix.not_in_key_indices(clusters_all) log.debug("Clusters to update: {0:s}".format(str(clusters_to_update))) # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: self.mainwindow.set_busy(computing_matrix=True) # Launch the task. self.tasks.similarity_matrix_task.compute( features, clusters, cluster_groups, masks, clusters_to_update, target_next=target_next, similarity_measure=None, ) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: return [("_wizard_update", (target_next,)), ("_update_similarity_matrix_view",)] def _correlograms_computed(self, clusters, correlograms, ncorrbins, corrbin, sample_rate, wizard): clusters_selected = self.loader.get_clusters_selected() # Abort if the selection has changed during the computation of the # correlograms. # Reset the cursor. self.mainwindow.set_busy(computing_correlograms=False) if not np.array_equal(clusters, clusters_selected): log.debug( "Skip update correlograms with clusters selected={0:s}" " and clusters updated={1:s}.".format(clusters_selected, clusters) ) return if self.statscache.ncorrbins != ncorrbins: log.debug( ( "Skip updating correlograms because ncorrbins has " "changed (from {0:d} to {1:d})".format(ncorrbins, self.statscache.ncorrbins) ) ) return # Put the computed correlograms in the cache. self.statscache.correlograms.update(clusters, correlograms) # Update the view. # self.update_correlograms_view() return ("_update_correlograms_view", (), dict(wizard=wizard)) def _similarity_matrix_computed(self, clusters_selected, matrix, clusters, cluster_groups, target_next=None): self.mainwindow.set_busy(computing_matrix=False) # spikes_slice = _get_similarity_matrix_slice( # self.loader.nspikes, # len(self.loader.get_clusters_unique())) # clusters_now = self.loader.get_clusters( # spikes=self.loader.background_spikes) # if not np.array_equal(clusters, clusters_now): # return False if len(matrix) == 0: return [] self.statscache.similarity_matrix.update(clusters_selected, matrix) self.statscache.similarity_matrix_normalized = normalize(self.statscache.similarity_matrix.to_array(copy=True)) # Update the cluster view with cluster quality. quality = np.diag(self.statscache.similarity_matrix_normalized).copy() self.statscache.cluster_quality = pd.Series(quality, index=self.statscache.similarity_matrix.indices) self.get_view("ClusterView").set_quality(self.statscache.cluster_quality) return [("_wizard_update", (target_next,)), ("_update_similarity_matrix_view",)] def _invalidate(self, clusters): self.statscache.invalidate(clusters) # View updates. # ------------- def _update_correlograms_view(self, wizard=None): clu = self.loader.get_clusters_selected() # HACK: work around a bug with some GPU drivers and empty selections if len(clu) == 0: return data = vd.get_correlogramsview_data( self.experiment, self.statscache.correlograms, clusters=clu, channel_group=self.loader.shank, wizard=wizard ) [view.set_data(**data) for view in self.get_views("CorrelogramsView")] def _update_similarity_matrix_view(self): data = vd.get_similaritymatrixview_data( self.experiment, self.statscache.similarity_matrix_normalized, channel_group=self.loader.shank ) [view.set_data(**data) for view in self.get_views("SimilarityMatrixView")] # Show selected clusters when the matrix has been updated. clusters = self.loader.get_clusters_selected() return ("_show_selection_in_matrix", (clusters,)) def _update_feature_view(self, autozoom=None): clu = self.loader.clusters_selected # HACK: work around a bug with some GPU drivers and empty selections if len(clu) == 0: return data = vd.get_featureview_data( self.experiment, clusters=clu, autozoom=autozoom, channel_group=self.loader.shank ) [view.set_data(**data) for view in self.get_views("FeatureView")] def _update_waveform_view(self, autozoom=None, wizard=None): clu = self.loader.clusters_selected # HACK: work around a bug with some GPU drivers and empty selections if len(clu) == 0: return data = vd.get_waveformview_data( self.experiment, clusters=clu, autozoom=autozoom, wizard=wizard, channel_group=self.loader.shank ) [view.set_data(**data) for view in self.get_views("WaveformView")] def _update_trace_view(self): data = vd.get_traceview_data(self.experiment, channel_group=self.loader.shank) [view.set_data(**data) for view in self.get_views("TraceView")] def _update_cluster_view(self, clusters=None): """Update the cluster view using the data stored in the loader object.""" data = vd.get_clusterview_data(self.experiment, self.statscache, channel_group=self.loader.shank) self.get_view("ClusterView").set_data(**data) if clusters is not None: return def _show_selection_in_matrix(self, clusters): if clusters is not None and 1 <= len(clusters) <= 2: [view.show_selection(clusters[0], clusters[-1]) for view in self.get_views("SimilarityMatrixView")] # Override colors. # ---------------- def _override_color(self, override_color): self.loader.set_override_color(override_color) return ["_update_feature_view", "_update_waveform_view", "_update_correlograms_view"] # Change correlograms parameter. # ------------------------------ def _change_correlograms_parameters(self, ncorrbins=None, corrbin=None): if ncorrbins % 2 == 0: ncorrbins += 1 # Update the correlograms parameters. if ncorrbins is not None: SETTINGS["correlograms.ncorrbins"] = ncorrbins if corrbin is not None: SETTINGS["correlograms.corrbin"] = corrbin # Reset the cache. self.statscache.reset(ncorrbins) # Update the correlograms. clusters = self.loader.get_clusters_selected() return ("_compute_correlograms", (clusters,)) # Merge/split actions. # -------------------- def _merge(self, clusters, wizard=False): if len(clusters) >= 2: action, output = self.controller.merge_clusters(clusters) # Tell the next nodes whether the merge occurred after a wizard # selection or not, so that the merged cluster background is # highlighted or not. output["wizard"] = wizard return after_merge(output) def _split(self, clusters, spikes_selected, wizard=False): if len(spikes_selected) >= 1: action, output = self.controller.split_clusters(clusters, spikes_selected) output["wizard"] = wizard return after_split(output) def _split2(self, spikes, clusters, wizard=False): if len(spikes) >= 1: action, output = self.controller.split2_clusters(spikes, clusters) output["wizard"] = wizard return after_split(output) def _undo(self, wizard=False): undo = self.controller.undo() if undo is None: return action, output = undo output["wizard"] = wizard if action == "merge_clusters_undo": return after_merge_undo(output) elif action == "split_clusters_undo": return after_split_undo(output) elif action == "split2_clusters_undo": return after_split_undo(output) elif action == "change_cluster_color_undo": return after_cluster_color_changed_undo(output) elif action == "change_group_color_undo": return after_group_color_changed(output) elif action == "move_clusters_undo": return after_clusters_moved_undo(output) elif action == "add_group_undo": return after_group_added(output) elif action == "rename_group_undo": return after_group_renamed(output) elif action == "remove_group_undo": return after_group_removed(output) def _redo(self, wizard=False): redo = self.controller.redo() if redo is None: return action, output = redo output["wizard"] = wizard if action == "merge_clusters": return after_merge(output) elif action == "split_clusters": return after_split(output) elif action == "split2_clusters": return after_split(output) elif action == "change_cluster_color": return after_cluster_color_changed(output) elif action == "change_group_color": return after_group_color_changed(output) elif action == "move_clusters": return after_clusters_moved(output) elif action == "add_group": return after_group_added(output) elif action == "rename_group": return after_group_renamed(output) elif action == "remove_group": return after_group_removed(output) # Other actions. # -------------- def _cluster_color_changed(self, cluster, color, wizard=True): action, output = self.controller.change_cluster_color(cluster, color) # if cluster == self.wizard.current_target(): output["wizard"] = wizard return after_cluster_color_changed(output) def _group_color_changed(self, group, color): action, output = self.controller.change_group_color(group, color) return after_group_color_changed(output) def _group_renamed(self, group, name): action, output = self.controller.rename_group(group, name) return after_group_renamed(output) def _clusters_moved(self, clusters, group, wizard=False): action, output = self.controller.move_clusters(clusters, group) output["wizard"] = wizard return after_clusters_moved(output) def _group_removed(self, group): action, output = self.controller.remove_group(group) return after_group_removed(output) def _group_added(self, group, name, color): action, output = self.controller.add_group(group, name, color) return after_group_added(output) # Wizard. # ------- def _wizard_update(self, target=None, update_matrix=True): if update_matrix: self.wizard.set_data( cluster_groups=self.loader.get_cluster_groups("all"), similarity_matrix=self.statscache.similarity_matrix_normalized, ) else: self.wizard.set_data(cluster_groups=self.loader.get_cluster_groups("all")) self.wizard.update_candidates(target) def _wizard_change_color(self, clusters): if clusters is not None: # Set the background color in the cluster view for the wizard # target and candidate. self.get_view("ClusterView").set_background( {cluster: {0: "target", 1: "candidate"}.get(i, None) for i, cluster in enumerate(clusters[:2])} ) def _wizard_change_candidate_color(self): candidate = self.wizard.current_candidate() target = self.wizard.current_target() # color = self.loader.get_cluster_color(candidate) return ("_cluster_color_changed", (candidate, random_color())) def _wizard_show_pair(self, target=None, candidate=None): if target is None: target = (self.wizard.current_target(), self.loader.get_cluster_color(self.wizard.current_target())) if candidate is None: try: candidate = ( self.wizard.current_candidate(), get_array(self.loader.get_cluster_color(self.wizard.current_candidate()))[0], ) # HACK: this can fail because when merging clusters, the merged # cluster (candidate) is deleted, and its color does not exist # anymore. except: candidate = (self.wizard.current_candidate(), 0) [view.set_wizard_pair(target, candidate) for view in self.get_views("FeatureView")] # Navigation. def _wizard_reset(self): clusters = self.wizard.reset() return ["_wizard_update", "_wizard_current_candidate"] def _wizard_previous_candidate(self): clusters = self.wizard.previous_pair() return after_wizard_selection(clusters) def _wizard_current_candidate(self): clusters = self.wizard.current_pair() return after_wizard_selection(clusters) def _wizard_next_candidate(self): clusters = self.wizard.next_pair() return after_wizard_selection(clusters) def _wizard_skip_target(self): # Skip the current target and go the next target. self.wizard.skip_target() return [("_wizard_update", ()), ("_wizard_next_candidate",)] def _wizard_reset_skipped(self): self.wizard.reset_skipped() # Control. def _wizard_move_and_next(self, what, group): """Move target, candidate, or both, to a given group, and go to the next proposition.""" # Current proposition. clusters = self.wizard.current_pair() if clusters is None: return target, candidate = clusters # Select appropriate clusters to move. if what == "candidate": clusters = [candidate] # Keep the current target. target_next = target reset_skipped = False elif what == "target": clusters = [target] # Go to the next best target cluster. target_next = None reset_skipped = True elif what == "both": clusters = [candidate, target] # Go to the next best target cluster. target_next = None reset_skipped = True # Move clusters, and select next proposition. r = [("_clusters_moved", (clusters, group, True))] if reset_skipped: r += [("_wizard_reset_skipped",)] r += [("_wizard_update", (target_next,)), ("_wizard_next_candidate",)] return r
class TaskGraph(AbstractTaskGraph): def __init__(self, mainwindow): # Shortcuts for the main window. self.set(mainwindow) # Create external threads/processes for long-lasting tasks. self.create_threads() def set(self, mainwindow): # Shortcuts for the main window. self.mainwindow = mainwindow self.get_view = self.mainwindow.get_view self.get_views = self.mainwindow.get_views self.loader = self.mainwindow.loader self.wizard = self.mainwindow.wizard self.controller = self.mainwindow.controller self.statscache = self.mainwindow.statscache def create_threads(self): # Create the external threads. self.tasks = ThreadedTasks() self.tasks.selection_task.set_loader(self.loader) self.tasks.selection_task.selectionDone.connect( self.selection_done_callback) self.tasks.correlograms_task.correlogramsComputed.connect( self.correlograms_computed_callback) self.tasks.similarity_matrix_task.correlationMatrixComputed.connect( self.similarity_matrix_computed_callback) def join(self): self.tasks.join() # Selection. # ---------- def _select(self, clusters, wizard=False): self.tasks.selection_task.select(clusters, wizard) def _select_done(self, clusters, wizard=False): if wizard: target = (self.wizard.current_target(), ) else: target = () # self.loader.select(clusters=clusters) log.debug("Selected clusters {0:s}.".format(str(clusters))) return [ ('_update_feature_view', target), ('_update_waveform_view', (), dict(wizard=wizard)), ('_show_selection_in_matrix', (clusters, )), ( '_compute_correlograms', (clusters, ), ), ] def _select_in_cluster_view(self, clusters, groups=[], wizard=False): self.get_view('ClusterView').select(clusters, groups=groups, wizard=wizard) # Callbacks. # ---------- def selection_done_callback(self, clusters, wizard): self.select_done(clusters, wizard=wizard) def correlograms_computed_callback(self, clusters, correlograms, ncorrbins, corrbin): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.correlograms_computed(clusters, correlograms, ncorrbins, corrbin) def similarity_matrix_computed_callback(self, clusters_selected, matrix, clusters, cluster_groups, target_next=None): # Execute the callback function under the control of the task manager # (which handles the graph dependency). self.similarity_matrix_computed(clusters_selected, matrix, clusters, cluster_groups, target_next=target_next) # Computations. # ------------- def _compute_correlograms(self, clusters_selected): # Get the correlograms parameters. spiketimes = get_array(self.loader.get_spiketimes('all')) # Make a copy of the array so that it does not change before the # computation of the correlograms begins. clusters = np.array(get_array(self.loader.get_clusters('all'))) # corrbin = self.loader.corrbin # ncorrbins = self.loader.ncorrbins corrbin = SETTINGS.get('correlograms.corrbin', .001) ncorrbins = SETTINGS.get('correlograms.ncorrbins', 100) # Get cluster indices that need to be updated. clusters_to_update = ( self.statscache.correlograms.not_in_key_indices(clusters_selected)) # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: # Set wait cursor. self.mainwindow.set_busy(computing_correlograms=True) # Launch the task. self.tasks.correlograms_task.compute( spiketimes, clusters, clusters_to_update=clusters_to_update, clusters_selected=clusters_selected, ncorrbins=ncorrbins, corrbin=corrbin) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: # self.update_correlograms_view() return '_update_correlograms_view' def _compute_similarity_matrix(self, target_next=None): similarity_measure = self.loader.similarity_measure features = self.loader.background_features masks = self.loader.background_masks clusters = get_array( self.loader.get_clusters(spikes=self.loader.background_spikes)) cluster_groups = get_array(self.loader.get_cluster_groups('all')) clusters_all = self.loader.get_clusters_unique() # Get cluster indices that need to be updated. # if clusters_to_update is None: # NOTE: not specifying explicitely clusters_to_update ensures that # all clusters that need to be updated are updated. # Allows to fix a bug where the matrix is not updated correctly # when multiple calls to this functions are called quickly. clusters_to_update = ( self.statscache.similarity_matrix.not_in_key_indices(clusters_all)) log.debug("Clusters to update: {0:s}".format(str(clusters_to_update))) # If there are pairs that need to be updated, launch the task. if len(clusters_to_update) > 0: self.mainwindow.set_busy(computing_matrix=True) # Launch the task. self.tasks.similarity_matrix_task.compute( features, clusters, cluster_groups, masks, clusters_to_update, target_next=target_next, similarity_measure=similarity_measure) # Otherwise, update directly the correlograms view without launching # the task in the external process. else: return [ ('_wizard_update', (target_next, )), ('_update_similarity_matrix_view', ), ] def _correlograms_computed(self, clusters, correlograms, ncorrbins, corrbin): clusters_selected = self.loader.get_clusters_selected() # Abort if the selection has changed during the computation of the # correlograms. # Reset the cursor. self.mainwindow.set_busy(computing_correlograms=False) if not np.array_equal(clusters, clusters_selected): log.debug("Skip update correlograms with clusters selected={0:s}" " and clusters updated={1:s}.".format( clusters_selected, clusters)) return if self.statscache.ncorrbins != ncorrbins: log.debug(("Skip updating correlograms because ncorrbins has " "changed (from {0:d} to {1:d})".format( ncorrbins, self.statscache.ncorrbins))) return # Put the computed correlograms in the cache. self.statscache.correlograms.update(clusters, correlograms) # Update the view. # self.update_correlograms_view() return '_update_correlograms_view' def _similarity_matrix_computed(self, clusters_selected, matrix, clusters, cluster_groups, target_next=None): self.mainwindow.set_busy(computing_matrix=False) # spikes_slice = _get_similarity_matrix_slice( # self.loader.nspikes, # len(self.loader.get_clusters_unique())) clusters_now = self.loader.get_clusters( spikes=self.loader.background_spikes) if not np.array_equal(clusters, clusters_now): return False self.statscache.similarity_matrix.update(clusters_selected, matrix) self.statscache.similarity_matrix_normalized = normalize( self.statscache.similarity_matrix.to_array(copy=True)) # Update the cluster view with cluster quality. quality = np.diag(self.statscache.similarity_matrix_normalized) self.statscache.cluster_quality = pd.Series( quality, index=self.statscache.similarity_matrix.indices, ) self.get_view('ClusterView').set_quality( self.statscache.cluster_quality) return [ ('_wizard_update', (target_next, )), ('_update_similarity_matrix_view', ), ] def _invalidate(self, clusters): self.statscache.invalidate(clusters) # View updates. # ------------- def _update_correlograms_view(self): data = vd.get_correlogramsview_data(self.loader, self.statscache) [view.set_data(**data) for view in self.get_views('CorrelogramsView')] def _update_similarity_matrix_view(self): data = vd.get_similaritymatrixview_data(self.loader, self.statscache) [ view.set_data(**data) for view in self.get_views('SimilarityMatrixView') ] # Show selected clusters when the matrix has been updated. clusters = self.loader.get_clusters_selected() return ('_show_selection_in_matrix', (clusters, )) def _update_feature_view(self, autozoom=None): data = vd.get_featureview_data(self.loader, autozoom=autozoom) [view.set_data(**data) for view in self.get_views('FeatureView')] def _update_waveform_view(self, autozoom=None, wizard=None): data = vd.get_waveformview_data(self.loader, autozoom=autozoom, wizard=wizard) [view.set_data(**data) for view in self.get_views('WaveformView')] def _update_trace_view(self): data = vd.get_traceview_data(self.loader) [view.set_data(**data) for view in self.get_views('TraceView')] def _update_cluster_view(self, clusters=None): """Update the cluster view using the data stored in the loader object.""" data = vd.get_clusterview_data(self.loader, self.statscache, clusters=clusters) self.get_view('ClusterView').set_data(**data) if clusters is not None: return def _show_selection_in_matrix(self, clusters): if clusters is not None and 1 <= len(clusters) <= 2: [ view.show_selection(clusters[0], clusters[-1]) for view in self.get_views('SimilarityMatrixView') ] # Override colors. # ---------------- def _override_color(self, override_color): self.loader.set_override_color(override_color) return [ '_update_feature_view', '_update_waveform_view', '_update_correlograms_view' ] # Change correlograms parameter. # ------------------------------ def _change_correlograms_parameters(self, ncorrbins=None, corrbin=None): # Update the correlograms parameters. if ncorrbins is not None: SETTINGS['correlograms.ncorrbins'] = ncorrbins if corrbin is not None: SETTINGS['correlograms.corrbin'] = corrbin # Reset the cache. self.statscache.reset(ncorrbins) # Update the correlograms. clusters = self.loader.get_clusters_selected() return ('_compute_correlograms', (clusters, )) # Merge/split actions. # -------------------- def _merge(self, clusters, wizard=False): if len(clusters) >= 2: action, output = self.controller.merge_clusters(clusters) # Tell the next nodes whether the merge occurred after a wizard # selection or not, so that the merged cluster background is # highlighted or not. output['wizard'] = wizard return after_merge(output) def _split(self, clusters, spikes_selected, wizard=False): if len(spikes_selected) >= 1: action, output = self.controller.split_clusters( clusters, spikes_selected) output['wizard'] = wizard return after_split(output) def _undo(self, wizard=False): undo = self.controller.undo() if undo is None: return action, output = undo output['wizard'] = wizard if action == 'merge_clusters_undo': return after_merge_undo(output) elif action == 'split_clusters_undo': return after_split_undo(output) elif action == 'change_cluster_color_undo': return after_cluster_color_changed_undo(output) elif action == 'change_group_color_undo': return after_group_color_changed(output) elif action == 'move_clusters_undo': return after_clusters_moved_undo(output) elif action == 'add_group_undo': return after_group_added(output) elif action == 'rename_group_undo': return after_group_renamed(output) elif action == 'remove_group_undo': return after_group_removed(output) def _redo(self, wizard=False): redo = self.controller.redo() if redo is None: return action, output = redo output['wizard'] = wizard if action == 'merge_clusters': return after_merge(output) elif action == 'split_clusters': return after_split(output) elif action == 'change_cluster_color': return after_cluster_color_changed(output) elif action == 'change_group_color': return after_group_color_changed(output) elif action == 'move_clusters': return after_clusters_moved(output) elif action == 'add_group': return after_group_added(output) elif action == 'rename_group': return after_group_renamed(output) elif action == 'remove_group': return after_group_removed(output) # Other actions. # -------------- def _cluster_color_changed(self, cluster, color, wizard=True): action, output = self.controller.change_cluster_color(cluster, color) # if cluster == self.wizard.current_target(): output['wizard'] = wizard return after_cluster_color_changed(output) def _group_color_changed(self, group, color): action, output = self.controller.change_group_color(group, color) return after_group_color_changed(output) def _group_renamed(self, group, name): action, output = self.controller.rename_group(group, name) return after_group_renamed(output) def _clusters_moved( self, clusters, group, wizard=False, ): action, output = self.controller.move_clusters(clusters, group) output['wizard'] = wizard return after_clusters_moved(output) def _group_removed(self, group): action, output = self.controller.remove_group(group) return after_group_removed(output) def _group_added(self, group, name, color): action, output = self.controller.add_group(group, name, color) return after_group_added(output) # Wizard. # ------- def _wizard_update(self, target=None, update_matrix=True): if update_matrix: self.wizard.set_data( cluster_groups=self.loader.get_cluster_groups('all'), similarity_matrix=self.statscache.similarity_matrix_normalized, ) else: self.wizard.set_data( cluster_groups=self.loader.get_cluster_groups('all'), ) self.wizard.update_candidates(target) def _wizard_change_color(self, clusters): if clusters is not None: # Set the background color in the cluster view for the wizard # target and candidate. self.get_view('ClusterView').set_background({ cluster: { 0: 'target', 1: 'candidate' }.get(i, None) for i, cluster in enumerate(clusters[:2]) }) def _wizard_change_candidate_color(self): candidate = self.wizard.current_candidate() target = self.wizard.current_target() # color = self.loader.get_cluster_color(candidate) return ('_cluster_color_changed', ( candidate, random_color(), )) def _wizard_show_pair(self, target=None, candidate=None): if target is None: target = (self.wizard.current_target(), self.loader.get_cluster_color( self.wizard.current_target())) if candidate is None: candidate = (self.wizard.current_candidate(), self.loader.get_cluster_color( self.wizard.current_candidate())) [ view.set_wizard_pair(target, candidate) for view in self.get_views('FeatureView') ] # Navigation. def _wizard_reset(self): clusters = self.wizard.reset() return ['_wizard_update', '_wizard_current_candidate'] def _wizard_previous_candidate(self): clusters = self.wizard.previous_pair() return after_wizard_selection(clusters) def _wizard_current_candidate(self): clusters = self.wizard.current_pair() return after_wizard_selection(clusters) def _wizard_next_candidate(self): clusters = self.wizard.next_pair() return after_wizard_selection(clusters) def _wizard_skip_target(self): # Skip the current target and go the next target. self.wizard.skip_target() return [ ('_wizard_update', ()), ('_wizard_next_candidate', ), ] def _wizard_reset_skipped(self): self.wizard.reset_skipped() # Control. def _wizard_move_and_next(self, what, group): """Move target, candidate, or both, to a given group, and go to the next proposition.""" # Current proposition. clusters = self.wizard.current_pair() if clusters is None: return target, candidate = clusters # Select appropriate clusters to move. if what == 'candidate': clusters = [candidate] # Keep the current target. target_next = target reset_skipped = False elif what == 'target': clusters = [target] # Go to the next best target cluster. target_next = None reset_skipped = True elif what == 'both': clusters = [candidate, target] # Go to the next best target cluster. target_next = None reset_skipped = True # Move clusters, and select next proposition. r = [ ('_clusters_moved', (clusters, group, True)), ] if reset_skipped: r += [ ('_wizard_reset_skipped', ), ] r += [ ('_wizard_update', (target_next, )), ('_wizard_next_candidate', ), ] return r