Example #1
0
    def test_with_BinDatRecordingExtractor(self):
        # some sorter (TDC, KS, KS2, ...) work by default with the raw binary
        # format as input to avoid copy when the recording is already this format

        recording, sorting_gt = se.example_datasets.toy_example(num_channels=2,
                                                                duration=10)

        # create a raw dat file and prb file
        raw_filename = 'raw_file.dat'
        prb_filename = 'raw_file.prb'

        samplerate = recording.get_sampling_frequency()
        traces = recording.get_traces().astype('float32')
        with open(raw_filename, mode='wb') as f:
            f.write(traces.T.tobytes())

        se.save_probe_file(recording, prb_filename, format='spyking_circus')

        recording = se.BinDatRecordingExtractor(raw_filename,
                                                samplerate,
                                                2,
                                                'float32',
                                                frames_first=True,
                                                offset=0)
        se.load_probe_file(recording, prb_filename)

        params = self.SorterClass.default_params()
        sorter = self.SorterClass(recording=recording, output_folder=None)
        sorter.set_params(**params)
        sorter.run()
        sorting = sorter.get_result()
Example #2
0
def get_one_recording(study_folder, rec_name):
    """
    Get one recording from its name

    Parameters
    ----------
    study_folder: str
        The study folder.
    rec_name: str
        The recording name
    Returns
    ----------

    recording: RecordingExtractor
        The recording.
    
    """
    raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
    prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
    json_filename = study_folder / 'raw_files' / (rec_name + '.json')
    with open(json_filename, 'r', encoding='utf8') as f:
        info = json.load(f)
    rec = se.BinDatRecordingExtractor(raw_filename,
                                      info['sample_rate'],
                                      info['num_chan'],
                                      info['dtype'],
                                      time_axis=info['time_axis'])
    load_probe_file_inplace(rec, prb_filename)

    return rec
Example #3
0
    def test_with_BinDatRecordingExtractor(self):
        # some sorter (TDC, KS, KS2, ...) work by default with the raw binary
        # format as input to avoid copy when the recording is already this format

        recording, sorting_gt = se.example_datasets.toy_example(num_channels=2, duration=10, seed=0)

        # create a raw dat file and prb file
        raw_filename = 'raw_file.dat'
        prb_filename = 'raw_file.prb'

        samplerate = recording.get_sampling_frequency()
        traces = recording.get_traces().astype('float32')
        with open(raw_filename, mode='wb') as f:
            f.write(traces.T.tobytes())

        recording.save_to_probe_file(prb_filename)
        recording = se.BinDatRecordingExtractor(raw_filename, samplerate, 2, 'float32', time_axis=0, offset=0)
        recording = recording.load_probe_file(prb_filename)

        params = self.SorterClass.default_params()
        sorter = self.SorterClass(recording=recording, output_folder=None)
        sorter.set_params(**params)
        sorter.run()
        sorting = sorter.get_result()

        for unit_id in sorting.get_unit_ids():
            print('unit #', unit_id, 'nb', len(sorting.get_unit_spike_train(unit_id)))
        del sorting
