Ejemplo n.º 1
0
def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index,
                                  peak_shifts, return_scaled):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    if isinstance(sorting, dict):
        from spikeinterface.core import load_extractor
        sorting = load_extractor(sorting)
    worker_ctx['recording'] = recording
    worker_ctx['sorting'] = sorting
    worker_ctx['return_scaled'] = return_scaled
    all_spikes = sorting.get_all_spike_trains()
    for segment_index in range(recording.get_num_segments()):
        spike_times, spike_labels = all_spikes[segment_index]
        for unit_id in sorting.unit_ids:
            if peak_shifts[unit_id] != 0:
                mask = spike_labels == unit_id
                spike_times[mask] += peak_shifts[unit_id]
        # reorder otherwise the chunk processing and searchsorted will not work
        order = np.argsort(spike_times)
        all_spikes[segment_index] = spike_times[order], spike_labels[order]
    worker_ctx['all_spikes'] = all_spikes
    worker_ctx['extremum_channels_index'] = extremum_channels_index
    return worker_ctx
Ejemplo n.º 2
0
def _init_worker_waveform_extractor(recording, sorting, wfs_memmap,
                                    selected_spikes, selected_spike_times,
                                    nbefore, nafter):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording

    if isinstance(sorting, dict):
        from spikeinterface.core import load_extractor
        sorting = load_extractor(sorting)
    worker_ctx['sorting'] = sorting

    worker_ctx['wfs_memmap'] = wfs_memmap
    worker_ctx['selected_spikes'] = selected_spikes
    worker_ctx['selected_spike_times'] = selected_spike_times
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter

    num_seg = sorting.get_num_segments()
    unit_cum_sum = {}
    for unit_id in sorting.unit_ids:
        # spike per segment
        n_per_segment = [
            selected_spikes[unit_id][i].size for i in range(num_seg)
        ]
        cum_sum = [0] + np.cumsum(n_per_segment).tolist()
        unit_cum_sum[unit_id] = cum_sum
    worker_ctx['unit_cum_sum'] = unit_cum_sum

    return worker_ctx
Ejemplo n.º 3
0
    def load_from_folder(cls, folder):
        folder = Path(folder)
        assert folder.is_dir(), f'This folder do not exists {folder}'
        recording = load_extractor(folder / 'recording.json')
        sorting = load_extractor(folder / 'sorting.json')
        we = cls(recording, sorting, folder)

        for mode in _possible_template_modes:
            # load cached templates
            template_file = folder / f'templates_{mode}.npy'
            if template_file.is_file():
                we._template_cache[mode] = np.load(template_file)

        return we
Ejemplo n.º 4
0
    def get_result_from_folder(cls, output_folder):
        output_folder = Path(output_folder)
        # check errors in log file
        log_file = output_folder / 'spikeinterface_log.json'
        if not log_file.is_file():
            raise SpikeSortingError(
                'get result error: the folder do not contain spikeinterface_log.json'
            )

        with log_file.open('r', encoding='utf8') as f:
            log = json.load(f)

        if bool(log['error']):
            raise SpikeSortingError(
                "Spike sorting failed. You can inspect the runtime trace in spikeinterface_log.json"
            )

        sorting = cls._get_result_from_folder(output_folder)

        recording = load_extractor(output_folder /
                                   'spikeinterface_recording.json')
        if recording is not None:
            # can be None when not dumpable
            sorting.register_recording(recording)
        return sorting
Ejemplo n.º 5
0
def get_recordings(study_folder):
    """
    Get ground recording as a dict.

    They are read from the 'raw_files' folder with binary format.

    Parameters
    ----------
    study_folder: str
        The study folder.

    Returns
    -------
    recording_dict: dict
        Dict of recording.
    """
    study_folder = Path(study_folder)

    rec_names = get_rec_names(study_folder)
    recording_dict = {}
    for rec_name in rec_names:
        rec = load_extractor(study_folder / 'raw_files' / rec_name)
        recording_dict[rec_name] = rec

    return recording_dict
