def test_run_sorters_with_list(): # This import is to get error on github whenn import fails import tridesclous cache_folder = './local_cache' working_folder = 'test_run_sorters_list' if os.path.exists(cache_folder): shutil.rmtree(cache_folder) if os.path.exists(working_folder): shutil.rmtree(working_folder) rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) # make dumpable set_global_tmp_folder(cache_folder) rec0 = rec0.save(name='rec0') rec1 = rec1.save(name='rec1') recording_list = [rec0, rec1] sorter_list = ['tridesclous'] run_sorters(sorter_list, recording_list, working_folder, engine='loop', verbose=False, with_output=False)
def test_run_sorters_joblib(): cache_folder = './local_cache' working_folder = 'test_run_sorters_joblib' if os.path.exists(cache_folder): shutil.rmtree(cache_folder) if os.path.exists(working_folder): shutil.rmtree(working_folder) set_global_tmp_folder(cache_folder) recording_dict = {} for i in range(8): rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) # make dumpable rec = rec.save(name=f'rec_{i}') recording_dict[f'rec_{i}'] = rec sorter_list = ['tridesclous', ] # joblib t0 = time.perf_counter() run_sorters(sorter_list, recording_dict, working_folder, engine='joblib', engine_kwargs={'n_jobs' : 4}, with_output=False, mode_if_folder_exists='keep') t1 = time.perf_counter() print(t1 - t0)
def test_run_sorters_dask(): cache_folder = './local_cache' working_folder = 'test_run_sorters_dask' if os.path.exists(cache_folder): shutil.rmtree(cache_folder) if os.path.exists(working_folder): shutil.rmtree(working_folder) # create recording recording_dict = {} for i in range(8): rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) # make dumpable rec = rec.save(name=f'rec_{i}') recording_dict[f'rec_{i}'] = rec sorter_list = ['tridesclous', ] # create a dask Client for a slurm queue from dask.distributed import Client from dask_jobqueue import SLURMCluster python = '/home/samuel.garcia/.virtualenvs/py36/bin/python3.6' cluster = SLURMCluster(processes=1, cores=1, memory="12GB", python=python, walltime='12:00:00', ) cluster.scale(5) client = Client(cluster) # dask t0 = time.perf_counter() run_sorters(sorter_list, recording_dict, working_folder, engine='dask', engine_kwargs={'client': client}, with_output=False, mode_if_folder_exists='keep') t1 = time.perf_counter() print(t1 - t0)
def run_sorters(self, sorter_list, mode_if_folder_exists='keep', **kwargs): sorter_folders = self.study_folder / 'sorter_folders' recording_dict = get_recordings(self.study_folder) run_sorters(sorter_list, recording_dict, sorter_folders, with_output=False, mode_if_folder_exists=mode_if_folder_exists, **kwargs) # results are copied so the heavy sorter_folders can be removed self.copy_sortings()
def test_run_sorters_with_dict(): # This import is to get error on github whenn import fails import tridesclous import circus cache_folder = './local_cache' working_folder = 'test_run_sorters_dict' if os.path.exists(cache_folder): shutil.rmtree(cache_folder) if os.path.exists(working_folder): shutil.rmtree(working_folder) rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) # make dumpable set_global_tmp_folder(cache_folder) rec0 = rec0.save(name='rec0') rec1 = rec1.save(name='rec1') recording_dict = {'toy_tetrode': rec0, 'toy_octotrode': rec1} sorter_list = ['tridesclous', 'spykingcircus'] sorter_params = { 'tridesclous': dict(detect_threshold=5.6), 'spykingcircus': dict(detect_threshold=5.6), } # simple loop t0 = time.perf_counter() results = run_sorters(sorter_list, recording_dict, working_folder, engine='loop', sorter_params=sorter_params, with_output=True, mode_if_folder_exists='raise') t1 = time.perf_counter() print(t1 - t0) print(results) shutil.rmtree(working_folder + '/toy_tetrode/tridesclous') run_sorters(sorter_list, recording_dict, working_folder, engine='loop', sorter_params=sorter_params, with_output=False, mode_if_folder_exists='keep')
# The sorter name can be now a parameter, e.g. chosen with a command line interface or a GUI sorter_name = 'klusta' sorting_KL = ss.run_sorter(sorter_name_or_class='klusta', recording=recording, output_folder='my_sorter_output') print(sorting_KL.get_unit_ids()) ############################################################################## # This will launch the klusta sorter on the recording object. # # You can also run multiple sorters on the same recording: recording_list = [recording] sorter_list = ['klusta', 'mountainsort4', 'tridesclous'] sorting_output = ss.run_sorters(sorter_list, recording_list, working_folder='tmp_some_sorters', mode='overwrite') ############################################################################## # The 'mode' argument allows to 'overwrite' the 'working_folder' (if existing), 'raise' and Exception, or 'keep' the # folder and skip the spike sorting run. # # To 'sorting_output' is a dictionary that has (recording, sorter) pairs as keys and the correspondent # :code:`SortingExtractor` as values. It can be accessed as follows: for (rec, sorter), sort in sorting_output.items(): print(rec, sorter, ':', sort.get_unit_ids()) ############################################################################## # With the same mechanism, you can run several spike sorters on many recordings, just by creating a list of # :code:`RecordingExtractor` objects (:code:`recording_list`).
] sorter_params = { 'mountainsort4': { 'adjacency_radius': 50 }, 'spyking_circus': { 'adjacency_radius': 50 } } if not (results_folder / 'sortings').is_dir(): print('-----------------------------------------') print('Running sorters') result_dict = ss.run_sorters(sorter_list=sorter_list, recording_dict_or_list=rec_dict, with_output=True, debug=True, sorter_params=sorter_params, working_folder=working_folder) print('-----------------------------------------') print('Saving results') for s in sorter_list: sorting = result_dict[('rec', s)] se.NpzSortingExtractor.write_sorting( sorting, results_folder / 'sortings' / str(s + '.npz')) print('-----------------------------------------') print('Loading sorting output') exclude_sorters = [] sorting_dict = {}
sorting_HS = ss.run_sorter(sorter_name='herdingspikes', recording=recording, output_folder='my_sorter_output', clustering_bandwidth=8) print(sorting_HS.get_unit_ids()) ############################################################################## # # You can also run multiple sorters on the same recording: recordings = {'toy': recording} sorter_list = ['herdingspikes', 'tridesclous'] sorter_params = {'herdingspikes': {'clustering_bandwidth': 8}} sorting_output = ss.run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters', mode_if_folder_exists='overwrite', sorter_params=sorter_params) ############################################################################## # The 'mode' argument allows to 'overwrite' the 'working_folder' (if existing), 'raise' and Exception, or 'keep' the # folder and skip the spike sorting run. # # To 'sorting_output' is a dictionary that has (recording, sorter) pairs as keys and the correspondent # :code:`SortingExtractor` as values. It can be accessed as follows: for (rec_name, sorter_name), sorting in sorting_output.items(): print(rec_name, sorter_name, ':', sorting.get_unit_ids()) ############################################################################## # With the same mechanism, you can run several spike sorters on many recordings, just by creating a list/dict of
def run(self): task_args_list = [] for key in self._sortings_pre.keys(): # recording_dict = self._recordings[key[0]].to_dict() # sorting_dict = self._sortings_pre[key].to_dict() # gt_dict = self._gt[key[0]].to_dict() if self._gt is not None else None # comparison = sc.compare_sorter_to_ground_truth(tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]]) # self._comparisons[key] = comparison # task_args_list.append((recording_dict, gt_dict, sorting_dict, key, # self._params_dict['wd_score'], self._params_dict['isi_thr'], # self._params_dict['fr_thr'], self._params_dict['sample_window_ms'], # self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'], # self._params_dict['detect_threshold'], self._params_dict['method'], # self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params, # comparison, self._output_folder, self._params_dict['job_kwargs'])) self._recordings[key[0]].save_to_folder( folder=self._output_folder / 'back_recording' / key[1] / key[0]) self._sortings_pre[key].save_to_folder(folder=self._output_folder / 'back_recording' / key[0] / (key[1] + '_pre')) self._gt[key[0]].save_to_folder(folder=self._output_folder / 'back_recording' / key[1] / (key[0] + '_gt')) task_args_list.append( (key, self._params_dict['wd_score'], self._params_dict['isi_thr'], self._params_dict['fr_thr'], self._params_dict['sample_window_ms'], self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'], self._params_dict['detect_threshold'], self._params_dict['method'], self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params, self._compare, self._output_folder, self._params_dict['job_kwargs'])) if self._params_dict['parallel']: # raise NotImplementedError() from joblib import Parallel, delayed Parallel(n_jobs=self._params_dict['n_jobs'], backend='loky')(delayed(_do_recovery_loop)(task_args) for task_args in task_args_list) else: for task_args in task_args_list: _do_recovery_loop(task_args) for key in self._sortings_pre.keys(): if key[1] in self._recordings_backprojected.keys(): self._recordings_backprojected[key[1]].append( load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])) else: self._recordings_backprojected[key[1]] = \ [load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])] for sorter in self._recordings_backprojected.keys(): self._sortings_post[sorter] = ss.run_sorters( sorter, self._recordings_backprojected[sorter], working_folder=self._output_folder / 'sortings_post' / sorter, sorter_params=self._sorters_params['sorters_params'], mode_if_folder_exists='overwrite', engine=self._sorters_params['engine'], engine_kwargs=self._sorters_params['engine_kwargs'], verbose=self._sorters_params['verbose'], with_output=self._sorters_params['with_output']) for key in self._sortings_post[sorter].keys(): self._aggregated_sortings[key] = aggregate_units([ self._sortings_post[sorter][key], self._sortings_pre[key] ]) self._comparisons[key] = sc.compare_sorter_to_ground_truth( tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]])
def __init__(self, sorters: list, recordings: Union[list, dict], gt=None, sorters_params={}, output_folder=None, overwrite=False, we_params={}, well_detected_score=.7, isi_thr=.3, fr_thr=None, sample_window_ms=2, percentage_spikes=None, balance_spikes=False, detect_threshold=5, method='locally_exclusive', skew_thr=0.1, n_jobs=4, parallel=False, **job_kwargs): """ Apply spike sorting algorithm two times to increase its accuracy. After the first sorting, well detected units are removed from the recording. ICA is run on the "new recording" to increase its SNR and then ease the detection of small units. Finally, the spike sorting algorithm is run again on the ica-filtered recording. Multiple sorting algorithms can be run at the same time, each one on its own recording. The recovery can be run in parallel or in a loop. The former option is suggested if the number of recordings or sortings is high. Parameters ---------- sorters: list list of sorters name to be run. recordings: list or dict list or dict of RecordingExtractors. If dict, the keys are sorter names. gt: list or dict list or dict of ground truth SortingExtractors. sorters_params: dict dict with keys the parameters of spikeinterface.sorters.run_sorters(). If a parameter is not set, its default values is used. output_folder: str String with name or path of the output folder. If none it is named 'recovery_output' overwrite: bool If True and output_folder exists, it will be overwritten. If false and output_folder exists an exception is raised. we_params: dict with keys the parameters of spikeinterface.core.extract_waveforms(). If a parameter is not set, its default values is used. well_detected_score: float agreement score to mark a unit as well detected. Used only if gt is provided. isi_thr: float If the ISI violation ratio of a unit is above the threshold, it will be discarded. fr_thr: list list with 2 values. If the firing rate of a unit is not in the provided interval, it will be discarded. sample_window_ms: list or int If list [ms_before, ms_after] of recording selected for each detected spike in subsampling for ICA. percentage_spikes: float percentage of detected spikes to be used in subsampling for ICA. If None, all spikes are used. balance_spikes: bool If true, same percentage of spikes is selected channel by channel. If None, spikes are picked randomly. Used only if percentage_spikes is not None detect_threshold: float MAD threshold for spike detection in subsampling for ICA. method: str How to detect peaks: * 'by_channel' : peak are detected in each channel independently. (default) * 'locally_exclusive' : locally given a radius the best peak only is taken but not neighboring channels. skew_thr: float Skewness threshold for ICA sources cleaning. If the skewness is lower than the threshold, it will be discarded. n_jobs: int Number of parallel processes parallel: bool If True, the recovery is run in parallel for each sorter. If False, the recovery is run in loop. job_kwargs: dict Parameters for parallel processing of RecordingExtractors. Returns -------- unitsrecovery object """ self._sorters = sorters if output_folder is None: output_folder = 'recovery_output' self._output_folder = Path(output_folder) if fr_thr is None: fr_thr = [3.5, 19.5] self._params_dict = { 'wd_score': well_detected_score, 'isi_thr': isi_thr, 'fr_thr': fr_thr, 'parallel': parallel, 'sample_window_ms': sample_window_ms, 'percentage_spikes': percentage_spikes, 'balance_spikes': balance_spikes, 'detect_threshold': detect_threshold, 'method': method, 'skew_thr': skew_thr, 'n_jobs': n_jobs, 'job_kwargs': job_kwargs } self._sorters_params, self._we_params = _set_sorters_params( sorters_params, we_params) # assert len(sorters) == len(recordings), "The number of sorters must equal the number of recordings" if self._output_folder.is_dir() and not overwrite: raise Exception( 'Output folder already exists. Set overwrite=True to overwrite it' ) elif self._output_folder.is_dir() and overwrite: rmtree(self._output_folder) self._sortings_pre = ss.run_sorters( sorters, recordings, working_folder=self._output_folder / 'sorting_pre', sorter_params=self._sorters_params['sorters_params'], mode_if_folder_exists='overwrite', engine=self._sorters_params['engine'], engine_kwargs=self._sorters_params['engine_kwargs'], verbose=self._sorters_params['verbose'], with_output=self._sorters_params['with_output']) if not isinstance(recordings, dict): self._recordings = { key[0]: recordings[int(key[0][-1])] for key in self._sortings_pre.keys() } else: self._recordings = recordings if not isinstance(gt, dict) and gt is not None: assert len(recordings) == len( gt), "Recordings and gts must be of same length" self._gt = { key[0]: gt[int(key[0][-1])] for key in self._sortings_pre.keys() } else: if isinstance(gt, dict) and isinstance(recordings, dict): assert gt.keys() == recordings.keys( ), "Recordings and gts dictionaries must have same keys" self._gt = gt if gt is not None: self._comparisons = {} self._compare = True else: self._compare = False self._recordings_backprojected = {} self._aggregated_sortings = {} self._sortings_post = {}
recordings = recording_4_tetrodes.split_by(property='group') print(recordings) ############################################################################## # We can also get a dict instead of the list which is easier to handle group keys. recordings = recording_4_tetrodes.split_by(property='group', outputs='dict') print(recordings) ############################################################################## # We can now use the `run_sorters()` function instead of the `run_sorter()`. # This function can run several sorters on several recording with different parallel engines. # here we use engine 'loop' but we could use also 'joblib' or 'dask' for multi process or multi node computing. # have a look to the documentation of this function that handle many cases. sorter_list = ['tridesclous'] working_folder = 'sorter_outputs' results = ss.run_sorters(sorter_list, recordings, working_folder, engine='loop', with_output=True, mode_if_folder_exists='overwrite') ############################################################################## # the output is a dict with all combinations of (group, sorter_name) from pprint import pprint pprint(results)