コード例 #1
0
def get_ground_truths(study_folder):
    """
    Get ground truth sorting extractor as a dict.

    They are read from the 'ground_truth' folder with npz format.

    Parameters
    ----------
    study_folder: str
        The study folder.

    Returns
    ----------

    ground_truths: dict
        Dict of sorintg_gt.

    """
    study_folder = Path(study_folder)
    rec_names = get_rec_names(study_folder)
    ground_truths = {}
    for rec_name in rec_names:
        sorting = se.NpzSortingExtractor(study_folder / 'ground_truth' /
                                         (rec_name + '.npz'))
        ground_truths[rec_name] = sorting
    return ground_truths
コード例 #2
0
def iter_computed_sorting(study_folder):
    """
    Iter over sorting files.
    """
    sorting_folder = Path(study_folder) / 'sortings'
    for filename in os.listdir(sorting_folder):
        if filename.endswith('.npz') and '[#]' in filename:
            rec_name, sorter_name = filename.replace('.npz', '').split('[#]')
            sorting = se.NpzSortingExtractor(sorting_folder / filename)
            yield rec_name, sorter_name, sorting
コード例 #3
0
def create_dumpable_extractors_from_existing(folder, RX, SX):
    folder = Path(folder)

    if 'location' not in RX.get_shared_channel_property_names():
        RX.set_channel_locations(np.random.randn(RX.get_num_channels(), 2))
    se.MdaRecordingExtractor.write_recording(RX, folder)
    RX_mda = se.MdaRecordingExtractor(folder)
    se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz')
    SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz')

    return RX_mda, SX_npz
コード例 #4
0
    def test_write_then_read(self):


        recording, sorting_gt = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)

        se.NpzSortingExtractor.write_sorting(sorting_gt, 'test_NpzSortingExtractors.npz')

        npz = np.load('test_NpzSortingExtractors.npz')
        sorting_npz = se.NpzSortingExtractor('test_NpzSortingExtractors.npz')
        units_ids = npz['unit_ids']
        self.assertEqual(list(units_ids), list(sorting_gt.get_unit_ids()))
        self.assertEqual(list(sorting_npz.get_unit_ids()), list(sorting_gt.get_unit_ids()))
        self.assertEqual(sorting_npz.get_sampling_frequency(), 30000.0)
コード例 #5
0
    def test_npz_extractor(self):
        path = self.test_dir + '/sorting.npz'
        se.NpzSortingExtractor.write_sorting(self.SX, path)
        SX_npz = se.NpzSortingExtractor(path)

        # empty write
        sorting_empty = se.NumpySortingExtractor()
        path_empty = self.test_dir + '/sorting_empty.npz'
        se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty)

        check_sorting_return_types(SX_npz)
        check_sortings_equal(self.SX, SX_npz)
        check_dumping(SX_npz)
コード例 #6
0
ファイル: groundtruthstudy.py プロジェクト: yger/spiketoolkit
def collect_study_sorting(study_folder):
    """
    Collect sorting from the copied version.
    """
    sorting_folder = Path(study_folder) / 'sortings'

    sortings = {}
    for filename in os.listdir(sorting_folder):
        if filename.endswith('.npz') and '[#]' in filename:
            rec_name, sorter_name = filename.replace('.npz', '').split('[#]')
            sorting = se.NpzSortingExtractor(sorting_folder / filename)
            sortings[(rec_name, sorter_name)] = sorting

    return sortings
コード例 #7
0
def toy_example(duration=10,
                num_channels=4,
                sampling_frequency=30000.0,
                K=10,
                dumpable=False,
                dump_folder=None,
                seed=None):
    upsamplefac = 13

    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
コード例 #8
0
def create_dumpable_extractors(folder,
                               duration=10,
                               num_channels=4,
                               sampling_frequency=30000.0,
                               K=10,
                               seed=None):
    RX, SX = toy_example(duration=duration,
                         num_channels=num_channels,
                         K=K,
                         sampling_frequency=sampling_frequency,
                         seed=seed)

    folder = Path(folder)

    se.MdaRecordingExtractor.write_recording(RX, folder)
    RX_mda = se.MdaRecordingExtractor(folder)
    se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz')
    SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz')

    return RX_mda, SX_npz
コード例 #9
0
 def get_ground_truth(self, rec_name=None):
     rec_name = self._check_rec_name(rec_name)
     sorting = se.NpzSortingExtractor(self.study_folder / 'ground_truth' /
                                      (rec_name + '.npz'))
     return sorting