Ejemplo n.º 6
0
    def _run_from_folder(cls, output_folder, params, verbose):
        import mountainsort4

        recording = load_extractor(output_folder /
                                   'spikeinterface_recording.json')

        # alias to params
        p = params

        samplerate = recording.get_sampling_frequency()

        # Bandpass filter
        if p['filter'] and p['freq_min'] is not None and p[
                'freq_max'] is not None:
            if verbose:
                print('filtering')
            recording = bandpass_filter(recording=recording,
                                        freq_min=p['freq_min'],
                                        freq_max=p['freq_max'])

        # Whiten
        if p['whiten']:
            if verbose:
                print('whitenning')
            recording = whiten(recording=recording)

        print(
            'Mountainsort4 use the OLD spikeextractors mapped with RecordingExtractorOldAPI'
        )
        old_api_recording = RecordingExtractorOldAPI(recording)

        # Check location no more needed done in basesorter
        old_api_sorting = mountainsort4.mountainsort4(
            recording=old_api_recording,
            detect_sign=p['detect_sign'],
            adjacency_radius=p['adjacency_radius'],
            clip_size=p['clip_size'],
            detect_threshold=p['detect_threshold'],
            detect_interval=p['detect_interval'],
            num_workers=p['num_workers'],
            verbose=verbose)

        # Curate
        if p['noise_overlap_threshold'] is not None and p['curation'] is True:
            if verbose:
                print('Curating')
            old_api_sorting = mountainsort4.mountainsort4_curation(
                recording=old_api_recording,
                sorting=old_api_sorting,
                noise_overlap_threshold=p['noise_overlap_threshold'])

        # convert sorting to new API and save it
        unit_ids = old_api_sorting.get_unit_ids()
        units_dict_list = [{
            u: old_api_sorting.get_unit_spike_train(u)
            for u in unit_ids
        }]
        new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate)
        NpzSortingExtractor.write_sorting(new_api_sorting,
                                          str(output_folder / 'firings.npz'))
Ejemplo n.º 7
0
    def _run_from_folder(cls, output_folder, params, verbose):
        recording = load_extractor(output_folder / 'spikeinterface_recording.json')

        assert isinstance(recording, BinaryRecordingExtractor)
        assert recording.get_num_segments() == 1
        dat_path = recording._kwargs['file_paths'][0]
        print('dat_path', dat_path)

        num_chans = recording.get_num_channels()
        locations = recording.get_channel_locations()
        print(locations)
        print(type(locations))

        # ks_probe is not probeinterface Probe at all
        ks_probe = Bunch()
        ks_probe.NchanTOT = num_chans
        ks_probe.chanMap = np.arange(num_chans)
        ks_probe.kcoords = np.ones(num_chans)
        ks_probe.xc = locations[:, 0]
        ks_probe.yc = locations[:, 1]

        run(
            dat_path,
            params=params,
            probe=ks_probe,
            dir_path=output_folder,
            n_channels=num_chans,
            dtype=recording.get_dtype(),
            sample_rate=recording.get_sampling_frequency(),
        )
Ejemplo n.º 8
0
def get_ground_truths(study_folder):
    """
    Get ground truth sorting extractor as a dict.

    They are read from the 'ground_truth' folder with npz format.

    Parameters
    ----------
    study_folder: str
        The study folder.

    Returns
    ----------

    ground_truths: dict
        Dict of sorintg_gt.

    """
    study_folder = Path(study_folder)
    rec_names = get_rec_names(study_folder)
    ground_truths = {}
    for rec_name in rec_names:
        sorting = load_extractor(study_folder /  'ground_truth' / rec_name)
        ground_truths[rec_name] = sorting
    return ground_truths
