Пример #1
0
def test_run_sorters_with_list():
    # This import is to get error on github whenn import fails
    import tridesclous
    
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_list'

    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)
    
    rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
    rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)
    
    # make dumpable
    set_global_tmp_folder(cache_folder)
    rec0 = rec0.save(name='rec0')
    rec1 = rec1.save(name='rec1')
    
    recording_list = [rec0, rec1]
    sorter_list = ['tridesclous']

    run_sorters(sorter_list, recording_list, working_folder,
            engine='loop', verbose=False, with_output=False)
Пример #2
0
def test_run_sorters_joblib():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_joblib'
    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)
    
    set_global_tmp_folder(cache_folder)
    
    recording_dict = {}
    for i in range(8):
        rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
        # make dumpable
        rec = rec.save(name=f'rec_{i}')
        recording_dict[f'rec_{i}'] = rec

    sorter_list = ['tridesclous', ]

    # joblib
    t0 = time.perf_counter()
    run_sorters(sorter_list, recording_dict, working_folder,
                engine='joblib', engine_kwargs={'n_jobs' : 4},
                with_output=False,
                mode_if_folder_exists='keep')
    t1 = time.perf_counter()
    print(t1 - t0)
Пример #3
0
def test_run_sorters_dask():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_dask'
    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    # create recording
    recording_dict = {}
    for i in range(8):
        rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
        # make dumpable
        rec = rec.save(name=f'rec_{i}')
        recording_dict[f'rec_{i}'] = rec

    sorter_list = ['tridesclous', ]

    # create a dask Client for a slurm queue
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster

    python = '/home/samuel.garcia/.virtualenvs/py36/bin/python3.6'
    cluster = SLURMCluster(processes=1, cores=1, memory="12GB", python=python, walltime='12:00:00', )
    cluster.scale(5)
    client = Client(cluster)

    # dask
    t0 = time.perf_counter()
    run_sorters(sorter_list, recording_dict, working_folder,
                engine='dask', engine_kwargs={'client': client},
                with_output=False,
                mode_if_folder_exists='keep')
    t1 = time.perf_counter()
    print(t1 - t0)
Пример #4
0
    def run_sorters(self, sorter_list, mode_if_folder_exists='keep', **kwargs):

        sorter_folders = self.study_folder / 'sorter_folders'
        recording_dict = get_recordings(self.study_folder)

        run_sorters(sorter_list, recording_dict, sorter_folders,
                    with_output=False, mode_if_folder_exists=mode_if_folder_exists, **kwargs)

        # results are copied so the heavy sorter_folders can be removed
        self.copy_sortings()
Пример #5
0
def test_run_sorters_with_dict():
    # This import is to get error on github whenn import fails
    import tridesclous
    import circus

    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_dict'

    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    rec0, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
    rec1, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1)

    # make dumpable
    set_global_tmp_folder(cache_folder)
    rec0 = rec0.save(name='rec0')
    rec1 = rec1.save(name='rec1')

    recording_dict = {'toy_tetrode': rec0, 'toy_octotrode': rec1}

    sorter_list = ['tridesclous', 'spykingcircus']

    sorter_params = {
        'tridesclous': dict(detect_threshold=5.6),
        'spykingcircus': dict(detect_threshold=5.6),
    }

    # simple loop
    t0 = time.perf_counter()
    results = run_sorters(sorter_list,
                          recording_dict,
                          working_folder,
                          engine='loop',
                          sorter_params=sorter_params,
                          with_output=True,
                          mode_if_folder_exists='raise')

    t1 = time.perf_counter()
    print(t1 - t0)
    print(results)

    shutil.rmtree(working_folder + '/toy_tetrode/tridesclous')
    run_sorters(sorter_list,
                recording_dict,
                working_folder,
                engine='loop',
                sorter_params=sorter_params,
                with_output=False,
                mode_if_folder_exists='keep')
# The sorter name can be now a parameter, e.g. chosen with a command line interface or a GUI
sorter_name = 'klusta'
sorting_KL = ss.run_sorter(sorter_name_or_class='klusta',
                           recording=recording,
                           output_folder='my_sorter_output')
print(sorting_KL.get_unit_ids())

##############################################################################
# This will launch the klusta sorter on the recording object.
#
# You can also run multiple sorters on the same recording:

