def test_with_BinDatRecordingExtractor(self): # some sorter (TDC, KS, KS2, ...) work by default with the raw binary # format as input to avoid copy when the recording is already this format recording, sorting_gt = se.example_datasets.toy_example(num_channels=2, duration=10) # create a raw dat file and prb file raw_filename = 'raw_file.dat' prb_filename = 'raw_file.prb' samplerate = recording.get_sampling_frequency() traces = recording.get_traces().astype('float32') with open(raw_filename, mode='wb') as f: f.write(traces.T.tobytes()) se.save_probe_file(recording, prb_filename, format='spyking_circus') recording = se.BinDatRecordingExtractor(raw_filename, samplerate, 2, 'float32', frames_first=True, offset=0) se.load_probe_file(recording, prb_filename) params = self.SorterClass.default_params() sorter = self.SorterClass(recording=recording, output_folder=None) sorter.set_params(**params) sorter.run() sorting = sorter.get_result()
def get_one_recording(study_folder, rec_name): """ Get one recording from its name Parameters ---------- study_folder: str The study folder. rec_name: str The recording name Returns ---------- recording: RecordingExtractor The recording. """ raw_filename = study_folder / 'raw_files' / (rec_name + '.dat') prb_filename = study_folder / 'raw_files' / (rec_name + '.prb') json_filename = study_folder / 'raw_files' / (rec_name + '.json') with open(json_filename, 'r', encoding='utf8') as f: info = json.load(f) rec = se.BinDatRecordingExtractor(raw_filename, info['sample_rate'], info['num_chan'], info['dtype'], time_axis=info['time_axis']) load_probe_file_inplace(rec, prb_filename) return rec
def test_with_BinDatRecordingExtractor(self): # some sorter (TDC, KS, KS2, ...) work by default with the raw binary # format as input to avoid copy when the recording is already this format recording, sorting_gt = se.example_datasets.toy_example(num_channels=2, duration=10, seed=0) # create a raw dat file and prb file raw_filename = 'raw_file.dat' prb_filename = 'raw_file.prb' samplerate = recording.get_sampling_frequency() traces = recording.get_traces().astype('float32') with open(raw_filename, mode='wb') as f: f.write(traces.T.tobytes()) recording.save_to_probe_file(prb_filename) recording = se.BinDatRecordingExtractor(raw_filename, samplerate, 2, 'float32', time_axis=0, offset=0) recording = recording.load_probe_file(prb_filename) params = self.SorterClass.default_params() sorter = self.SorterClass(recording=recording, output_folder=None) sorter.set_params(**params) sorter.run() sorting = sorter.get_result() for unit_id in sorting.get_unit_ids(): print('unit #', unit_id, 'nb', len(sorting.get_unit_spike_train(unit_id))) del sorting
def __init__(self, probe_file, xml_file, nrs_file, dat_file): se.RecordingExtractor.__init__(self) # info = check_load_nrs(dirpath) # assert info is not None probe_obj = kp.load_object(probe_file) xml_file = kp.load_file(xml_file) # nrs_file = kp.load_file(nrs_file) dat_file = kp.load_file(dat_file) from xml.etree import ElementTree as ET xml = ET.parse(xml_file) root_element = xml.getroot() try: txt = root_element.find('acquisitionSystem/samplingRate').text assert txt is not None self._samplerate = float(txt) except: raise Exception('Unable to load acquisitionSystem/samplingRate') try: txt = root_element.find('acquisitionSystem/nChannels').text assert txt is not None self._nChannels = int(txt) except: raise Exception('Unable to load acquisitionSystem/nChannels') try: txt = root_element.find('acquisitionSystem/nBits').text assert txt is not None self._nBits = int(txt) except: raise Exception('Unable to load acquisitionSystem/nBits') if self._nBits == 16: dtype = np.int16 elif self._nBits == 32: dtype = np.int32 else: raise Exception(f'Unexpected nBits: {self._nBits}') self._rec = se.BinDatRecordingExtractor( dat_file, sampling_frequency=self._samplerate, numchan=self._nChannels, dtype=dtype) self._channel_ids = probe_obj['channel'] for ii in range(len(probe_obj['channel'])): channel = probe_obj['channel'][ii] x = probe_obj['x'][ii] y = probe_obj['y'][ii] z = probe_obj['z'][ii] group = probe_obj.get('group', probe_obj.get('shank'))[ii] self.set_channel_property(channel, 'location', [x, y, z]) self.set_channel_property(channel, 'group', group)
def setup_study(): rec_names = [ '20160415_patch2', '20160426_patch2', '20160426_patch3', '20170621_patch1', '20170713_patch1', '20170725_patch1', '20170728_patch2', '20170803_patch1', ] gt_dict = {} for rec_name in rec_names: # find raw file dirname = recording_folder + rec_name + '/' for f in os.listdir(dirname): if f.endswith('.raw') and not f.endswith('juxta.raw'): mea_filename = dirname + f # raw files have an internal offset that depend on the channel count # a simple built header can be parsed to get it with open(mea_filename.replace('.raw', '.txt'), mode='r') as f: offset = int(re.findall('padding = (\d+)', f.read())[0]) # recording rec = se.BinDatRecordingExtractor(mea_filename, 20000., 256, 'uint16', offset=offset, frames_first=True) # this reduce channel count to 252 rec = se.load_probe_file(rec, basedir + 'mea_256.prb') # gt sorting gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64') sorting_gt = se.NumpySortingExtractor() sorting_gt.set_times_labels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64')) sorting_gt.set_sampling_frequency(20000.0) gt_dict[rec_name] = (rec, sorting_gt) study = GroundTruthStudy.setup(study_folder, gt_dict)
def get_recordings(study_folder): """ Get ground recording as a dict. They are read from the 'raw_files' folder with binary format. Parameters ---------- study_folder: str The study folder. Returns ---------- recording_dict: dict Dict of rexording. """ study_folder = Path(study_folder) rec_names = get_rec_names(study_folder) recording_dict = {} for rec_name in rec_names: raw_filename = study_folder / 'raw_files' / (rec_name + '.dat') prb_filename = study_folder / 'raw_files' / (rec_name + '.prb') json_filename = study_folder / 'raw_files' / (rec_name + '.json') with open(json_filename, 'r', encoding='utf8') as f: info = json.load(f) rec = se.BinDatRecordingExtractor(raw_filename, info['sample_rate'], info['num_chan'], info['dtype'], frames_first=info['frames_first']) se.load_probe_file(rec, prb_filename) recording_dict[rec_name] = rec return recording_dict
def run_sorters(sorter_list, recording_dict_or_list, working_folder, grouping_property=None, shared_binary_copy=False, engine=None, engine_kargs={}, debug=False, write_log=True): """ This run several sorter on several recording. Simple implementation will nested loops. Need to be done with multiprocessing. sorter_list: list of str (sorter names) recording_dict_or_list: a dict (or a list) of recording working_folder : str engine = None ( = 'loop') or 'multiprocessing' processes = only if 'multiprocessing' if None then processes=os.cpu_count() debug=True/False to control sorter verbosity Note: engine='multiprocessing' use the python multiprocessing module. This do not allow to have subprocess in subprocess. So sorter that already use internally multiprocessing, this will fail. Parameters ---------- sorter_list: list of str List of sorter name. recording_dict_or_list: dict or list A dict of recording. The key will be the name of the recording. In a list is given then the name will be recording_0, recording_1, ... working_folder: str The working directory. This must not exists before calling this function. grouping_property: str The property of grouping given to sorters. shared_binary_copy: False default Before running each sorter, all recording are copied inside the working_folder with the raw binary format (BinDatRecordingExtractor) and new recording are instantiated as BinDatRecordingExtractor. This avoids multiple copy inside each sorter of the same file but imply a global of all files. engine: str 'loop' or 'multiprocessing' engine_kargs: dict This contains kargs specific to the launcher engine: * 'loop' : no kargs * 'multiprocessing' : {'processes' : } number of processes debug: bool default True write_log: bool default True Output ---------- results : dict The output is nested dict[rec_name][sorter_name] of SortingExtrator. """ assert not os.path.exists( working_folder), 'working_folder already exists, please remove it' working_folder = Path(working_folder) for sorter_name in sorter_list: assert sorter_name in sorter_dict, '{} is not in sorter list'.format( sorter_name) if isinstance(recording_dict_or_list, list): # in case of list recording_dict = { 'recording_{}'.format(i): rec for i, rec in enumerate(recording_dict_or_list) } elif isinstance(recording_dict_or_list, dict): recording_dict = recording_dict_or_list else: raise (ValueError('bad recording dict')) if shared_binary_copy: os.makedirs(working_folder / 'raw_files') old_rec_dict = dict(recording_dict) recording_dict = {} for rec_name, recording in old_rec_dict.items(): if grouping_property is not None: recording_list = se.get_sub_extractors_by_property( recording, grouping_property) n_group = len(recording_list) assert n_group == 1, 'shared_binary_copy work only when one group' recording = recording_list[0] grouping_property = None raw_filename = working_folder / 'raw_files' / (rec_name + '.raw') prb_filename = working_folder / 'raw_files' / (rec_name + '.prb') n_chan = recording.get_num_channels() chunksize = 2**24 // n_chan sr = recording.get_sampling_frequency() # save binary se.write_binary_dat_format(recording, raw_filename, time_axis=0, dtype='float32', chunksize=chunksize) # save location (with PRB format) se.save_probe_file(recording, prb_filename, format='spyking_circus') # make new recording new_rec = se.BinDatRecordingExtractor(raw_filename, sr, n_chan, 'float32', frames_first=True) se.load_probe_file(new_rec, prb_filename) recording_dict[rec_name] = new_rec task_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / 'output_folders' / rec_name / sorter_name task_list.append((rec_name, recording, sorter_name, output_folder, grouping_property, debug, write_log)) if engine is None or engine == 'loop': # simple loop in main process for arg_list in task_list: # print(arg_list) _run_one(arg_list) elif engine == 'multiprocessing': # use mp.Pool processes = engine_kargs.get('processes', None) pool = multiprocessing.Pool(processes) pool.map(_run_one, task_list) if write_log: # collect run time and write to cvs with open(working_folder / 'run_time.csv', mode='w') as f: for task in task_list: rec_name = task[0] sorter_name = task[2] output_folder = task[3] if os.path.exists(output_folder / 'run_log.txt'): with open(output_folder / 'run_log.txt', mode='r') as logfile: run_time = float(logfile.readline().replace( 'run_time:', '')) txt = '{}\t{}\t{}\n'.format(rec_name, sorter_name, run_time) f.write(txt) results = collect_results(working_folder) return results