Ejemplo n.º 9
0
def _run_one(arg_list):
    # the multiprocessing python module force to have one unique tuple argument
    sorter_name, recording, output_folder, verbose, sorter_params, docker_image, with_output = arg_list

    if isinstance(recording, dict):
        recording = load_extractor(recording)
    else:
        recording = recording

    # because this is checks in run_sorters before this call
    remove_existing_folder = False
    # result is retrieve later
    delete_output_folder = False
    # because we won't want the loop/worker to break
    raise_error = False

    if docker_image is None:

        run_sorter_local(sorter_name, recording, output_folder=output_folder,
                         remove_existing_folder=remove_existing_folder, delete_output_folder=delete_output_folder,
                         verbose=verbose, raise_error=raise_error, with_output=with_output, **sorter_params)
    else:

        run_sorter_docker(sorter_name, recording, docker_image, output_folder=output_folder,
                          remove_existing_folder=remove_existing_folder, delete_output_folder=delete_output_folder,
                          verbose=verbose, raise_error=raise_error, with_output=with_output, **sorter_params)
Ejemplo n.º 10
0
def _init_worker_localize_peaks(recording, peaks, method, method_kwargs, nbefore, nafter, contact_locations, margin):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['peaks'] = peaks
    worker_ctx['method'] = method
    worker_ctx['method_kwargs'] = method_kwargs
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter
    
    worker_ctx['contact_locations'] = contact_locations
    worker_ctx['margin'] = margin
    
    
    if method in ('center_of_mass', 'monopolar_triangulation'):
        # handle sparsity
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < method_kwargs['local_radius_um']
        worker_ctx['neighbours_mask'] = neighbours_mask

    #~ if method == 'center_of_mass':
        #~ pass
    #~ elif method == 'monopolar_triangulation':
        #~ pass
    
    return worker_ctx
Ejemplo n.º 11
0
def _init_memory_worker(recording, arrays, shm_names, shapes, dtype):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        worker_ctx['recording'] = load_extractor(recording)
    else:
        worker_ctx['recording'] = recording

    worker_ctx['dtype'] = np.dtype(dtype)

    if arrays is None:
        # create it from share memory name
        from multiprocessing.shared_memory import SharedMemory
        arrays = []
        # keep shm alive
        worker_ctx['shms'] = []
        for i in range(len(shm_names)):
            shm = SharedMemory(shm_names[i])
            worker_ctx['shms'].append(shm)
            arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf)
            arrays.append(arr)

    worker_ctx['arrays'] = arrays

    return worker_ctx
Ejemplo n.º 12
0
def test_BaseSorting():
    num_seg = 2
    file_path = 'test_BaseSorting.npz'

    create_sorting_npz(num_seg, file_path)

    sorting = NpzSortingExtractor(file_path)
    print(sorting)

    assert sorting.get_num_segments() == 2
    assert sorting.get_num_units() == 3

    # annotations / properties
    sorting.annotate(yep='yop')
    assert sorting.get_annotation('yep') == 'yop'

    sorting.set_property('amplitude', [-20, -40., -55.5])
    values = sorting.get_property('amplitude')
    assert np.all(values == [-20, -40., -55.5])

    # dump/load dict
    d = sorting.to_dict()
    sorting2 = BaseExtractor.from_dict(d)
    sorting3 = load_extractor(d)

    # dump/load json
    sorting.dump_to_json('test_BaseSorting.json')
    sorting2 = BaseExtractor.load('test_BaseSorting.json')
    sorting3 = load_extractor('test_BaseSorting.json')

    # dump/load pickle
    sorting.dump_to_pickle('test_BaseSorting.pkl')
    sorting2 = BaseExtractor.load('test_BaseSorting.pkl')
    sorting3 = load_extractor('test_BaseSorting.pkl')

    # cache
    folder = Path('./my_cache_folder') / 'simple_sorting'
    sorting.save(folder=folder)
    sorting2 = BaseExtractor.load_from_folder(folder)
    # but also possible
    sorting3 = BaseExtractor.load(folder)

    spikes = sorting.get_all_spike_trains()
    # print(spikes)

    spikes = sorting.to_spike_vector()
Ejemplo n.º 13
0
def _init_worker_find_spike(recording, method, method_kwargs):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['method'] = method
    worker_ctx['method_kwargs'] = method_kwargs
    return worker_ctx