recording_list = [recording]
sorter_list = ['klusta', 'mountainsort4', 'tridesclous']
sorting_output = ss.run_sorters(sorter_list,
                                recording_list,
                                working_folder='tmp_some_sorters',
                                mode='overwrite')

##############################################################################
# The 'mode' argument allows to 'overwrite' the 'working_folder' (if existing), 'raise' and Exception, or 'keep' the
# folder and skip the spike sorting run.
#
# To 'sorting_output' is a dictionary that has (recording, sorter) pairs as keys and the correspondent
# :code:`SortingExtractor` as values. It can be accessed as follows:

for (rec, sorter), sort in sorting_output.items():
    print(rec, sorter, ':', sort.get_unit_ids())

##############################################################################
# With the same mechanism, you can run several spike sorters on many recordings, just by creating a list of
# :code:`RecordingExtractor` objects (:code:`recording_list`).
Пример #7
0
]
sorter_params = {
    'mountainsort4': {
        'adjacency_radius': 50
    },
    'spyking_circus': {
        'adjacency_radius': 50
    }
}

if not (results_folder / 'sortings').is_dir():
    print('-----------------------------------------')
    print('Running sorters')
    result_dict = ss.run_sorters(sorter_list=sorter_list,
                                 recording_dict_or_list=rec_dict,
                                 with_output=True,
                                 debug=True,
                                 sorter_params=sorter_params,
                                 working_folder=working_folder)

    print('-----------------------------------------')
    print('Saving results')
    for s in sorter_list:
        sorting = result_dict[('rec', s)]
        se.NpzSortingExtractor.write_sorting(
            sorting, results_folder / 'sortings' / str(s + '.npz'))

print('-----------------------------------------')
print('Loading sorting output')

exclude_sorters = []
sorting_dict = {}
Пример #8
0
sorting_HS = ss.run_sorter(sorter_name='herdingspikes',
                           recording=recording,
                           output_folder='my_sorter_output',
                           clustering_bandwidth=8)
print(sorting_HS.get_unit_ids())

##############################################################################
#
# You can also run multiple sorters on the same recording:

recordings = {'toy': recording}
sorter_list = ['herdingspikes', 'tridesclous']
sorter_params = {'herdingspikes': {'clustering_bandwidth': 8}}
sorting_output = ss.run_sorters(sorter_list,
                                recordings,
                                working_folder='tmp_some_sorters',
                                mode_if_folder_exists='overwrite',
                                sorter_params=sorter_params)

##############################################################################
# The 'mode' argument allows to 'overwrite' the 'working_folder' (if existing), 'raise' and Exception, or 'keep' the
# folder and skip the spike sorting run.
#
# To 'sorting_output' is a dictionary that has (recording, sorter) pairs as keys and the correspondent
# :code:`SortingExtractor` as values. It can be accessed as follows:

for (rec_name, sorter_name), sorting in sorting_output.items():
    print(rec_name, sorter_name, ':', sorting.get_unit_ids())