Example #4
0
    def __init__(self, probe_file, xml_file, nrs_file, dat_file):
        se.RecordingExtractor.__init__(self)
        # info = check_load_nrs(dirpath)
        # assert info is not None
        probe_obj = kp.load_object(probe_file)
        xml_file = kp.load_file(xml_file)
        # nrs_file = kp.load_file(nrs_file)
        dat_file = kp.load_file(dat_file)

        from xml.etree import ElementTree as ET
        xml = ET.parse(xml_file)
        root_element = xml.getroot()
        try:
            txt = root_element.find('acquisitionSystem/samplingRate').text
            assert txt is not None
            self._samplerate = float(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/samplingRate')
        try:
            txt = root_element.find('acquisitionSystem/nChannels').text
            assert txt is not None
            self._nChannels = int(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/nChannels')
        try:
            txt = root_element.find('acquisitionSystem/nBits').text
            assert txt is not None
            self._nBits = int(txt)
        except:
            raise Exception('Unable to load acquisitionSystem/nBits')

        if self._nBits == 16:
            dtype = np.int16
        elif self._nBits == 32:
            dtype = np.int32
        else:
            raise Exception(f'Unexpected nBits: {self._nBits}')

        self._rec = se.BinDatRecordingExtractor(
            dat_file,
            sampling_frequency=self._samplerate,
            numchan=self._nChannels,
            dtype=dtype)

        self._channel_ids = probe_obj['channel']
        for ii in range(len(probe_obj['channel'])):
            channel = probe_obj['channel'][ii]
            x = probe_obj['x'][ii]
            y = probe_obj['y'][ii]
            z = probe_obj['z'][ii]
            group = probe_obj.get('group', probe_obj.get('shank'))[ii]
            self.set_channel_property(channel, 'location', [x, y, z])
            self.set_channel_property(channel, 'group', group)
def setup_study():
    rec_names = [
        '20160415_patch2',
        '20160426_patch2',
        '20160426_patch3',
        '20170621_patch1',
        '20170713_patch1',
        '20170725_patch1',
        '20170728_patch2',
        '20170803_patch1',
    ]

    gt_dict = {}
    for rec_name in rec_names:

        # find raw file
        dirname = recording_folder + rec_name + '/'
        for f in os.listdir(dirname):
            if f.endswith('.raw') and not f.endswith('juxta.raw'):
                mea_filename = dirname + f

        # raw files have an internal offset that depend on the channel count
        # a simple built header can be parsed to get it
        with open(mea_filename.replace('.raw', '.txt'), mode='r') as f:
            offset = int(re.findall('padding = (\d+)', f.read())[0])

        # recording
        rec = se.BinDatRecordingExtractor(mea_filename,
                                          20000.,
                                          256,
                                          'uint16',
                                          offset=offset,
                                          frames_first=True)

        # this reduce channel count to 252
        rec = se.load_probe_file(rec, basedir + 'mea_256.prb')

        # gt sorting
        gt_indexes = np.fromfile(ground_truth_folder + rec_name +
                                 '/juxta_peak_indexes.raw',
                                 dtype='int64')
        sorting_gt = se.NumpySortingExtractor()
        sorting_gt.set_times_labels(gt_indexes,
                                    np.zeros(gt_indexes.size, dtype='int64'))
        sorting_gt.set_sampling_frequency(20000.0)

        gt_dict[rec_name] = (rec, sorting_gt)

    study = GroundTruthStudy.setup(study_folder, gt_dict)
Example #6
0
def get_recordings(study_folder):
    """
    Get ground recording as a dict.
    
    They are read from the 'raw_files' folder with binary format.
    
    Parameters
    ----------
    study_folder: str
        The study folder.
    
    Returns
    ----------
    
    recording_dict: dict
        Dict of rexording.
        
    """
    study_folder = Path(study_folder)

    rec_names = get_rec_names(study_folder)
    recording_dict = {}
    for rec_name in rec_names:
        raw_filename = study_folder / 'raw_files' / (rec_name + '.dat')
        prb_filename = study_folder / 'raw_files' / (rec_name + '.prb')
        json_filename = study_folder / 'raw_files' / (rec_name + '.json')
        with open(json_filename, 'r', encoding='utf8') as f:
            info = json.load(f)

        rec = se.BinDatRecordingExtractor(raw_filename,
                                          info['sample_rate'],
                                          info['num_chan'],
                                          info['dtype'],
                                          frames_first=info['frames_first'])
        se.load_probe_file(rec, prb_filename)

        recording_dict[rec_name] = rec

    return recording_dict
Example #7
0
def run_sorters(sorter_list,
                recording_dict_or_list,
                working_folder,
                grouping_property=None,
                shared_binary_copy=False,
                engine=None,
                engine_kargs={},
                debug=False,
                write_log=True):
    """
    This run several sorter on several recording.
    Simple implementation will nested loops.

    Need to be done with multiprocessing.

    sorter_list: list of str (sorter names)
    recording_dict_or_list: a dict (or a list) of recording
    working_folder : str

    engine = None ( = 'loop') or 'multiprocessing'
    processes = only if 'multiprocessing' if None then processes=os.cpu_count()
    debug=True/False to control sorter verbosity


    Note: engine='multiprocessing' use the python multiprocessing module.
    This do not allow to have subprocess in subprocess.
    So sorter that already use internally multiprocessing, this will fail.

    Parameters
    ----------
    
    sorter_list: list of str
        List of sorter name.
    
    recording_dict_or_list: dict or list
        A dict of recording. The key will be the name of the recording.
        In a list is given then the name will be recording_0, recording_1, ...
    
    working_folder: str
        The working directory.
        This must not exists before calling this function.
    
    grouping_property: str
        The property of grouping given to sorters.
    
    shared_binary_copy: False default
        Before running each sorter, all recording are copied inside 
        the working_folder with the raw binary format (BinDatRecordingExtractor)
        and new recording are instantiated as BinDatRecordingExtractor.
        This avoids multiple copy inside each sorter of the same file but
        imply a global of all files.

    engine: str
        'loop' or 'multiprocessing'
    
    engine_kargs: dict
        This contains kargs specific to the launcher engine:
            * 'loop' : no kargs
            * 'multiprocessing' : {'processes' : } number of processes
    
    debug: bool
        default True
    
    write_log: bool
        default True
    
    Output
    ----------
    
    results : dict
        The output is nested dict[rec_name][sorter_name] of SortingExtrator.



    """

    assert not os.path.exists(
        working_folder), 'working_folder already exists, please remove it'
    working_folder = Path(working_folder)

    for sorter_name in sorter_list:
        assert sorter_name in sorter_dict, '{} is not in sorter list'.format(
            sorter_name)

    if isinstance(recording_dict_or_list, list):
        # in case of list
        recording_dict = {
            'recording_{}'.format(i): rec
            for i, rec in enumerate(recording_dict_or_list)
        }
    elif isinstance(recording_dict_or_list, dict):
        recording_dict = recording_dict_or_list
    else:
        raise (ValueError('bad recording dict'))

    if shared_binary_copy:
        os.makedirs(working_folder / 'raw_files')
        old_rec_dict = dict(recording_dict)
        recording_dict = {}
        for rec_name, recording in old_rec_dict.items():
            if grouping_property is not None:
                recording_list = se.get_sub_extractors_by_property(
                    recording, grouping_property)
                n_group = len(recording_list)
                assert n_group == 1, 'shared_binary_copy work only when one group'
                recording = recording_list[0]
                grouping_property = None

            raw_filename = working_folder / 'raw_files' / (rec_name + '.raw')
            prb_filename = working_folder / 'raw_files' / (rec_name + '.prb')
            n_chan = recording.get_num_channels()
            chunksize = 2**24 // n_chan
            sr = recording.get_sampling_frequency()

            # save binary
            se.write_binary_dat_format(recording,
                                       raw_filename,
                                       time_axis=0,
                                       dtype='float32',
                                       chunksize=chunksize)
            # save location (with PRB format)
            se.save_probe_file(recording,
                               prb_filename,
                               format='spyking_circus')

            # make new  recording
            new_rec = se.BinDatRecordingExtractor(raw_filename,
                                                  sr,
                                                  n_chan,
                                                  'float32',
                                                  frames_first=True)
            se.load_probe_file(new_rec, prb_filename)
            recording_dict[rec_name] = new_rec

    task_list = []
    for rec_name, recording in recording_dict.items():
        for sorter_name in sorter_list:
            output_folder = working_folder / 'output_folders' / rec_name / sorter_name
            task_list.append((rec_name, recording, sorter_name, output_folder,
                              grouping_property, debug, write_log))

    if engine is None or engine == 'loop':
        # simple loop in main process
        for arg_list in task_list:
            # print(arg_list)
            _run_one(arg_list)

    elif engine == 'multiprocessing':
        # use mp.Pool
        processes = engine_kargs.get('processes', None)
        pool = multiprocessing.Pool(processes)
        pool.map(_run_one, task_list)

    if write_log:
        # collect run time and write to cvs
        with open(working_folder / 'run_time.csv', mode='w') as f:
            for task in task_list:
                rec_name = task[0]
                sorter_name = task[2]
                output_folder = task[3]
                if os.path.exists(output_folder / 'run_log.txt'):
                    with open(output_folder / 'run_log.txt',
                              mode='r') as logfile:
                        run_time = float(logfile.readline().replace(
                            'run_time:', ''))

                    txt = '{}\t{}\t{}\n'.format(rec_name, sorter_name,
                                                run_time)
                    f.write(txt)

    results = collect_results(working_folder)
    return results