Ejemplo n.º 14
0
def _init_binary_worker(recording, rec_memmaps, dtype):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        worker_ctx['recording'] = load_extractor(recording)
    else:
        worker_ctx['recording'] = recording

    worker_ctx['rec_memmaps'] = rec_memmaps
    worker_ctx['dtype'] = np.dtype(dtype)

    return worker_ctx
Ejemplo n.º 15
0
def _init_worker_detect_peaks(recording, method, peak_sign, abs_threholds, n_shifts, neighbours_mask):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['method'] = method
    worker_ctx['peak_sign'] = peak_sign
    worker_ctx['abs_threholds'] = abs_threholds
    worker_ctx['n_shifts'] = n_shifts
    worker_ctx['neighbours_mask'] = neighbours_mask
    return worker_ctx
Ejemplo n.º 16
0
def _init_worker_localize_peaks(recording, peaks, method, nbefore, nafter, neighbours_mask, contact_locations):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['peaks'] = peaks
    worker_ctx['method'] = method
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter
    worker_ctx['neighbours_mask'] = neighbours_mask
    worker_ctx['contact_locations'] = contact_locations

    return worker_ctx
Ejemplo n.º 17
0
def _init_work_all_pc_extractor(recording, all_pcs, spike_times, spike_labels, nbefore, nafter, unit_channels, all_pca):
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['all_pcs'] = all_pcs
    worker_ctx['spike_times'] = spike_times
    worker_ctx['spike_labels'] = spike_labels
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter
    worker_ctx['unit_channels'] = unit_channels
    worker_ctx['all_pca'] = all_pca
    
    return worker_ctx
Ejemplo n.º 18
0
def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index,
                                  peak_shifts, return_scaled):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    if isinstance(sorting, dict):
        from spikeinterface.core import load_extractor
        sorting = load_extractor(sorting)

    worker_ctx['recording'] = recording
    worker_ctx['sorting'] = sorting
    worker_ctx['return_scaled'] = return_scaled
    worker_ctx['peak_shifts'] = peak_shifts
    worker_ctx['min_shift'] = np.min(peak_shifts)
    worker_ctx['max_shifts'] = np.max(peak_shifts)

    all_spikes = sorting.get_all_spike_trains(outputs='unit_index')

    worker_ctx['all_spikes'] = all_spikes
    worker_ctx['extremum_channels_index'] = extremum_channels_index

    return worker_ctx
Ejemplo n.º 19
0
def _init_worker_unit_amplitudes(recording, sorting, extremum_channels_index,
                                 peak_shifts):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    if isinstance(sorting, dict):
        from spikeinterface.core import load_extractor
        sorting = load_extractor(sorting)
    worker_ctx['recording'] = recording
    worker_ctx['sorting'] = sorting
    all_spikes = sorting.get_all_spike_trains()
    # apply peak shift
    for unit_id in sorting.unit_ids:
        if peak_shifts[unit_id] != 0:
            for segment_index in range(recording.get_num_segments()):
                spike_times, spike_labels = all_spikes[segment_index]
                mask = spike_labels == unit_id
                spike_times[mask] += peak_shifts[unit_id]
                all_spikes[segment_index] = spike_times, spike_labels
    worker_ctx['all_spikes'] = all_spikes
    worker_ctx['extremum_channels_index'] = extremum_channels_index
    return worker_ctx
Ejemplo n.º 20
0
def _init_work_all_pc_extractor(recording, all_pcs_args, spike_times, spike_labels, nbefore, nafter, unit_channels,
                                pca_model):
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording
    worker_ctx['all_pcs'] = np.lib.format.open_memmap(**all_pcs_args)
    worker_ctx['spike_times'] = spike_times
    worker_ctx['spike_labels'] = spike_labels
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter
    worker_ctx['unit_channels'] = unit_channels
    worker_ctx['pca_model'] = pca_model

    return worker_ctx
