def get_unit_amplitudes(waveform_extractor, peak_sign='neg', outputs='concatenated', **job_kwargs): """ Computes the spike amplitudes from a WaveformExtractor. Amplitudes can be computed in absolute value (uV) or relative to the template amplitude. 1. The waveform extractor is used to determine the max channel per unit. 2. Then a "peak_shift" is estimated because for some sorter the spike index is not always at the extremum. 3. Extract all epak chunk by chunk (parallel or not) """ we = waveform_extractor recording = we.recording sorting = we.sorting all_spikes = sorting.get_all_spike_trains() extremum_channels_index = get_template_extremum_channel( waveform_extractor, peak_sign=peak_sign, outputs='index') peak_shifts = get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign='neg') # and run func = _unit_amplitudes_chunk init_func = _init_worker_unit_amplitudes init_args = (recording.to_dict(), sorting.to_dict(), extremum_channels_index, peak_shifts) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='extract amplitudes', **job_kwargs) out = processor.run() amps, segments = zip(*out) amps = np.concatenate(amps) segments = np.concatenate(segments) amplitudes = [] for segment_index in range(recording.get_num_segments()): mask = segments == segment_index amplitudes.append(amps[mask]) if outputs == 'concatenated': return amplitudes elif outputs == 'by_units': amplitudes_by_units = [] for segment_index in range(recording.get_num_segments()): amplitudes_by_units.append({}) for unit_id in sorting.unit_ids: spike_times, spike_labels = all_spikes[segment_index] mask = spike_labels == unit_id amps = amplitudes[segment_index][mask] amplitudes_by_units[segment_index][unit_id] = amps return amplitudes_by_units
def localize_peaks(recording, peaks, method='center_of_mass', local_radius_um=150, ms_before=0.3, ms_after=0.6, **job_kwargs): """ Localize peak (spike) in 2D or 3D depending the probe.ndim of the recording. Parameters ---------- recording: RecordingExtractor The recording extractor object peaks: numpy peak vector given by detect_peaks() in "compact_numpy" way. method: str Method to be used ('center_of_mass') local_radius_um: float Radius in micrometer to make neihgborhood for channel around the peak ms_before: float The left window before a peak in millisecond ms_after: float The left window before a peak in millisecond {} Returns ------- peak_locations: np.array Array with estimated x-y location for each spike """ assert method in ('center_of_mass',) # find channel neighbours assert local_radius_um is not None channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < local_radius_um nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.) nafter = int(ms_after * recording.get_sampling_frequency() / 1000.) contact_locations = recording.get_channel_locations() # TODO # make a memmap for peaks to avoid serilisation # and run func = _localize_peaks_chunk init_func = _init_worker_localize_peaks init_args = (recording.to_dict(), peaks, method, nbefore, nafter, neighbours_mask, contact_locations) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='localize peaks', **job_kwargs) peak_locations = processor.run() peak_locations = np.concatenate(peak_locations) return peak_locations
def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs): """ This run the PCs on all spike from the sorting. This is a long computation because waveform need to be extracted from each spikes. Used mainly for `export_to_phy()` PCs are exported to a .npy single file. """ p = self._params we = self.waveform_extractor sorting = we.sorting recording = we.recording assert sorting.get_num_segments() == 1 assert p['mode'] in ('by_channel_local', 'by_channel_global') file_path = Path(file_path) all_spikes = sorting.get_all_spike_trains(outputs='unit_index') spike_times, spike_labels = all_spikes[0] max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels()) best_channels_index = get_template_best_channels(we, max_channels_per_template, peak_sign=peak_sign, outputs='index') unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids] if p['mode'] == 'by_channel_local': all_pca = self._fit_by_channel_local() elif p['mode'] == 'by_channel_global': one_pca = self._fit_by_channel_global() all_pca = [one_pca] * recording.get_num_channels() # nSpikes, nFeaturesPerChannel, nPCFeatures # this come from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets shape = (spike_times.size, p['n_components'], max_channels_per_template) all_pcs = np.lib.format.open_memmap(file_path, mode='w+', dtype='float32', shape=shape) # and run func = _all_pc_extractor_chunk init_func = _init_work_all_pc_extractor n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording, ) else: init_args = (recording.to_dict(), ) init_args = init_args + (all_pcs, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, all_pca) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs) processor.run()
def find_spike_from_templates(recording, waveform_extractor, method='simple', method_kwargs={}, **job_kwargs): """ Find spike from a recording from known given templates. Template are represented as WaveformExtractor so statistics can be extracted. Parameters ---------- recording: RecordingExtractor The recording extractor object method: 'simple' / ... peak_detect_kwargs: dict Params for peak detection job_kwargs: dict Parameters for ChunkRecordingExecutor """ assert method in ('simple', ) if method == 'simple': method_kwargs = check_kwargs_simple_matching(recording, waveform_extractor, method_kwargs) # and run func = _find_spike_chunk init_func = _init_worker_find_spike init_args = (recording.to_dict(), method, method_kwargs) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='find spikes', **job_kwargs) spikes = processor.run() spikes = np.concatenate(spikes) return spikes
def detect_peaks(recording, method='by_channel', peak_sign='neg', detect_threshold=5, n_shifts=2, local_radius_um=100, noise_levels=None, random_chunk_kwargs={}, outputs='numpy_compact', **job_kwargs): """ Peak detection ported from tridesclous into spikeinterface. Peak detection based on threhold crossing in term of k x MAD Ifg the MAD is not provide then it is estimated with random snipet Several methods: * 'by_channel' : peak are dettected in each channel independantly * 'locally_exclusive' : locally given a radius the best peak only is taken but not neirbhoring channels Parameters ---------- recording: RecordingExtractor The recording extractor object method: peak_sign='neg'/ 'pos' / 'both' Signa of the peak. detect_threshold: float Threshold in median absolute deviations (MAD) to detect peaks n_shifts: int Number of shifts to find peak. E.g. if n_shift is 2, a peak is detected (if detect_sign is 'negative') if a sample is below the threshold, the two samples before are higher than the sample, and the two samples after the sample are higher than the sample. noise_levels: np.array noise_levels can be provide externally if already computed. random_chunk_kwargs: dict A dict that contain option to randomize chunk for get_noise_levels() Only used if noise_levels is None numpy_compact: str numpy_compact/numpy_split/sorting The type of the output. By default "numpy_compact" give a vector with complex dtype. """ assert method in ('by_channel', 'locally_exclusive') assert peak_sign in ('both', 'neg', 'pos') assert outputs in ('numpy_compact', 'numpy_split', 'sorting') if method == 'locally_exclusive' and not HAVE_NUMBA: raise ModuleNotFoundError( '"locally_exclusive" need numba which is not installed') if noise_levels is None: noise_levels = get_noise_levels(recording, **random_chunk_kwargs) abs_threholds = noise_levels * detect_threshold if method == 'locally_exclusive': assert local_radius_um is not None channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < local_radius_um else: neighbours_mask = None # and run func = _detect_peaks_chunk init_func = _init_worker_detect_peaks init_args = (recording.to_dict(), method, peak_sign, abs_threholds, n_shifts, neighbours_mask) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, **job_kwargs) peaks = processor.run() peak_sample_inds, peak_chan_inds, peak_amplitudes, peak_segments = zip( *peaks) peak_sample_inds = np.concatenate(peak_sample_inds) peak_chan_inds = np.concatenate(peak_chan_inds) peak_amplitudes = np.concatenate(peak_amplitudes) peak_segments = np.concatenate(peak_segments) if outputs == 'numpy_compact': dtype = [('sample_ind', 'int64'), ('channel_ind', 'int64'), ('amplitude', 'float64'), ('segment_ind', 'int64')] peaks = np.zeros(peak_sample_inds.size, dtype=dtype) peaks['sample_ind'] = peak_sample_inds peaks['channel_ind'] = peak_chan_inds peaks['amplitude'] = peak_amplitudes peaks['segment_ind'] = peak_segments return peaks elif outputs == 'numpy_split': return peak_sample_inds, peak_chan_inds, peak_amplitudes, peak_segments elif outputs == 'sorting': #@alessio : here we can do what you did in old API # the output is a sorting where unit_id is in fact one channel raise NotImplementedError
def detect_peaks(recording, method='by_channel', peak_sign='neg', detect_threshold=5, n_shifts=2, local_radius_um=50, noise_levels=None, random_chunk_kwargs={}, outputs='numpy_compact', localization_dict=None, **job_kwargs): """Peak detection based on threshold crossing in term of k x MAD. Parameters ---------- recording: RecordingExtractor The recording extractor object. method: 'by_channel', 'locally_exclusive' Method to use. Options: * 'by_channel' : peak are detected in each channel independently * 'locally_exclusive' : a single best peak is taken from a set of neighboring channels peak_sign: 'neg', 'pos', 'both' Sign of the peak. detect_threshold: float Threshold, in median absolute deviations (MAD), to use to detect peaks. n_shifts: int Number of shifts to find peak. For example, if `n_shift` is 2, a peak is detected if a sample crosses the threshold, and the two samples before and after are above the sample. local_radius_um: float The radius to use for detection across local channels. noise_levels: array, optional Estimated noise levels to use, if already computed. If not provide then it is estimated from a random snippet of the data. random_chunk_kwargs: dict, optional A dict that contain option to randomize chunk for get_noise_levels(). Only used if noise_levels is None. outputs: 'numpy_compact', 'numpy_split', 'sorting' The type of the output. By default, "numpy_compact" returns an array with complex dtype. In case of 'sorting', each unit corresponds to a recording channel. localization_dict : dict, optional Can optionally do peak localization at the same time as detection. This avoids running `localize_peaks` separately and re-reading the entire dataset. {} Returns ------- peaks: array Detected peaks. Notes ----- This peak detection ported from tridesclous into spikeinterface. """ assert method in ('by_channel', 'locally_exclusive') assert peak_sign in ('both', 'neg', 'pos') assert outputs in ('numpy_compact', 'numpy_split', 'sorting') if method == 'locally_exclusive' and not HAVE_NUMBA: raise ModuleNotFoundError( '"locally_exclusive" need numba which is not installed') if noise_levels is None: noise_levels = get_noise_levels(recording, return_scaled=False, **random_chunk_kwargs) abs_threholds = noise_levels * detect_threshold if method == 'locally_exclusive': assert local_radius_um is not None channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < local_radius_um else: neighbours_mask = None # deal with margin if localization_dict is None: extra_margin = 0 else: assert isinstance(localization_dict, dict) assert localization_dict['method'] in dtype_localize_by_method.keys() localization_dict = init_kwargs_dict(localization_dict['method'], localization_dict) nbefore = int(localization_dict['ms_before'] * recording.get_sampling_frequency() / 1000.) nafter = int(localization_dict['ms_after'] * recording.get_sampling_frequency() / 1000.) extra_margin = max(nbefore, nafter) # and run func = _detect_peaks_chunk init_func = _init_worker_detect_peaks init_args = (recording.to_dict(), method, peak_sign, abs_threholds, n_shifts, neighbours_mask, extra_margin, localization_dict) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='detect peaks', **job_kwargs) peaks = processor.run() peaks = np.concatenate(peaks) if outputs == 'numpy_compact': return peaks elif outputs == 'sorting': return NumpySorting.from_peaks( peaks, sampling_frequency=recording.get_sampling_frequency())
def compute_amplitudes(self, **job_kwargs): we = self.waveform_extractor recording = we.recording sorting = we.sorting all_spikes = sorting.get_all_spike_trains(outputs='unit_index') self._all_spikes = all_spikes peak_sign = self._params['peak_sign'] return_scaled = self._params['return_scaled'] extremum_channels_index = get_template_extremum_channel( we, peak_sign=peak_sign, outputs='index') peak_shifts = get_template_extremum_channel_peak_shift( we, peak_sign=peak_sign) # put extremum_channels_index and peak_shifts in vector way extremum_channels_index = np.array( [extremum_channels_index[unit_id] for unit_id in sorting.unit_ids], dtype='int64') peak_shifts = np.array( [peak_shifts[unit_id] for unit_id in sorting.unit_ids], dtype='int64') if return_scaled: # check if has scaled values: if not we.recording.has_scaled_traces(): print("Setting 'return_scaled' to False") return_scaled = False # and run func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording, sorting) else: init_args = (recording.to_dict(), sorting.to_dict()) init_args = init_args + (extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='extract amplitudes', **job_kwargs) out = processor.run() amps, segments = zip(*out) amps = np.concatenate(amps) segments = np.concatenate(segments) self._amplitudes = [] for segment_index in range(recording.get_num_segments()): mask = segments == segment_index amps_seg = amps[mask] self._amplitudes.append(amps_seg) # save to folder file_amps = self.extension_folder / f'amplitude_segment_{segment_index}.npy' np.save(file_amps, amps_seg)
def detect_peaks(recording, method='by_channel', peak_sign='neg', detect_threshold=5, n_shifts=2, local_radius_um=50, noise_levels=None, random_chunk_kwargs={}, outputs='numpy_compact', localization_dict=None, **job_kwargs): """ Peak detection ported from tridesclous into spikeinterface. Peak detection based on threhold crossing in term of k x MAD If the MAD is not provide then it is estimated with random snipet Several methods: * 'by_channel' : peak are dettected in each channel independantly * 'locally_exclusive' : locally given a radius the best peak only is taken but not neighboring channels Parameters ---------- recording: RecordingExtractor The recording extractor object method: peak_sign='neg'/ 'pos' / 'both' Signa of the peak. detect_threshold: float Threshold in median absolute deviations (MAD) to detect peaks n_shifts: int Number of shifts to find peak. E.g. if n_shift is 2, a peak is detected (if detect_sign is 'negative') if a sample is below the threshold, the two samples before are higher than the sample, and the two samples after the sample are higher than the sample. noise_levels: np.array noise_levels can be provide externally if already computed. random_chunk_kwargs: dict A dict that contain option to randomize chunk for get_noise_levels() Only used if noise_levels is None numpy_compact: str numpy_compact/numpy_split/sorting The type of the output. By default "numpy_compact" give a vector with complex dtype. localization_dict : None or dict Can optionally do peak localisation at the same time as detection. This avoids to run localize_peaks separately and re read the entire dataset. job_kwargs: dict Parameters for ChunkRecordingExecutor """ assert method in ('by_channel', 'locally_exclusive') assert peak_sign in ('both', 'neg', 'pos') assert outputs in ('numpy_compact', 'numpy_split', 'sorting') if method == 'locally_exclusive' and not HAVE_NUMBA: raise ModuleNotFoundError( '"locally_exclusive" need numba which is not installed') if noise_levels is None: noise_levels = get_noise_levels(recording, **random_chunk_kwargs) abs_threholds = noise_levels * detect_threshold if method == 'locally_exclusive': assert local_radius_um is not None channel_distance = get_channel_distances(recording) neighbours_mask = channel_distance < local_radius_um else: neighbours_mask = None # deal with margin if localization_dict is None: extra_margin = 0 else: assert isinstance(localization_dict, dict) assert localization_dict['method'] in dtype_localize_by_method.keys() localization_dict = init_kwargs_dict(localization_dict['method'], localization_dict) nbefore = int(localization_dict['ms_before'] * recording.get_sampling_frequency() / 1000.) nafter = int(localization_dict['ms_after'] * recording.get_sampling_frequency() / 1000.) extra_margin = max(nbefore, nafter) # and run func = _detect_peaks_chunk init_func = _init_worker_detect_peaks init_args = (recording.to_dict(), method, peak_sign, abs_threholds, n_shifts, neighbours_mask, extra_margin, localization_dict) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='detect peaks', **job_kwargs) peaks = processor.run() peaks = np.concatenate(peaks) if outputs == 'numpy_compact': return peaks elif outputs == 'sorting': # @alessio : here we can do what you did in old API # the output is a sorting where unit_id is in fact one channel raise NotImplementedError
def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) # make dumpable recording = recording.save() init_args = 'a', 120, 'yep' # no chunk processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None) processor.run() # chunk + loop processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_memory="500k") processor.run() # chunk + parralel processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=True, n_jobs=2, total_memory="200k", job_name='job_name') processor.run()
def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs): """ Project all spikes from the sorting on the PCA model. This is a long computation because waveform need to be extracted from each spikes. Used mainly for `export_to_phy()` PCs are exported to a .npy single file. Parameters ---------- file_path : str or Path Path to npy file that will store the PCA projections max_channels_per_template : int, optionl Maximum number of best channels to compute PCA projections on peak_sign : str, optional Peak sign to get best channels ('neg', 'pos', 'both'), by default 'neg' {} """ p = self._params we = self.waveform_extractor sorting = we.sorting recording = we.recording assert sorting.get_num_segments() == 1 assert p['mode'] in ('by_channel_local', 'by_channel_global') file_path = Path(file_path) all_spikes = sorting.get_all_spike_trains(outputs='unit_index') spike_times, spike_labels = all_spikes[0] max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels()) best_channels_index = get_template_channel_sparsity(we, outputs="index", peak_sign=peak_sign, num_channels=max_channels_per_template) unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids] pca_model = self.get_pca_model() if p['mode'] in ['by_channel_global', 'concatenated']: pca_model = [pca_model] * recording.get_num_channels() # nSpikes, nFeaturesPerChannel, nPCFeatures # this come from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets shape = (spike_times.size, p['n_components'], max_channels_per_template) all_pcs = np.lib.format.open_memmap(filename=file_path, mode='w+', dtype='float32', shape=shape) all_pcs_args = dict(filename=file_path, mode='r+', dtype='float32', shape=shape) # and run func = _all_pc_extractor_chunk init_func = _init_work_all_pc_extractor n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording,) else: init_args = (recording.to_dict(),) init_args = init_args + (all_pcs_args, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, pca_model) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs) processor.run()
def localize_peaks(recording, peaks, ms_before=0.1, ms_after=0.3, method='center_of_mass', method_kwargs={}, **job_kwargs): #~ local_radius_um=150, """ Localize peak (spike) in 2D or 3D depending the method. When a probe is 2D then: * X is axis 0 of the probe * Y is axis 1 of the probe * Z is orthogonal to the plane of the probe Parameters ---------- recording: RecordingExtractor The recording extractor object peaks: numpy peak vector given by detect_peaks() in "compact_numpy" way. ms_before: float The left window before a peak in millisecond ms_after: float The left window before a peak in millisecond method: str Method to be used ('center_of_mass' or 'monopolar_triangulation') method_kwargs: dict of kwargs method 'center_of_mass': * local_radius_um: float For channel sparsity 'monopolar_triangulation' also have: * local_radius_um: float For channel sparsity * max_distance_um: float default 1000 boundary for distance estimation Returns ------- peak_locations: np.array Array with estimated location for each spike The dtype depend on the method. ('x', 'y') or ('x', 'y', 'z', 'alpha') """ assert method in possible_localization_methods, f"Method {method} is not supported. Choose from {possible_localization_methods}" # handle default method_kwargs method_kwargs = init_kwargs_dict(method, method_kwargs) nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.) nafter = int(ms_after * recording.get_sampling_frequency() / 1000.) contact_locations = recording.get_channel_locations() # margin at border for get_trace margin = max(nbefore, nafter) # TODO # make a memmap for peaks to avoid serilisation # and run func = _localize_peaks_chunk init_func = _init_worker_localize_peaks init_args = (recording.to_dict(), peaks, method, method_kwargs, nbefore, nafter, contact_locations, margin) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='localize peaks', **job_kwargs) peak_locations = processor.run() peak_locations = np.concatenate(peak_locations) return peak_locations
def get_spike_amplitudes(waveform_extractor, peak_sign='neg', outputs='concatenated', return_scaled=True, **job_kwargs): """ Computes the spike amplitudes from a WaveformExtractor. 1. The waveform extractor is used to determine the max channel per unit. 2. Then a "peak_shift" is estimated because for some sorters the spike index is not always at the peak. 3. Amplitudes are extracted in chunks (parallel or not) Parameters ---------- waveform_extractor: WaveformExtractor The waveform extractor object peak_sign: str The sign to compute maximum channel: - 'neg' - 'pos' - 'both' return_scaled: bool If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. outputs: str How the output should be returned: - 'concatenated' - 'by_unit' {} Returns ------- amplitudes: np.array The spike amplitudes. - If 'concatenated' all amplitudes for all spikes and all units are concatenated - If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units) """ we = waveform_extractor recording = we.recording sorting = we.sorting all_spikes = sorting.get_all_spike_trains() extremum_channels_index = get_template_extremum_channel( waveform_extractor, peak_sign=peak_sign, outputs='index') peak_shifts = get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign='neg') if return_scaled: # check if has scaled values: if not waveform_extractor.recording.has_scaled_traces(): print("Setting 'return_scaled' to False") return_scaled = False # and run func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes init_args = (recording.to_dict(), sorting.to_dict(), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='extract amplitudes', **job_kwargs) out = processor.run() amps, segments = zip(*out) amps = np.concatenate(amps) segments = np.concatenate(segments) amplitudes = [] for segment_index in range(recording.get_num_segments()): mask = segments == segment_index amplitudes.append(amps[mask]) if outputs == 'concatenated': return amplitudes elif outputs == 'by_unit': amplitudes_by_unit = [] for segment_index in range(recording.get_num_segments()): amplitudes_by_unit.append({}) for unit_id in sorting.unit_ids: spike_times, spike_labels = all_spikes[segment_index] mask = spike_labels == unit_id amps = amplitudes[segment_index][mask] amplitudes_by_unit[segment_index][unit_id] = amps return amplitudes_by_unit
def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) # make dumpable recording = recording.save() def func(segment_index, start_frame, end_frame, worker_ctx): import os, time # print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid()) time.sleep(0.010) # time.sleep(1.0) return os.getpid() def init_func(arg1, arg2, arg3): worker_ctx = {} worker_ctx['arg1'] = arg1 worker_ctx['arg2'] = arg2 worker_ctx['arg3'] = arg3 return worker_ctx init_args = 'a', 120, 'yep' # no chunk processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None) processor.run() # chunk + loop processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_memory="500k") processor.run() # chunk + parralel processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=True, n_jobs=2, total_memory="200k", job_name='job_name') processor.run()