##############################################################################
# With the same mechanism, you can run several spike sorters on many recordings, just by creating a list/dict of
Пример #9
0
    def run(self):

        task_args_list = []
        for key in self._sortings_pre.keys():
            # recording_dict = self._recordings[key[0]].to_dict()
            # sorting_dict = self._sortings_pre[key].to_dict()
            # gt_dict = self._gt[key[0]].to_dict() if self._gt is not None else None
            # comparison = sc.compare_sorter_to_ground_truth(tested_sorting=self._sortings_pre[key], gt_sorting=self._gt[key[0]])
            # self._comparisons[key] = comparison
            # task_args_list.append((recording_dict, gt_dict, sorting_dict, key,
            #                        self._params_dict['wd_score'], self._params_dict['isi_thr'],
            #                        self._params_dict['fr_thr'], self._params_dict['sample_window_ms'],
            #                        self._params_dict['percentage_spikes'], self._params_dict['balance_spikes'],
            #                        self._params_dict['detect_threshold'], self._params_dict['method'],
            #                        self._params_dict['skew_thr'], self._params_dict['n_jobs'], self._we_params,
            #                        comparison, self._output_folder, self._params_dict['job_kwargs']))
            self._recordings[key[0]].save_to_folder(
                folder=self._output_folder / 'back_recording' / key[1] /
                key[0])
            self._sortings_pre[key].save_to_folder(folder=self._output_folder /
                                                   'back_recording' / key[0] /
                                                   (key[1] + '_pre'))
            self._gt[key[0]].save_to_folder(folder=self._output_folder /
                                            'back_recording' / key[1] /
                                            (key[0] + '_gt'))
            task_args_list.append(
                (key, self._params_dict['wd_score'],
                 self._params_dict['isi_thr'], self._params_dict['fr_thr'],
                 self._params_dict['sample_window_ms'],
                 self._params_dict['percentage_spikes'],
                 self._params_dict['balance_spikes'],
                 self._params_dict['detect_threshold'],
                 self._params_dict['method'], self._params_dict['skew_thr'],
                 self._params_dict['n_jobs'], self._we_params, self._compare,
                 self._output_folder, self._params_dict['job_kwargs']))

        if self._params_dict['parallel']:
            # raise NotImplementedError()
            from joblib import Parallel, delayed
            Parallel(n_jobs=self._params_dict['n_jobs'],
                     backend='loky')(delayed(_do_recovery_loop)(task_args)
                                     for task_args in task_args_list)
        else:
            for task_args in task_args_list:
                _do_recovery_loop(task_args)

        for key in self._sortings_pre.keys():
            if key[1] in self._recordings_backprojected.keys():
                self._recordings_backprojected[key[1]].append(
                    load_extractor(self._output_folder / 'back_recording' /
                                   key[0] / key[1]))
            else:
                self._recordings_backprojected[key[1]] = \
                    [load_extractor(self._output_folder / 'back_recording' / key[0] / key[1])]

        for sorter in self._recordings_backprojected.keys():
            self._sortings_post[sorter] = ss.run_sorters(
                sorter,
                self._recordings_backprojected[sorter],
                working_folder=self._output_folder / 'sortings_post' / sorter,
                sorter_params=self._sorters_params['sorters_params'],
                mode_if_folder_exists='overwrite',
                engine=self._sorters_params['engine'],
                engine_kwargs=self._sorters_params['engine_kwargs'],
                verbose=self._sorters_params['verbose'],
                with_output=self._sorters_params['with_output'])
            for key in self._sortings_post[sorter].keys():
                self._aggregated_sortings[key] = aggregate_units([
                    self._sortings_post[sorter][key], self._sortings_pre[key]
                ])
                self._comparisons[key] = sc.compare_sorter_to_ground_truth(
                    tested_sorting=self._sortings_pre[key],
                    gt_sorting=self._gt[key[0]])