Ejemplo n.º 21
0
 def load_from_folder(folder_path):
     folder_path = Path(folder_path)
     with (folder_path / 'kwargs.json').open() as f:
         kwargs = json.load(f)
     with (folder_path / 'sortings.json').open() as f:
         dict_sortings = json.load(f)
     name_list = list(dict_sortings.keys())
     sorting_list = [load_extractor(v) for v in dict_sortings.values()]
     mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(
         name_list), do_matching=False, **kwargs)
     mcmp.graph = nx.read_gpickle(
         str(folder_path / 'multicomparison.gpickle'))
     # do step 3 and 4
     mcmp._clean_graph()
     mcmp._do_agreement()
     mcmp._populate_spiketrains()
     return mcmp
Ejemplo n.º 22
0
def _init_worker_waveform_extractor(recording, unit_ids, spikes,
                                    wfs_arrays_info, nbefore, nafter,
                                    return_scaled, inds_by_unit, mode,
                                    sparsity_mask):
    # create a local dict per worker
    worker_ctx = {}
    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)
    worker_ctx['recording'] = recording

    if mode == 'memmap':
        # ~ if not platform.system().lower().startswith('linux'):
        # For OSX and windows : need to re open all npy files in r+ mode for each worker
        wfs_arrays = {}
        for unit_id, filename in wfs_arrays_info.items():
            wfs_arrays[unit_id] = np.load(str(filename), mmap_mode='r+')
    elif mode == 'shared_memory':

        from multiprocessing.shared_memory import SharedMemory
        wfs_arrays = {}
        shms = {}
        for unit_id, (sm, shm_name, dtype, shape) in wfs_arrays_info.items():
            shm = SharedMemory(shm_name)
            arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
            wfs_arrays[unit_id] = arr
            # we need a reference to all sham otherwise we get segment fault!!!
            shms[unit_id] = shm
        worker_ctx['shms'] = shms

    worker_ctx['unit_ids'] = unit_ids
    worker_ctx['spikes'] = spikes
    worker_ctx['wfs_arrays'] = wfs_arrays
    worker_ctx['nbefore'] = nbefore
    worker_ctx['nafter'] = nafter
    worker_ctx['return_scaled'] = return_scaled
    worker_ctx['inds_by_unit'] = inds_by_unit
    worker_ctx['sparsity_mask'] = sparsity_mask

    return worker_ctx
Ejemplo n.º 23
0
def _run_one(arg_list):
    # the multiprocessing python module force to have one unique tuple argument
    sorter_name, recording, output_folder, verbose, sorter_params = arg_list
    if isinstance(recording, dict):
        recording = load_extractor(recording)
    else:
        recording = recording

    SorterClass = sorter_dict[sorter_name]

    # because this is checks in run_sorters before this call
    remove_existing_folder = False
    # result is retrieve later
    delete_output_folder = False
    # because we won't want the loop/worker to break
    raise_error = False

    # only classmethod call not instance (stateless at instance level but state is in folder)
    output_folder = SorterClass.initialize_folder(recording, output_folder, verbose, remove_existing_folder)
    SorterClass.set_params_to_folder(recording, output_folder, sorter_params, verbose)
    SorterClass.setup_recording(recording, output_folder, verbose=verbose)
    SorterClass.run_from_folder(output_folder, raise_error, verbose)
