def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, peak_shifts, return_scaled): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) if isinstance(sorting, dict): from spikeinterface.core import load_extractor sorting = load_extractor(sorting) worker_ctx['recording'] = recording worker_ctx['sorting'] = sorting worker_ctx['return_scaled'] = return_scaled all_spikes = sorting.get_all_spike_trains() for segment_index in range(recording.get_num_segments()): spike_times, spike_labels = all_spikes[segment_index] for unit_id in sorting.unit_ids: if peak_shifts[unit_id] != 0: mask = spike_labels == unit_id spike_times[mask] += peak_shifts[unit_id] # reorder otherwise the chunk processing and searchsorted will not work order = np.argsort(spike_times) all_spikes[segment_index] = spike_times[order], spike_labels[order] worker_ctx['all_spikes'] = all_spikes worker_ctx['extremum_channels_index'] = extremum_channels_index return worker_ctx
def _init_worker_waveform_extractor(recording, sorting, wfs_memmap, selected_spikes, selected_spike_times, nbefore, nafter): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording if isinstance(sorting, dict): from spikeinterface.core import load_extractor sorting = load_extractor(sorting) worker_ctx['sorting'] = sorting worker_ctx['wfs_memmap'] = wfs_memmap worker_ctx['selected_spikes'] = selected_spikes worker_ctx['selected_spike_times'] = selected_spike_times worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter num_seg = sorting.get_num_segments() unit_cum_sum = {} for unit_id in sorting.unit_ids: # spike per segment n_per_segment = [ selected_spikes[unit_id][i].size for i in range(num_seg) ] cum_sum = [0] + np.cumsum(n_per_segment).tolist() unit_cum_sum[unit_id] = cum_sum worker_ctx['unit_cum_sum'] = unit_cum_sum return worker_ctx
def load_from_folder(cls, folder): folder = Path(folder) assert folder.is_dir(), f'This folder do not exists {folder}' recording = load_extractor(folder / 'recording.json') sorting = load_extractor(folder / 'sorting.json') we = cls(recording, sorting, folder) for mode in _possible_template_modes: # load cached templates template_file = folder / f'templates_{mode}.npy' if template_file.is_file(): we._template_cache[mode] = np.load(template_file) return we
def get_result_from_folder(cls, output_folder): output_folder = Path(output_folder) # check errors in log file log_file = output_folder / 'spikeinterface_log.json' if not log_file.is_file(): raise SpikeSortingError( 'get result error: the folder do not contain spikeinterface_log.json' ) with log_file.open('r', encoding='utf8') as f: log = json.load(f) if bool(log['error']): raise SpikeSortingError( "Spike sorting failed. You can inspect the runtime trace in spikeinterface_log.json" ) sorting = cls._get_result_from_folder(output_folder) recording = load_extractor(output_folder / 'spikeinterface_recording.json') if recording is not None: # can be None when not dumpable sorting.register_recording(recording) return sorting
def get_recordings(study_folder): """ Get ground recording as a dict. They are read from the 'raw_files' folder with binary format. Parameters ---------- study_folder: str The study folder. Returns ------- recording_dict: dict Dict of recording. """ study_folder = Path(study_folder) rec_names = get_rec_names(study_folder) recording_dict = {} for rec_name in rec_names: rec = load_extractor(study_folder / 'raw_files' / rec_name) recording_dict[rec_name] = rec return recording_dict
def _run_from_folder(cls, output_folder, params, verbose): import mountainsort4 recording = load_extractor(output_folder / 'spikeinterface_recording.json') # alias to params p = params samplerate = recording.get_sampling_frequency() # Bandpass filter if p['filter'] and p['freq_min'] is not None and p[ 'freq_max'] is not None: if verbose: print('filtering') recording = bandpass_filter(recording=recording, freq_min=p['freq_min'], freq_max=p['freq_max']) # Whiten if p['whiten']: if verbose: print('whitenning') recording = whiten(recording=recording) print( 'Mountainsort4 use the OLD spikeextractors mapped with RecordingExtractorOldAPI' ) old_api_recording = RecordingExtractorOldAPI(recording) # Check location no more needed done in basesorter old_api_sorting = mountainsort4.mountainsort4( recording=old_api_recording, detect_sign=p['detect_sign'], adjacency_radius=p['adjacency_radius'], clip_size=p['clip_size'], detect_threshold=p['detect_threshold'], detect_interval=p['detect_interval'], num_workers=p['num_workers'], verbose=verbose) # Curate if p['noise_overlap_threshold'] is not None and p['curation'] is True: if verbose: print('Curating') old_api_sorting = mountainsort4.mountainsort4_curation( recording=old_api_recording, sorting=old_api_sorting, noise_overlap_threshold=p['noise_overlap_threshold']) # convert sorting to new API and save it unit_ids = old_api_sorting.get_unit_ids() units_dict_list = [{ u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids }] new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate) NpzSortingExtractor.write_sorting(new_api_sorting, str(output_folder / 'firings.npz'))
def _run_from_folder(cls, output_folder, params, verbose): recording = load_extractor(output_folder / 'spikeinterface_recording.json') assert isinstance(recording, BinaryRecordingExtractor) assert recording.get_num_segments() == 1 dat_path = recording._kwargs['file_paths'][0] print('dat_path', dat_path) num_chans = recording.get_num_channels() locations = recording.get_channel_locations() print(locations) print(type(locations)) # ks_probe is not probeinterface Probe at all ks_probe = Bunch() ks_probe.NchanTOT = num_chans ks_probe.chanMap = np.arange(num_chans) ks_probe.kcoords = np.ones(num_chans) ks_probe.xc = locations[:, 0] ks_probe.yc = locations[:, 1] run( dat_path, params=params, probe=ks_probe, dir_path=output_folder, n_channels=num_chans, dtype=recording.get_dtype(), sample_rate=recording.get_sampling_frequency(), )
def get_ground_truths(study_folder): """ Get ground truth sorting extractor as a dict. They are read from the 'ground_truth' folder with npz format. Parameters ---------- study_folder: str The study folder. Returns ---------- ground_truths: dict Dict of sorintg_gt. """ study_folder = Path(study_folder) rec_names = get_rec_names(study_folder) ground_truths = {} for rec_name in rec_names: sorting = load_extractor(study_folder / 'ground_truth' / rec_name) ground_truths[rec_name] = sorting return ground_truths
def _run_one(arg_list): # the multiprocessing python module force to have one unique tuple argument sorter_name, recording, output_folder, verbose, sorter_params, docker_image, with_output = arg_list if isinstance(recording, dict): recording = load_extractor(recording) else: recording = recording # because this is checks in run_sorters before this call remove_existing_folder = False # result is retrieve later delete_output_folder = False # because we won't want the loop/worker to break raise_error = False if docker_image is None: run_sorter_local(sorter_name, recording, output_folder=output_folder, remove_existing_folder=remove_existing_folder, delete_output_folder=delete_output_folder, verbose=verbose, raise_error=raise_error, with_output=with_output, **sorter_params) else: run_sorter_docker(sorter_name, recording, docker_image, output_folder=output_folder, remove_existing_folder=remove_existing_folder, delete_output_folder=delete_output_folder, verbose=verbose, raise_error=raise_error, with_output=with_output, **sorter_params)
def _init_worker_localize_peaks(recording, peaks, method, method_kwargs, nbefore, nafter, contact_locations, margin): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['peaks'] = peaks worker_ctx['method'] = method worker_ctx['method_kwargs'] = method_kwargs worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter worker_ctx['contact_locations'] = contact_locations worker_ctx['margin'] = margin if method in ('center_of_mass', 'monopolar_triangulation'): # handle sparsity channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < method_kwargs['local_radius_um'] worker_ctx['neighbours_mask'] = neighbours_mask #~ if method == 'center_of_mass': #~ pass #~ elif method == 'monopolar_triangulation': #~ pass return worker_ctx
def _init_memory_worker(recording, arrays, shm_names, shapes, dtype): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor worker_ctx['recording'] = load_extractor(recording) else: worker_ctx['recording'] = recording worker_ctx['dtype'] = np.dtype(dtype) if arrays is None: # create it from share memory name from multiprocessing.shared_memory import SharedMemory arrays = [] # keep shm alive worker_ctx['shms'] = [] for i in range(len(shm_names)): shm = SharedMemory(shm_names[i]) worker_ctx['shms'].append(shm) arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) arrays.append(arr) worker_ctx['arrays'] = arrays return worker_ctx
def test_BaseSorting(): num_seg = 2 file_path = 'test_BaseSorting.npz' create_sorting_npz(num_seg, file_path) sorting = NpzSortingExtractor(file_path) print(sorting) assert sorting.get_num_segments() == 2 assert sorting.get_num_units() == 3 # annotations / properties sorting.annotate(yep='yop') assert sorting.get_annotation('yep') == 'yop' sorting.set_property('amplitude', [-20, -40., -55.5]) values = sorting.get_property('amplitude') assert np.all(values == [-20, -40., -55.5]) # dump/load dict d = sorting.to_dict() sorting2 = BaseExtractor.from_dict(d) sorting3 = load_extractor(d) # dump/load json sorting.dump_to_json('test_BaseSorting.json') sorting2 = BaseExtractor.load('test_BaseSorting.json') sorting3 = load_extractor('test_BaseSorting.json') # dump/load pickle sorting.dump_to_pickle('test_BaseSorting.pkl') sorting2 = BaseExtractor.load('test_BaseSorting.pkl') sorting3 = load_extractor('test_BaseSorting.pkl') # cache folder = Path('./my_cache_folder') / 'simple_sorting' sorting.save(folder=folder) sorting2 = BaseExtractor.load_from_folder(folder) # but also possible sorting3 = BaseExtractor.load(folder) spikes = sorting.get_all_spike_trains() # print(spikes) spikes = sorting.to_spike_vector()
def _init_worker_find_spike(recording, method, method_kwargs): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['method'] = method worker_ctx['method_kwargs'] = method_kwargs return worker_ctx
def _init_binary_worker(recording, rec_memmaps, dtype): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor worker_ctx['recording'] = load_extractor(recording) else: worker_ctx['recording'] = recording worker_ctx['rec_memmaps'] = rec_memmaps worker_ctx['dtype'] = np.dtype(dtype) return worker_ctx
def _init_worker_detect_peaks(recording, method, peak_sign, abs_threholds, n_shifts, neighbours_mask): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['method'] = method worker_ctx['peak_sign'] = peak_sign worker_ctx['abs_threholds'] = abs_threholds worker_ctx['n_shifts'] = n_shifts worker_ctx['neighbours_mask'] = neighbours_mask return worker_ctx
def _init_worker_localize_peaks(recording, peaks, method, nbefore, nafter, neighbours_mask, contact_locations): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['peaks'] = peaks worker_ctx['method'] = method worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter worker_ctx['neighbours_mask'] = neighbours_mask worker_ctx['contact_locations'] = contact_locations return worker_ctx
def _init_work_all_pc_extractor(recording, all_pcs, spike_times, spike_labels, nbefore, nafter, unit_channels, all_pca): worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['all_pcs'] = all_pcs worker_ctx['spike_times'] = spike_times worker_ctx['spike_labels'] = spike_labels worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter worker_ctx['unit_channels'] = unit_channels worker_ctx['all_pca'] = all_pca return worker_ctx
def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, peak_shifts, return_scaled): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) if isinstance(sorting, dict): from spikeinterface.core import load_extractor sorting = load_extractor(sorting) worker_ctx['recording'] = recording worker_ctx['sorting'] = sorting worker_ctx['return_scaled'] = return_scaled worker_ctx['peak_shifts'] = peak_shifts worker_ctx['min_shift'] = np.min(peak_shifts) worker_ctx['max_shifts'] = np.max(peak_shifts) all_spikes = sorting.get_all_spike_trains(outputs='unit_index') worker_ctx['all_spikes'] = all_spikes worker_ctx['extremum_channels_index'] = extremum_channels_index return worker_ctx
def _init_worker_unit_amplitudes(recording, sorting, extremum_channels_index, peak_shifts): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) if isinstance(sorting, dict): from spikeinterface.core import load_extractor sorting = load_extractor(sorting) worker_ctx['recording'] = recording worker_ctx['sorting'] = sorting all_spikes = sorting.get_all_spike_trains() # apply peak shift for unit_id in sorting.unit_ids: if peak_shifts[unit_id] != 0: for segment_index in range(recording.get_num_segments()): spike_times, spike_labels = all_spikes[segment_index] mask = spike_labels == unit_id spike_times[mask] += peak_shifts[unit_id] all_spikes[segment_index] = spike_times, spike_labels worker_ctx['all_spikes'] = all_spikes worker_ctx['extremum_channels_index'] = extremum_channels_index return worker_ctx
def _init_work_all_pc_extractor(recording, all_pcs_args, spike_times, spike_labels, nbefore, nafter, unit_channels, pca_model): worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording worker_ctx['all_pcs'] = np.lib.format.open_memmap(**all_pcs_args) worker_ctx['spike_times'] = spike_times worker_ctx['spike_labels'] = spike_labels worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter worker_ctx['unit_channels'] = unit_channels worker_ctx['pca_model'] = pca_model return worker_ctx
def load_from_folder(folder_path): folder_path = Path(folder_path) with (folder_path / 'kwargs.json').open() as f: kwargs = json.load(f) with (folder_path / 'sortings.json').open() as f: dict_sortings = json.load(f) name_list = list(dict_sortings.keys()) sorting_list = [load_extractor(v) for v in dict_sortings.values()] mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list( name_list), do_matching=False, **kwargs) mcmp.graph = nx.read_gpickle( str(folder_path / 'multicomparison.gpickle')) # do step 3 and 4 mcmp._clean_graph() mcmp._do_agreement() mcmp._populate_spiketrains() return mcmp
def _init_worker_waveform_extractor(recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask): # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx['recording'] = recording if mode == 'memmap': # ~ if not platform.system().lower().startswith('linux'): # For OSX and windows : need to re open all npy files in r+ mode for each worker wfs_arrays = {} for unit_id, filename in wfs_arrays_info.items(): wfs_arrays[unit_id] = np.load(str(filename), mmap_mode='r+') elif mode == 'shared_memory': from multiprocessing.shared_memory import SharedMemory wfs_arrays = {} shms = {} for unit_id, (sm, shm_name, dtype, shape) in wfs_arrays_info.items(): shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) wfs_arrays[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx['shms'] = shms worker_ctx['unit_ids'] = unit_ids worker_ctx['spikes'] = spikes worker_ctx['wfs_arrays'] = wfs_arrays worker_ctx['nbefore'] = nbefore worker_ctx['nafter'] = nafter worker_ctx['return_scaled'] = return_scaled worker_ctx['inds_by_unit'] = inds_by_unit worker_ctx['sparsity_mask'] = sparsity_mask return worker_ctx
def _run_one(arg_list): # the multiprocessing python module force to have one unique tuple argument sorter_name, recording, output_folder, verbose, sorter_params = arg_list if isinstance(recording, dict): recording = load_extractor(recording) else: recording = recording SorterClass = sorter_dict[sorter_name] # because this is checks in run_sorters before this call remove_existing_folder = False # result is retrieve later delete_output_folder = False # because we won't want the loop/worker to break raise_error = False # only classmethod call not instance (stateless at instance level but state is in folder) output_folder = SorterClass.initialize_folder(recording, output_folder, verbose, remove_existing_folder) SorterClass.set_params_to_folder(recording, output_folder, sorter_params, verbose) SorterClass.setup_recording(recording, output_folder, verbose=verbose) SorterClass.run_from_folder(output_folder, raise_error, verbose)
def _init_worker_detect_peaks(recording, method, peak_sign, abs_threholds, n_shifts, neighbours_mask, extra_margin, localization_dict): """Initialize a worker for detecting peaks.""" if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) # create a local dict per worker worker_ctx = {} worker_ctx['recording'] = recording worker_ctx['method'] = method worker_ctx['peak_sign'] = peak_sign worker_ctx['abs_threholds'] = abs_threholds worker_ctx['n_shifts'] = n_shifts worker_ctx['neighbours_mask'] = neighbours_mask worker_ctx['extra_margin'] = extra_margin worker_ctx['localization_dict'] = localization_dict if localization_dict is not None: worker_ctx['contact_locations'] = recording.get_channel_locations() channel_distance = get_channel_distances(recording) ms_before = worker_ctx['localization_dict']['ms_before'] ms_after = worker_ctx['localization_dict']['ms_after'] worker_ctx['localization_dict']['nbefore'] = \ int(ms_before * recording.get_sampling_frequency() / 1000.) worker_ctx['localization_dict']['nafter'] = \ int(ms_after * recording.get_sampling_frequency() / 1000.) # channel sparsity channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < localization_dict[ 'local_radius_um'] worker_ctx['localization_dict']['neighbours_mask'] = neighbours_mask return worker_ctx
def test_BaseRecording(): num_seg = 2 num_chan = 3 num_samples = 30 sampling_frequency = 10000 dtype = 'int16' files_path = [f'test_base_recording_{i}.raw' for i in range(num_seg)] for i in range(num_seg): a = np.memmap(files_path[i], dtype=dtype, mode='w+', shape=(num_samples, num_chan)) a[:] = np.random.randn(*a.shape).astype(dtype) rec = BinaryRecordingExtractor(files_path, sampling_frequency, num_chan, dtype) print(rec) assert rec.get_num_segments() == 2 assert rec.get_num_channels() == 3 assert np.all(rec.ids_to_indices([0, 1, 2]) == [0, 1, 2]) assert np.all( rec.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) # annotations / properties rec.annotate(yep='yop') assert rec.get_annotation('yep') == 'yop' rec.set_property('quality', [1., 3.3, np.nan]) values = rec.get_property('quality') assert np.all(values[:2] == [ 1., 3.3, ]) # dump/load dict d = rec.to_dict() rec2 = BaseExtractor.from_dict(d) rec3 = load_extractor(d) # dump/load json rec.dump_to_json('test_BaseRecording.json') rec2 = BaseExtractor.load('test_BaseRecording.json') rec3 = load_extractor('test_BaseRecording.json') # dump/load pickle rec.dump_to_pickle('test_BaseRecording.pkl') rec2 = BaseExtractor.load('test_BaseRecording.pkl') rec3 = load_extractor('test_BaseRecording.pkl') # cache to binary cache_folder = Path('./my_cache_folder') folder = cache_folder / 'simple_recording' rec.save(format='binary', folder=folder) rec2 = BaseExtractor.load_from_folder(folder) assert 'quality' in rec2.get_property_keys() # but also possible rec3 = BaseExtractor.load('./my_cache_folder/simple_recording') # cache to memory rec4 = rec3.save(format='memory') traces4 = rec4.get_traces(segment_index=0) traces = rec.get_traces(segment_index=0) assert np.array_equal(traces4, traces) # cache joblib several jobs rec.save(name='simple_recording_2', chunk_size=10, n_jobs=4) # set/get Probe only 2 channels probe = Probe(ndim=2) positions = [[0., 0.], [0., 15.], [0, 30.]] probe.set_contacts(positions=positions, shapes='circle', shape_params={'radius': 5}) probe.set_device_channel_indices([2, -1, 0]) probe.create_auto_shape() rec2 = rec.set_probe(probe, group_mode='by_shank') rec2 = rec.set_probe(probe, group_mode='by_probe') positions2 = rec2.get_channel_locations() assert np.array_equal(positions2, [[0, 30.], [0., 0.]]) probe2 = rec2.get_probe() positions3 = probe2.contact_positions assert np.array_equal(positions2, positions3) # from probeinterface.plotting import plot_probe_group, plot_probe # import matplotlib.pyplot as plt # plot_probe(probe) # plot_probe(probe2) # plt.show() # test return_scale sampling_frequency = 30000 traces = np.zeros((1000, 5), dtype='int16') rec_int16 = NumpyRecording([traces], sampling_frequency) assert rec_int16.get_dtype() == 'int16' print(rec_int16) traces_int16 = rec_int16.get_traces() assert traces_int16.dtype == 'int16' # return_scaled raise error when no gain_to_uV/offset_to_uV properties with pytest.raises(ValueError): traces_float32 = rec_int16.get_traces(return_scaled=True) rec_int16.set_property('gain_to_uV', [.195] * 5) rec_int16.set_property('offset_to_uV', [0.] * 5) traces_float32 = rec_int16.get_traces(return_scaled=True) assert traces_float32.dtype == 'float32'
def get_recording(self, rec_name=None): rec_name = self._check_rec_name(rec_name) rec = load_extractor(self.study_folder / 'raw_files' / rec_name) return rec
def get_ground_truth(self, rec_name=None): rec_name = self._check_rec_name(rec_name) sorting = load_extractor(self.study_folder / 'ground_truth' / rec_name) return sorting
def load_from_folder(cls, folder): folder = Path(folder) recording = load_extractor(folder / 'recording.json') sorting = load_extractor(folder / 'sorting.json') we = cls(recording, sorting, folder) return we
def _run_from_folder(cls, output_folder, params, verbose): import herdingspikes as hs recording = load_extractor(output_folder / 'spikeinterface_recording.json') p = params # Bandpass filter if p['filter'] and p['freq_min'] is not None and p[ 'freq_max'] is not None: recording = st.bandpass_filter(recording=recording, freq_min=p['freq_min'], freq_max=p['freq_max']) if p['pre_scale']: recording = st.normalize_by_quantile(recording=recording, scale=p['pre_scale_value'], median=0.0, q1=0.05, q2=0.95) print( 'Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI' ) old_api_recording = RecordingExtractorOldAPI(recording) # this should have its name changed Probe = hs.probe.RecordingExtractor( old_api_recording, masked_channels=p['probe_masked_channels'], inner_radius=p['probe_inner_radius'], neighbor_radius=p['probe_neighbor_radius'], event_length=p['probe_event_length'], peak_jitter=p['probe_peak_jitter']) H = hs.HSDetection(Probe, file_directory_name=str(output_folder), left_cutout_time=p['left_cutout_time'], right_cutout_time=p['right_cutout_time'], threshold=p['detect_threshold'], to_localize=True, num_com_centers=p['num_com_centers'], maa=p['maa'], ahpthr=p['ahpthr'], out_file_name=p['out_file_name'], decay_filtering=p['decay_filtering'], save_all=p['save_all'], amp_evaluation_time=p['amp_evaluation_time'], spk_evaluation_time=p['spk_evaluation_time']) H.DetectFromRaw(load=True, tInc=int(p['t_inc'])) sorted_file = str(output_folder / 'HS2_sorted.hdf5') if (not H.spikes.empty): C = hs.HSClustering(H) C.ShapePCA(pca_ncomponents=p['pca_ncomponents'], pca_whiten=p['pca_whiten']) C.CombinedClustering(alpha=p['clustering_alpha'], cluster_subset=p['clustering_subset'], bandwidth=p['clustering_bandwidth'], bin_seeding=p['clustering_bin_seeding'], n_jobs=p['clustering_n_jobs'], min_bin_freq=p['clustering_min_bin_freq']) else: C = hs.HSClustering(H) if p['filter_duplicates']: uids = C.spikes.cl.unique() for u in uids: s = C.spikes[C.spikes.cl == u].t.diff( ) < p['spk_evaluation_time'] / 1000 * Probe.fps C.spikes = C.spikes.drop(s.index[s]) if verbose: print('Saving to', sorted_file) C.SaveHDF5(sorted_file, sampling=Probe.fps)
def _run_from_folder(cls, output_folder, params, verbose): source_dir = Path(__file__).parent p = params.copy() if p['detect_sign'] < 0: p['detect_sign'] = 'neg' elif p['detect_sign'] > 0: p['detect_sign'] = 'pos' else: p['detect_sign'] = 'both' if not p['enable_detect_filter']: p['detect_filter_order'] = 0 del p['enable_detect_filter'] if not p['enable_sort_filter']: p['sort_filter_order'] = 0 del p['enable_sort_filter'] if p['interpolation']: p['interpolation'] = 'y' else: p['interpolation'] = 'n' recording = load_extractor(output_folder / 'spikeinterface_recording.json') samplerate = recording.get_sampling_frequency() p['sr'] = samplerate num_channels = recording.get_num_channels() tmpdir = output_folder par_str = '' par_renames = { 'detect_sign': 'detection', 'detect_threshold': 'stdmin', 'feature_type': 'features', 'detect_filter_fmin': 'detect_fmin', 'detect_filter_fmax': 'detect_fmax', 'detect_filter_order': 'detect_order', 'sort_filter_fmin': 'sort_fmin', 'sort_filter_fmax': 'sort_fmax', 'sort_filter_order': 'sort_order' } for key, value in p.items(): if type(value) == str: value = '\'{}\''.format(value) elif type(value) == bool: value = '{}'.format(value).lower() if key in par_renames: key = par_renames[key] par_str += 'par.{} = {};\n'.format(key, value) if verbose: print('Running waveclus in {tmpdir}...'.format(tmpdir=tmpdir)) matlab_code = _matlab_code.format( waveclus_path=WaveClusSorter.waveclus_path, source_path=source_dir, tmpdir=tmpdir.absolute(), nChans=num_channels, parameters=par_str) with (output_folder / 'run_waveclus.m').open('w') as f: f.write(matlab_code) if 'win' in sys.platform and sys.platform != 'darwin': shell_cmd = ''' {disk_move} cd {tmpdir} matlab -nosplash -wait -log -r run_waveclus '''.format(disk_move=str(tmpdir)[:2], tmpdir=tmpdir) else: shell_cmd = ''' #!/bin/bash cd "{tmpdir}" matlab -nosplash -nodisplay -log -r run_waveclus '''.format(tmpdir=tmpdir) shell_cmd = ShellScript( shell_cmd, script_path=output_folder / f'run_{cls.sorter_name}', log_path=output_folder / f'{cls.sorter_name}.log', verbose=verbose) shell_cmd.start() retcode = shell_cmd.wait() if retcode != 0: raise Exception('waveclus returned a non-zero exit code') result_fname = tmpdir / 'times_results.mat' if not result_fname.is_file(): raise Exception(f'Result file does not exist: {result_fname}')