Пример #10
0
    def __init__(self,
                 sorters: list,
                 recordings: Union[list, dict],
                 gt=None,
                 sorters_params={},
                 output_folder=None,
                 overwrite=False,
                 we_params={},
                 well_detected_score=.7,
                 isi_thr=.3,
                 fr_thr=None,
                 sample_window_ms=2,
                 percentage_spikes=None,
                 balance_spikes=False,
                 detect_threshold=5,
                 method='locally_exclusive',
                 skew_thr=0.1,
                 n_jobs=4,
                 parallel=False,
                 **job_kwargs):
        """
        Apply spike sorting algorithm two times to increase its accuracy. After the first sorting, well detected units
        are removed from the recording. ICA is run on the "new recording" to increase its SNR and then ease the detection
        of small units. Finally, the spike sorting algorithm is run again on the ica-filtered recording.

        Multiple sorting algorithms can be run at the same time, each one on its own recording. The recovery can be run
        in parallel or in a loop. The former option is suggested if the number of recordings or sortings is high.
        Parameters
        ----------
        sorters: list
            list of sorters name to be run.
        recordings: list or dict
            list or dict of RecordingExtractors. If dict, the keys are sorter names.
        gt: list or dict
            list or dict of ground truth SortingExtractors.
        sorters_params: dict
            dict with keys the parameters of spikeinterface.sorters.run_sorters().
            If a parameter is not set, its default values is used.
        output_folder: str
            String with name or path of the output folder. If none it is named 'recovery_output'
        overwrite: bool
            If True and output_folder exists, it will be overwritten. If false and output_folder exists an exception is raised.
        we_params:
            dict with keys the parameters of spikeinterface.core.extract_waveforms().
            If a parameter is not set, its default values is used.
        well_detected_score: float
            agreement score to mark a unit as well detected. Used only if gt is provided.
        isi_thr: float
            If the ISI violation ratio of a unit is above the threshold, it will be discarded.
        fr_thr: list
            list with 2 values. If the firing rate of a unit is not in the provided interval,
            it will be discarded.
        sample_window_ms: list or int
            If list [ms_before, ms_after] of recording selected for each detected spike in subsampling for ICA.
        percentage_spikes: float
            percentage of detected spikes to be used in subsampling for ICA. If None, all spikes are used.
        balance_spikes: bool
            If true, same percentage of spikes is selected channel by channel. If None, spikes are picked randomly.
            Used only if percentage_spikes is not None
        detect_threshold: float
            MAD threshold for spike detection in subsampling for ICA.
        method: str
            How to detect peaks:
            * 'by_channel' : peak are detected in each channel independently. (default)
            * 'locally_exclusive' : locally given a radius the best peak only is taken but
              not neighboring channels.
        skew_thr: float
            Skewness threshold for ICA sources cleaning. If the skewness is lower than the threshold,
            it will be discarded.
        n_jobs: int
            Number of parallel processes
        parallel: bool
            If True, the recovery is run in parallel for each sorter. If False, the recovery is run in loop.
        job_kwargs: dict
            Parameters for parallel processing of RecordingExtractors.

        Returns
        --------
        unitsrecovery object
        """
        self._sorters = sorters
        if output_folder is None:
            output_folder = 'recovery_output'
        self._output_folder = Path(output_folder)
        if fr_thr is None:
            fr_thr = [3.5, 19.5]
        self._params_dict = {
            'wd_score': well_detected_score,
            'isi_thr': isi_thr,
            'fr_thr': fr_thr,
            'parallel': parallel,
            'sample_window_ms': sample_window_ms,
            'percentage_spikes': percentage_spikes,
            'balance_spikes': balance_spikes,
            'detect_threshold': detect_threshold,
            'method': method,
            'skew_thr': skew_thr,
            'n_jobs': n_jobs,
            'job_kwargs': job_kwargs
        }

        self._sorters_params, self._we_params = _set_sorters_params(
            sorters_params, we_params)

        # assert len(sorters) == len(recordings), "The number of sorters must equal the number of recordings"
        if self._output_folder.is_dir() and not overwrite:
            raise Exception(
                'Output folder already exists. Set overwrite=True to overwrite it'
            )
        elif self._output_folder.is_dir() and overwrite:
            rmtree(self._output_folder)
        self._sortings_pre = ss.run_sorters(
            sorters,
            recordings,
            working_folder=self._output_folder / 'sorting_pre',
            sorter_params=self._sorters_params['sorters_params'],
            mode_if_folder_exists='overwrite',
            engine=self._sorters_params['engine'],
            engine_kwargs=self._sorters_params['engine_kwargs'],
            verbose=self._sorters_params['verbose'],
            with_output=self._sorters_params['with_output'])

        if not isinstance(recordings, dict):
            self._recordings = {
                key[0]: recordings[int(key[0][-1])]
                for key in self._sortings_pre.keys()
            }
        else:
            self._recordings = recordings
        if not isinstance(gt, dict) and gt is not None:
            assert len(recordings) == len(
                gt), "Recordings and gts must be of same length"
            self._gt = {
                key[0]: gt[int(key[0][-1])]
                for key in self._sortings_pre.keys()
            }
        else:
            if isinstance(gt, dict) and isinstance(recordings, dict):
                assert gt.keys() == recordings.keys(
                ), "Recordings and gts dictionaries must have same keys"
            self._gt = gt
        if gt is not None:
            self._comparisons = {}
            self._compare = True
        else:
            self._compare = False
        self._recordings_backprojected = {}
        self._aggregated_sortings = {}
        self._sortings_post = {}
recordings = recording_4_tetrodes.split_by(property='group')
print(recordings)

##############################################################################
# We can also get a dict instead of the list which is easier to handle group keys.

recordings = recording_4_tetrodes.split_by(property='group', outputs='dict')
print(recordings)

##############################################################################
# We can now use the `run_sorters()` function instead of the `run_sorter()`.
# This function can run several sorters on several recording with different parallel engines.
#  here we use engine 'loop' but we could use also  'joblib' or 'dask' for multi process or multi node computing.
#  have a look to the documentation of this function that handle many cases.

sorter_list = ['tridesclous']
working_folder = 'sorter_outputs'
results = ss.run_sorters(sorter_list,
                         recordings,
                         working_folder,
                         engine='loop',
                         with_output=True,
                         mode_if_folder_exists='overwrite')

##############################################################################
#  the output is a dict with all combinations of (group, sorter_name)

from pprint import pprint
pprint(results)