Ejemplo n.º 24
0
def _init_worker_detect_peaks(recording, method, peak_sign, abs_threholds,
                              n_shifts, neighbours_mask, extra_margin,
                              localization_dict):
    """Initialize a worker for detecting peaks."""

    if isinstance(recording, dict):
        from spikeinterface.core import load_extractor
        recording = load_extractor(recording)

    # create a local dict per worker
    worker_ctx = {}
    worker_ctx['recording'] = recording
    worker_ctx['method'] = method
    worker_ctx['peak_sign'] = peak_sign
    worker_ctx['abs_threholds'] = abs_threholds
    worker_ctx['n_shifts'] = n_shifts
    worker_ctx['neighbours_mask'] = neighbours_mask
    worker_ctx['extra_margin'] = extra_margin
    worker_ctx['localization_dict'] = localization_dict

    if localization_dict is not None:
        worker_ctx['contact_locations'] = recording.get_channel_locations()
        channel_distance = get_channel_distances(recording)

        ms_before = worker_ctx['localization_dict']['ms_before']
        ms_after = worker_ctx['localization_dict']['ms_after']
        worker_ctx['localization_dict']['nbefore'] = \
            int(ms_before * recording.get_sampling_frequency() / 1000.)
        worker_ctx['localization_dict']['nafter'] = \
            int(ms_after * recording.get_sampling_frequency() / 1000.)

        # channel sparsity
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < localization_dict[
            'local_radius_um']
        worker_ctx['localization_dict']['neighbours_mask'] = neighbours_mask

    return worker_ctx
Ejemplo n.º 25
0
def test_BaseRecording():
    num_seg = 2
    num_chan = 3
    num_samples = 30
    sampling_frequency = 10000
    dtype = 'int16'

    files_path = [f'test_base_recording_{i}.raw' for i in range(num_seg)]
    for i in range(num_seg):
        a = np.memmap(files_path[i],
                      dtype=dtype,
                      mode='w+',
                      shape=(num_samples, num_chan))
        a[:] = np.random.randn(*a.shape).astype(dtype)

    rec = BinaryRecordingExtractor(files_path, sampling_frequency, num_chan,
                                   dtype)
    print(rec)

    assert rec.get_num_segments() == 2
    assert rec.get_num_channels() == 3

    assert np.all(rec.ids_to_indices([0, 1, 2]) == [0, 1, 2])
    assert np.all(
        rec.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None))

    # annotations / properties
    rec.annotate(yep='yop')
    assert rec.get_annotation('yep') == 'yop'

    rec.set_property('quality', [1., 3.3, np.nan])
    values = rec.get_property('quality')
    assert np.all(values[:2] == [
        1.,
        3.3,
    ])

    # dump/load dict
    d = rec.to_dict()
    rec2 = BaseExtractor.from_dict(d)
    rec3 = load_extractor(d)

    # dump/load json
    rec.dump_to_json('test_BaseRecording.json')
    rec2 = BaseExtractor.load('test_BaseRecording.json')
    rec3 = load_extractor('test_BaseRecording.json')

    # dump/load pickle
    rec.dump_to_pickle('test_BaseRecording.pkl')
    rec2 = BaseExtractor.load('test_BaseRecording.pkl')
    rec3 = load_extractor('test_BaseRecording.pkl')

    # cache to binary
    cache_folder = Path('./my_cache_folder')
    folder = cache_folder / 'simple_recording'
    rec.save(format='binary', folder=folder)
    rec2 = BaseExtractor.load_from_folder(folder)
    assert 'quality' in rec2.get_property_keys()
    # but also possible
    rec3 = BaseExtractor.load('./my_cache_folder/simple_recording')

    # cache to memory
    rec4 = rec3.save(format='memory')

    traces4 = rec4.get_traces(segment_index=0)
    traces = rec.get_traces(segment_index=0)
    assert np.array_equal(traces4, traces)

    # cache joblib several jobs
    rec.save(name='simple_recording_2', chunk_size=10, n_jobs=4)

    # set/get Probe only 2 channels
    probe = Probe(ndim=2)
    positions = [[0., 0.], [0., 15.], [0, 30.]]
    probe.set_contacts(positions=positions,
                       shapes='circle',
                       shape_params={'radius': 5})
    probe.set_device_channel_indices([2, -1, 0])
    probe.create_auto_shape()

    rec2 = rec.set_probe(probe, group_mode='by_shank')
    rec2 = rec.set_probe(probe, group_mode='by_probe')
    positions2 = rec2.get_channel_locations()
    assert np.array_equal(positions2, [[0, 30.], [0., 0.]])

    probe2 = rec2.get_probe()
    positions3 = probe2.contact_positions
    assert np.array_equal(positions2, positions3)

    # from probeinterface.plotting import plot_probe_group, plot_probe
    # import matplotlib.pyplot as plt
    # plot_probe(probe)
    # plot_probe(probe2)
    # plt.show()

    # test return_scale
    sampling_frequency = 30000
    traces = np.zeros((1000, 5), dtype='int16')
    rec_int16 = NumpyRecording([traces], sampling_frequency)
    assert rec_int16.get_dtype() == 'int16'
    print(rec_int16)
    traces_int16 = rec_int16.get_traces()
    assert traces_int16.dtype == 'int16'
    # return_scaled raise error when no gain_to_uV/offset_to_uV properties
    with pytest.raises(ValueError):
        traces_float32 = rec_int16.get_traces(return_scaled=True)
    rec_int16.set_property('gain_to_uV', [.195] * 5)
    rec_int16.set_property('offset_to_uV', [0.] * 5)
    traces_float32 = rec_int16.get_traces(return_scaled=True)
    assert traces_float32.dtype == 'float32'
