def test_load_save_probes(self): sub_RX = se.load_probe_file(self.RX, 'spikeextractors/tests/probe_test.prb') # print(SX.get_channel_property_names()) assert 'location' in sub_RX.get_shared_channel_property_names() assert 'group' in sub_RX.get_shared_channel_property_names() positions = [sub_RX.get_channel_locations(chan)[0] for chan in range(self.RX.get_num_channels())] # save in csv sub_RX.save_to_probe_file(Path(self.test_dir) / 'geom.csv') # load csv locations sub_RX_load = sub_RX.load_probe_file(Path(self.test_dir) / 'geom.csv') position_loaded = [sub_RX_load.get_channel_locations(chan)[0] for chan in range(sub_RX_load.get_num_channels())] self.assertTrue(np.allclose(positions[10], position_loaded[10])) # prb file RX = copy(self.RX) channel_groups = [] n_group = 4 for i in RX.get_channel_ids(): channel_groups.append(i // n_group) RX.set_channel_groups(channel_groups) RX.save_to_probe_file('spikeextractors/tests/probe_test_no_groups.prb') RX.save_to_probe_file('spikeextractors/tests/probe_test_groups.prb', grouping_property='group') # load RX_loaded_no_groups = se.load_probe_file(RX, 'spikeextractors/tests/probe_test_no_groups.prb') RX_loaded_groups = se.load_probe_file(RX, 'spikeextractors/tests/probe_test_groups.prb') assert len(np.unique(RX_loaded_no_groups.get_channel_groups())) == 1 assert len(np.unique(RX_loaded_groups.get_channel_groups())) == RX.get_num_channels() // n_group # cleanup os.remove('spikeextractors/tests/probe_test_no_groups.prb') os.remove('spikeextractors/tests/probe_test_groups.prb')
def test_with_BinDatRecordingExtractor(self): # some sorter (TDC, KS, KS2, ...) work by default with the raw binary # format as input to avoid copy when the recording is already this format recording, sorting_gt = se.example_datasets.toy_example(num_channels=2, duration=10) # create a raw dat file and prb file raw_filename = 'raw_file.dat' prb_filename = 'raw_file.prb' samplerate = recording.get_sampling_frequency() traces = recording.get_traces().astype('float32') with open(raw_filename, mode='wb') as f: f.write(traces.T.tobytes()) se.save_probe_file(recording, prb_filename, format='spyking_circus') recording = se.BinDatRecordingExtractor(raw_filename, samplerate, 2, 'float32', frames_first=True, offset=0) se.load_probe_file(recording, prb_filename) params = self.SorterClass.default_params() sorter = self.SorterClass(recording=recording, output_folder=None) sorter.set_params(**params) sorter.run() sorting = sorter.get_result()
def test_load_save_probes(self): SX = se.load_probe_file(self.RX, 'tests/probe_test.prb') # print(SX.get_channel_property_names()) assert 'location' in SX.get_channel_property_names() assert 'group' in SX.get_channel_property_names() positions = [ SX.get_channel_property(chan, 'location') for chan in range(self.RX.get_num_channels()) ] # save in csv se.save_probe_file(SX, Path(self.test_dir) / 'geom.csv') # load csv locations SX_load = se.load_probe_file(SX, Path(self.test_dir) / 'geom.csv') position_loaded = [ SX_load.get_channel_property(chan, 'location') for chan in range(SX_load.get_num_channels()) ] self.assertTrue(np.allclose(positions[10], position_loaded[10]))
def get_recordings(study_folder): """ Get ground recording as a dict. They are read from the 'raw_files' folder with binary format. Parameters ---------- study_folder: str The study folder. Returns ---------- recording_dict: dict Dict of rexording. """ study_folder = Path(study_folder) rec_names = get_rec_names(study_folder) recording_dict = {} for rec_name in rec_names: raw_filename = study_folder / 'raw_files' / (rec_name + '.dat') prb_filename = study_folder / 'raw_files' / (rec_name + '.prb') json_filename = study_folder / 'raw_files' / (rec_name + '.json') with open(json_filename, 'r', encoding='utf8') as f: info = json.load(f) rec = se.BinDatRecordingExtractor(raw_filename, info['sample_rate'], info['num_chan'], info['dtype'], frames_first=info['frames_first']) se.load_probe_file(rec, prb_filename) recording_dict[rec_name] = rec return recording_dict
def setup_study(): rec_names = [ '20160415_patch2', '20160426_patch2', '20160426_patch3', '20170621_patch1', '20170713_patch1', '20170725_patch1', '20170728_patch2', '20170803_patch1', ] gt_dict = {} for rec_name in rec_names: # find raw file dirname = recording_folder + rec_name + '/' for f in os.listdir(dirname): if f.endswith('.raw') and not f.endswith('juxta.raw'): mea_filename = dirname + f # raw files have an internal offset that depend on the channel count # a simple built header can be parsed to get it with open(mea_filename.replace('.raw', '.txt'), mode='r') as f: offset = int(re.findall('padding = (\d+)', f.read())[0]) # recording rec = se.BinDatRecordingExtractor(mea_filename, 20000., 256, 'uint16', offset=offset, frames_first=True) # this reduce channel count to 252 rec = se.load_probe_file(rec, basedir + 'mea_256.prb') # gt sorting gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64') sorting_gt = se.NumpySortingExtractor() sorting_gt.set_times_labels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64')) sorting_gt.set_sampling_frequency(20000.0) gt_dict[rec_name] = (rec, sorting_gt) study = GroundTruthStudy.setup(study_folder, gt_dict)
def run_sorters(sorter_list, recording_dict_or_list, working_folder, grouping_property=None, shared_binary_copy=False, engine=None, engine_kargs={}, debug=False, write_log=True): """ This run several sorter on several recording. Simple implementation will nested loops. Need to be done with multiprocessing. sorter_list: list of str (sorter names) recording_dict_or_list: a dict (or a list) of recording working_folder : str engine = None ( = 'loop') or 'multiprocessing' processes = only if 'multiprocessing' if None then processes=os.cpu_count() debug=True/False to control sorter verbosity Note: engine='multiprocessing' use the python multiprocessing module. This do not allow to have subprocess in subprocess. So sorter that already use internally multiprocessing, this will fail. Parameters ---------- sorter_list: list of str List of sorter name. recording_dict_or_list: dict or list A dict of recording. The key will be the name of the recording. In a list is given then the name will be recording_0, recording_1, ... working_folder: str The working directory. This must not exists before calling this function. grouping_property: str The property of grouping given to sorters. shared_binary_copy: False default Before running each sorter, all recording are copied inside the working_folder with the raw binary format (BinDatRecordingExtractor) and new recording are instantiated as BinDatRecordingExtractor. This avoids multiple copy inside each sorter of the same file but imply a global of all files. engine: str 'loop' or 'multiprocessing' engine_kargs: dict This contains kargs specific to the launcher engine: * 'loop' : no kargs * 'multiprocessing' : {'processes' : } number of processes debug: bool default True write_log: bool default True Output ---------- results : dict The output is nested dict[rec_name][sorter_name] of SortingExtrator. """ assert not os.path.exists( working_folder), 'working_folder already exists, please remove it' working_folder = Path(working_folder) for sorter_name in sorter_list: assert sorter_name in sorter_dict, '{} is not in sorter list'.format( sorter_name) if isinstance(recording_dict_or_list, list): # in case of list recording_dict = { 'recording_{}'.format(i): rec for i, rec in enumerate(recording_dict_or_list) } elif isinstance(recording_dict_or_list, dict): recording_dict = recording_dict_or_list else: raise (ValueError('bad recording dict')) if shared_binary_copy: os.makedirs(working_folder / 'raw_files') old_rec_dict = dict(recording_dict) recording_dict = {} for rec_name, recording in old_rec_dict.items(): if grouping_property is not None: recording_list = se.get_sub_extractors_by_property( recording, grouping_property) n_group = len(recording_list) assert n_group == 1, 'shared_binary_copy work only when one group' recording = recording_list[0] grouping_property = None raw_filename = working_folder / 'raw_files' / (rec_name + '.raw') prb_filename = working_folder / 'raw_files' / (rec_name + '.prb') n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan sr = recording.get_sampling_frequency() # save binary se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) # save location (with PRB format) se.save_probe_file(recording, prb_filename, format='spyking_circus') # make new recording new_rec = se.BinDatRecordingExtractor(raw_filename, sr, n_chan, 'float32', frames_first=True) se.load_probe_file(new_rec, prb_filename) recording_dict[rec_name] = new_rec task_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / 'output_folders' / rec_name / sorter_name task_list.append((rec_name, recording, sorter_name, output_folder, grouping_property, debug, write_log)) if engine is None or engine == 'loop': # simple loop in main process for arg_list in task_list: # print(arg_list) _run_one(arg_list) elif engine == 'multiprocessing': # use mp.Pool processes = engine_kargs.get('processes', None) pool = multiprocessing.Pool(processes) pool.map(_run_one, task_list) if write_log: # collect run time and write to cvs with open(working_folder / 'run_time.csv', mode='w') as f: for task in task_list: rec_name = task[0] sorter_name = task[2] output_folder = task[3] if os.path.exists(output_folder / 'run_log.txt'): with open(output_folder / 'run_log.txt', mode='r') as logfile: run_time = float(logfile.readline().replace( 'run_time:', '')) txt = '{}\t{}\t{}\n'.format(rec_name, sorter_name, run_time) f.write(txt) results = collect_results(working_folder) return results