Пример #1
0
    def test_load_save_probes(self):
        sub_RX = se.load_probe_file(self.RX, 'spikeextractors/tests/probe_test.prb')
        # print(SX.get_channel_property_names())
        assert 'location' in sub_RX.get_shared_channel_property_names()
        assert 'group' in sub_RX.get_shared_channel_property_names()
        positions = [sub_RX.get_channel_locations(chan)[0] for chan in range(self.RX.get_num_channels())]
        # save in csv
        sub_RX.save_to_probe_file(Path(self.test_dir) / 'geom.csv')
        # load csv locations
        sub_RX_load = sub_RX.load_probe_file(Path(self.test_dir) / 'geom.csv')
        position_loaded = [sub_RX_load.get_channel_locations(chan)[0] for
                           chan in range(sub_RX_load.get_num_channels())]
        self.assertTrue(np.allclose(positions[10], position_loaded[10]))

        # prb file
        RX = copy(self.RX)
        channel_groups = []
        n_group = 4
        for i in RX.get_channel_ids():
            channel_groups.append(i // n_group)
        RX.set_channel_groups(channel_groups)
        RX.save_to_probe_file('spikeextractors/tests/probe_test_no_groups.prb')
        RX.save_to_probe_file('spikeextractors/tests/probe_test_groups.prb', grouping_property='group')

        # load
        RX_loaded_no_groups = se.load_probe_file(RX, 'spikeextractors/tests/probe_test_no_groups.prb')
        RX_loaded_groups = se.load_probe_file(RX, 'spikeextractors/tests/probe_test_groups.prb')

        assert len(np.unique(RX_loaded_no_groups.get_channel_groups())) == 1
        assert len(np.unique(RX_loaded_groups.get_channel_groups())) == RX.get_num_channels() // n_group

        # cleanup
        os.remove('spikeextractors/tests/probe_test_no_groups.prb')
        os.remove('spikeextractors/tests/probe_test_groups.prb')
Пример #2
0
    def test_with_BinDatRecordingExtractor(self):
        # some sorter (TDC, KS, KS2, ...) work by default with the raw binary
        # format as input to avoid copy when the recording is already this format

        recording, sorting_gt = se.example_datasets.toy_example(num_channels=2,
                                                                duration=10)

        # create a raw dat file and prb file
        raw_filename = 'raw_file.dat'
        prb_filename = 'raw_file.prb'

        samplerate = recording.get_sampling_frequency()
        traces = recording.get_traces().astype('float32')
        with open(raw_filename, mode='wb') as f:
            f.write(traces.T.tobytes())

        se.save_probe_file(recording, prb_filename, format='spyking_circus')

        recording = se.BinDatRecordingExtractor(raw_filename,
                                                samplerate,
                                                2,
                                                'float32',
                                                frames_first=True,
                                                offset=0)
        se.load_probe_file(recording, prb_filename)

        params = self.SorterClass.default_params()
        sorter = self.SorterClass(recording=recording, output_folder=None)
        sorter.set_params(**params)
        sorter.run()
        sorting = sorter.get_result()
Пример #3
0
 def test_load_save_probes(self):
     SX = se.load_probe_file(self.RX, 'tests/probe_test.prb')
     # print(SX.get_channel_property_names())
     assert 'location' in SX.get_channel_property_names()
     assert 'group' in SX.get_channel_property_names()
     positions = [
         SX.get_channel_property(chan, 'location')
         for chan in range(self.RX.get_num_channels())
     ]
     # save in csv
     se.save_probe_file(SX, Path(self.test_dir) / 'geom.csv')
     # load csv locations
     SX_load = se.load_probe_file(SX, Path(self.test_dir) / 'geom.csv')
     position_loaded = [
         SX_load.get_channel_property(chan, 'location')
         for chan in range(SX_load.get_num_channels())
     ]
     self.assertTrue(np.allclose(positions[10], position_loaded[10]))
Пример #4
0
def get_recordings(study_folder):
    """
    Get ground recording as a dict.
    
    They are read from the 'raw_files' folder with binary format.
    
    Parameters
    ----------
    study_folder: str
        The study folder.
    
    Returns
    ----------
    
    recording_dict: dict
        Dict of rexording.
        
    """
    study_folder = Path(study_folder)

    rec_names = get_rec_names(study_folder)
    recording_dict = {}
    for rec_name in rec_names:
        raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
        prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
        json_filename = study_folder / 'raw_files' / (rec_name + '.json')
        with open(json_filename, 'r', encoding='utf8') as f:
            info = json.load(f)

        rec = se.BinDatRecordingExtractor(raw_filename,
                                          info['sample_rate'],
                                          info['num_chan'],
                                          info['dtype'],
                                          frames_first=info['frames_first'])
        se.load_probe_file(rec, prb_filename)

        recording_dict[rec_name] = rec

    return recording_dict
Пример #5
0
def setup_study():
    rec_names = [
        '20160415_patch2',
        '20160426_patch2',
        '20160426_patch3',
        '20170621_patch1',
        '20170713_patch1',
        '20170725_patch1',
        '20170728_patch2',
        '20170803_patch1',
    ]

    gt_dict = {}
    for rec_name in rec_names:

        # find raw file
        dirname = recording_folder + rec_name + '/'
        for f in os.listdir(dirname):
            if f.endswith('.raw') and not f.endswith('juxta.raw'):
                mea_filename = dirname + f

        # raw files have an internal offset that depend on the channel count
        # a simple built header can be parsed to get it
        with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
            offset = int(re.findall('padding = (\d+)', f.read())[0])

        # recording
        rec = se.BinDatRecordingExtractor(mea_filename,
                                          20000.,
                                          256,
                                          'uint16',
                                          offset=offset,
                                          frames_first=True)

        # this reduce channel count to 252
        rec = se.load_probe_file(rec, basedir + 'mea_256.prb')

        # gt sorting
        gt_indexes = np.fromfile(ground_truth_folder + rec_name +
                                 '/juxta_peak_indexes.raw',
                                 dtype='int64')
        sorting_gt = se.NumpySortingExtractor()
        sorting_gt.set_times_labels(gt_indexes,
                                    np.zeros(gt_indexes.size, dtype='int64'))
        sorting_gt.set_sampling_frequency(20000.0)

        gt_dict[rec_name] = (rec, sorting_gt)

    study = GroundTruthStudy.setup(study_folder, gt_dict)
Пример #6
0
def run_sorters(sorter_list,
                recording_dict_or_list,
                working_folder,
                grouping_property=None,
                shared_binary_copy=False,
                engine=None,
                engine_kargs={},
                debug=False,
                write_log=True):
    """
    This run several sorter on several recording.
    Simple implementation will nested loops.

    Need to be done with multiprocessing.

    sorter_list: list of str (sorter names)
    recording_dict_or_list: a dict (or a list) of recording
    working_folder : str

    engine = None ( = 'loop') or 'multiprocessing'
    processes = only if 'multiprocessing' if None then processes=os.cpu_count()
    debug=True/False to control sorter verbosity


    Note: engine='multiprocessing' use the python multiprocessing module.
    This do not allow to have subprocess in subprocess.
    So sorter that already use internally multiprocessing, this will fail.

    Parameters
    ----------
    
    sorter_list: list of str
        List of sorter name.
    
    recording_dict_or_list: dict or list
        A dict of recording. The key will be the name of the recording.
        In a list is given then the name will be recording_0, recording_1, ...
    
    working_folder: str
        The working directory.
        This must not exists before calling this function.
    
    grouping_property: str
        The property of grouping given to sorters.
    
    shared_binary_copy: False default
        Before running each sorter, all recording are copied inside 
        the working_folder with the raw binary format (BinDatRecordingExtractor)
        and new recording are instantiated as BinDatRecordingExtractor.
        This avoids multiple copy inside each sorter of the same file but
        imply a global of all files.

    engine: str
        'loop' or 'multiprocessing'
    
    engine_kargs: dict
        This contains kargs specific to the launcher engine:
            * 'loop' : no kargs
            * 'multiprocessing' : {'processes' : } number of processes
    
    debug: bool
        default True
    
    write_log: bool
        default True
    
    Output
    ----------
    
    results : dict
        The output is nested dict[rec_name][sorter_name] of SortingExtrator.



    """

    assert not os.path.exists(
        working_folder), 'working_folder already exists, please remove it'
    working_folder = Path(working_folder)

    for sorter_name in sorter_list:
        assert sorter_name in sorter_dict, '{} is not in sorter list'.format(
            sorter_name)

    if isinstance(recording_dict_or_list, list):
        # in case of list
        recording_dict = {
            'recording_{}'.format(i): rec
            for i, rec in enumerate(recording_dict_or_list)
        }
    elif isinstance(recording_dict_or_list, dict):
        recording_dict = recording_dict_or_list
    else:
        raise (ValueError('bad recording dict'))

    if shared_binary_copy:
        os.makedirs(working_folder / 'raw_files')
        old_rec_dict = dict(recording_dict)
        recording_dict = {}
        for rec_name, recording in old_rec_dict.items():
            if grouping_property is not None:
                recording_list = se.get_sub_extractors_by_property(
                    recording, grouping_property)
                n_group = len(recording_list)
                assert n_group == 1, 'shared_binary_copy work only when one group'
                recording = recording_list[0]
                grouping_property = None

            raw_filename = working_folder / 'raw_files' / (rec_name + '.raw')
            prb_filename = working_folder / 'raw_files' / (rec_name + '.prb')
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            sr = recording.get_sampling_frequency()

            # save binary
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            # save location (with PRB format)
            se.save_probe_file(recording,
                               prb_filename,
                               format='spyking_circus')

            # make new  recording
            new_rec = se.BinDatRecordingExtractor(raw_filename,
                                                  sr,
                                                  n_chan,
                                                  'float32',
                                                  frames_first=True)
            se.load_probe_file(new_rec, prb_filename)
            recording_dict[rec_name] = new_rec

    task_list = []
    for rec_name, recording in recording_dict.items():
        for sorter_name in sorter_list:
            output_folder = working_folder / 'output_folders' / rec_name / sorter_name
            task_list.append((rec_name, recording, sorter_name, output_folder,
                              grouping_property, debug, write_log))

    if engine is None or engine == 'loop':
        # simple loop in main process
        for arg_list in task_list:
            # print(arg_list)
            _run_one(arg_list)

    elif engine == 'multiprocessing':
        # use mp.Pool
        processes = engine_kargs.get('processes', None)
        pool = multiprocessing.Pool(processes)
        pool.map(_run_one, task_list)

    if write_log:
        # collect run time and write to cvs
        with open(working_folder / 'run_time.csv', mode='w') as f:
            for task in task_list:
                rec_name = task[0]
                sorter_name = task[2]
                output_folder = task[3]
                if os.path.exists(output_folder / 'run_log.txt'):
                    with open(output_folder / 'run_log.txt',
                              mode='r') as logfile:
                        run_time = float(logfile.readline().replace(
                            'run_time:', ''))

                    txt = '{}\t{}\t{}\n'.format(rec_name, sorter_name,
                                                run_time)
                    f.write(txt)

    results = collect_results(working_folder)
    return results