Ejemplo n.º 26
0
 def get_recording(self, rec_name=None):
     rec_name = self._check_rec_name(rec_name)
     rec = load_extractor(self.study_folder / 'raw_files' / rec_name)
     return rec
Ejemplo n.º 27
0
 def get_ground_truth(self, rec_name=None):
     rec_name = self._check_rec_name(rec_name)
     sorting = load_extractor(self.study_folder / 'ground_truth' / rec_name)
     return sorting
Ejemplo n.º 28
0
 def load_from_folder(cls, folder):
     folder = Path(folder)
     recording = load_extractor(folder / 'recording.json')
     sorting = load_extractor(folder / 'sorting.json')
     we = cls(recording, sorting, folder)
     return we
Ejemplo n.º 29
0
    def _run_from_folder(cls, output_folder, params, verbose):
        import herdingspikes as hs

        recording = load_extractor(output_folder /
                                   'spikeinterface_recording.json')

        p = params

        # Bandpass filter
        if p['filter'] and p['freq_min'] is not None and p[
                'freq_max'] is not None:
            recording = st.bandpass_filter(recording=recording,
                                           freq_min=p['freq_min'],
                                           freq_max=p['freq_max'])

        if p['pre_scale']:
            recording = st.normalize_by_quantile(recording=recording,
                                                 scale=p['pre_scale_value'],
                                                 median=0.0,
                                                 q1=0.05,
                                                 q2=0.95)

        print(
            'Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI'
        )
        old_api_recording = RecordingExtractorOldAPI(recording)

        # this should have its name changed
        Probe = hs.probe.RecordingExtractor(
            old_api_recording,
            masked_channels=p['probe_masked_channels'],
            inner_radius=p['probe_inner_radius'],
            neighbor_radius=p['probe_neighbor_radius'],
            event_length=p['probe_event_length'],
            peak_jitter=p['probe_peak_jitter'])

        H = hs.HSDetection(Probe,
                           file_directory_name=str(output_folder),
                           left_cutout_time=p['left_cutout_time'],
                           right_cutout_time=p['right_cutout_time'],
                           threshold=p['detect_threshold'],
                           to_localize=True,
                           num_com_centers=p['num_com_centers'],
                           maa=p['maa'],
                           ahpthr=p['ahpthr'],
                           out_file_name=p['out_file_name'],
                           decay_filtering=p['decay_filtering'],
                           save_all=p['save_all'],
                           amp_evaluation_time=p['amp_evaluation_time'],
                           spk_evaluation_time=p['spk_evaluation_time'])

        H.DetectFromRaw(load=True, tInc=int(p['t_inc']))

        sorted_file = str(output_folder / 'HS2_sorted.hdf5')
        if (not H.spikes.empty):
            C = hs.HSClustering(H)
            C.ShapePCA(pca_ncomponents=p['pca_ncomponents'],
                       pca_whiten=p['pca_whiten'])
            C.CombinedClustering(alpha=p['clustering_alpha'],
                                 cluster_subset=p['clustering_subset'],
                                 bandwidth=p['clustering_bandwidth'],
                                 bin_seeding=p['clustering_bin_seeding'],
                                 n_jobs=p['clustering_n_jobs'],
                                 min_bin_freq=p['clustering_min_bin_freq'])
        else:
            C = hs.HSClustering(H)

        if p['filter_duplicates']:
            uids = C.spikes.cl.unique()
            for u in uids:
                s = C.spikes[C.spikes.cl == u].t.diff(
                ) < p['spk_evaluation_time'] / 1000 * Probe.fps
                C.spikes = C.spikes.drop(s.index[s])

        if verbose:
            print('Saving to', sorted_file)
        C.SaveHDF5(sorted_file, sampling=Probe.fps)