コード例 #10
0
def toy_example(duration: float = 10.,
                num_channels: int = 4,
                sampling_frequency: float = 30000.,
                K: int = 10,
                dumpable: bool = False,
                dump_folder: Optional[Union[str, Path]] = None,
                seed: Optional[int] = None):
    """
    Create toy recording and sorting extractors.

    Parameters
    ----------
    duration: float
        Duration in s (default 10)
    num_channels: int
        Number of channels (default 4)
    sampling_frequency: float
        Sampling frequency (default 30000)
    K: int
        Number of units (default 10)
    dumpable: bool
        If True, objects are dumped to file and become 'dumpable'
    dump_folder: str or Path
        Path to dump folder (if None, 'test' is used
    seed: int
        Seed for random initialization

    Returns
    -------
    recording: RecordingExtractor
        The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        MdaRecordingExtractor
    sorting: SortingExtractor
        The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an
        NpzSortingExtractor
    """
    upsamplefac = 13
    waveforms, geom = synthesize_random_waveforms(K=K,
                                                  M=num_channels,
                                                  average_peak_amplitude=-100,
                                                  upsamplefac=upsamplefac,
                                                  seed=seed)
    times, labels = synthesize_random_firings(
        K=K,
        duration=duration,
        sampling_frequency=sampling_frequency,
        seed=seed)
    labels = labels.astype(np.int64)
    SX = se.NumpySortingExtractor()
    SX.set_times_labels(times, labels)
    X = synthesize_timeseries(sorting=SX,
                              waveforms=waveforms,
                              noise_level=10,
                              sampling_frequency=sampling_frequency,
                              duration=duration,
                              waveform_upsamplefac=upsamplefac,
                              seed=seed)
    SX.set_sampling_frequency(sampling_frequency)

    RX = se.NumpyRecordingExtractor(timeseries=X,
                                    sampling_frequency=sampling_frequency,
                                    geom=geom)
    RX.is_filtered = True

    if dumpable:
        if dump_folder is None:
            dump_folder = 'toy_example'
        dump_folder = Path(dump_folder)

        se.MdaRecordingExtractor.write_recording(RX, dump_folder)
        RX = se.MdaRecordingExtractor(dump_folder)
        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')
        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')

    return RX, SX
コード例 #11
0
def run_sorter_docker(sorter_name,
                      recording,
                      output_folder,
                      delete_output_folder=False,
                      grouping_property=None,
                      parallel=False,
                      verbose=False,
                      raise_error=True,
                      n_jobs=-1,
                      joblib_backend='loky',
                      use_docker=True,
                      container=None,
                      **params):
    if use_docker:
        # if container is None:
        #     assert sorter_name in default_docker_images, f"Default docker image for {sorter_name} not found"
        #     docker_image = default_docker_images[sorter_name]
        #
        # print(f"Running in docker image {docker_image.get_name()}")
        output_folder = Path(output_folder).absolute()
        output_folder.mkdir(exist_ok=True, parents=True)

        # dump recording with relative file paths to docker container /input folder
        dump_dict_container, input_directory = modify_input_folder(
            recording.dump_to_dict(), '/input')

        with hither.Config(use_container=False, show_console=True):
            kwargs = dict(recording_dict=dump_dict_container,
                          sorter_name=sorter_name,
                          output_folder=str(output_folder),
                          delete_output_folder=False,
                          grouping_property=grouping_property,
                          parallel=parallel,
                          verbose=verbose,
                          raise_error=raise_error,
                          n_jobs=n_jobs,
                          joblib_backend=joblib_backend)
            kwargs.update(params)
            kwargs.update({
                'input_directory': str(input_directory),
                'output_directory': str(output_folder)
            })

            sorting_job = hither.Job(run_sorter_docker_with_container, kwargs)
            sorting_job.wait()
        sorting = se.NpzSortingExtractor(output_folder / "sorting_docker.npz")
    else:
        # standard call
        sorting = ss.run_sorter(sorter_name,
                                recording,
                                output_folder=output_folder,
                                delete_output_folder=delete_output_folder,
                                grouping_property=grouping_property,
                                parallel=parallel,
                                verbose=verbose,
                                raise_error=raise_error,
                                n_jobs=n_jobs,
                                joblib_backend=joblib_backend,
                                **params)

    return sorting