Ejemplo n.º 30
0
    def _run_from_folder(cls, output_folder, params, verbose):
        source_dir = Path(__file__).parent

        p = params.copy()
        if p['detect_sign'] < 0:
            p['detect_sign'] = 'neg'
        elif p['detect_sign'] > 0:
            p['detect_sign'] = 'pos'
        else:
            p['detect_sign'] = 'both'

        if not p['enable_detect_filter']:
            p['detect_filter_order'] = 0
        del p['enable_detect_filter']

        if not p['enable_sort_filter']:
            p['sort_filter_order'] = 0
        del p['enable_sort_filter']

        if p['interpolation']:
            p['interpolation'] = 'y'
        else:
            p['interpolation'] = 'n'

        recording = load_extractor(output_folder /
                                   'spikeinterface_recording.json')
        samplerate = recording.get_sampling_frequency()
        p['sr'] = samplerate

        num_channels = recording.get_num_channels()
        tmpdir = output_folder

        par_str = ''
        par_renames = {
            'detect_sign': 'detection',
            'detect_threshold': 'stdmin',
            'feature_type': 'features',
            'detect_filter_fmin': 'detect_fmin',
            'detect_filter_fmax': 'detect_fmax',
            'detect_filter_order': 'detect_order',
            'sort_filter_fmin': 'sort_fmin',
            'sort_filter_fmax': 'sort_fmax',
            'sort_filter_order': 'sort_order'
        }
        for key, value in p.items():
            if type(value) == str:
                value = '\'{}\''.format(value)
            elif type(value) == bool:
                value = '{}'.format(value).lower()
            if key in par_renames:
                key = par_renames[key]
            par_str += 'par.{} = {};\n'.format(key, value)

        if verbose:
            print('Running waveclus in {tmpdir}...'.format(tmpdir=tmpdir))

        matlab_code = _matlab_code.format(
            waveclus_path=WaveClusSorter.waveclus_path,
            source_path=source_dir,
            tmpdir=tmpdir.absolute(),
            nChans=num_channels,
            parameters=par_str)

        with (output_folder / 'run_waveclus.m').open('w') as f:
            f.write(matlab_code)

        if 'win' in sys.platform and sys.platform != 'darwin':
            shell_cmd = '''
                {disk_move}
                cd {tmpdir}
                matlab -nosplash -wait -log -r run_waveclus
            '''.format(disk_move=str(tmpdir)[:2], tmpdir=tmpdir)
        else:
            shell_cmd = '''
                #!/bin/bash
                cd "{tmpdir}"
                matlab -nosplash -nodisplay -log -r run_waveclus
            '''.format(tmpdir=tmpdir)
        shell_cmd = ShellScript(
            shell_cmd,
            script_path=output_folder / f'run_{cls.sorter_name}',
            log_path=output_folder / f'{cls.sorter_name}.log',
            verbose=verbose)
        shell_cmd.start()

        retcode = shell_cmd.wait()

        if retcode != 0:
            raise Exception('waveclus returned a non-zero exit code')

        result_fname = tmpdir / 'times_results.mat'
        if not result_fname.is_file():
            raise Exception(f'Result file does not exist: {result_fname}')