def test_ensure_n_jobs(): recording = generate_recording() n_jobs = ensure_n_jobs(recording) assert n_jobs == 1 n_jobs = ensure_n_jobs(recording, n_jobs=0) assert n_jobs == 1 n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 # not dumpable force n_jobs=1 n_jobs = ensure_n_jobs(recording, n_jobs=-1) assert n_jobs == 1 # dumpable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1
def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs): """ This run the PCs on all spike from the sorting. This is a long computation because waveform need to be extracted from each spikes. Used mainly for `export_to_phy()` PCs are exported to a .npy single file. """ p = self._params we = self.waveform_extractor sorting = we.sorting recording = we.recording assert sorting.get_num_segments() == 1 assert p['mode'] in ('by_channel_local', 'by_channel_global') file_path = Path(file_path) all_spikes = sorting.get_all_spike_trains(outputs='unit_index') spike_times, spike_labels = all_spikes[0] max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels()) best_channels_index = get_template_best_channels(we, max_channels_per_template, peak_sign=peak_sign, outputs='index') unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids] if p['mode'] == 'by_channel_local': all_pca = self._fit_by_channel_local() elif p['mode'] == 'by_channel_global': one_pca = self._fit_by_channel_global() all_pca = [one_pca] * recording.get_num_channels() # nSpikes, nFeaturesPerChannel, nPCFeatures # this come from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets shape = (spike_times.size, p['n_components'], max_channels_per_template) all_pcs = np.lib.format.open_memmap(file_path, mode='w+', dtype='float32', shape=shape) # and run func = _all_pc_extractor_chunk init_func = _init_work_all_pc_extractor n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording, ) else: init_args = (recording.to_dict(), ) init_args = init_args + (all_pcs, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, all_pca) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs) processor.run()
def compute_amplitudes(self, **job_kwargs): we = self.waveform_extractor recording = we.recording sorting = we.sorting all_spikes = sorting.get_all_spike_trains(outputs='unit_index') self._all_spikes = all_spikes peak_sign = self._params['peak_sign'] return_scaled = self._params['return_scaled'] extremum_channels_index = get_template_extremum_channel( we, peak_sign=peak_sign, outputs='index') peak_shifts = get_template_extremum_channel_peak_shift( we, peak_sign=peak_sign) # put extremum_channels_index and peak_shifts in vector way extremum_channels_index = np.array( [extremum_channels_index[unit_id] for unit_id in sorting.unit_ids], dtype='int64') peak_shifts = np.array( [peak_shifts[unit_id] for unit_id in sorting.unit_ids], dtype='int64') if return_scaled: # check if has scaled values: if not we.recording.has_scaled_traces(): print("Setting 'return_scaled' to False") return_scaled = False # and run func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording, sorting) else: init_args = (recording.to_dict(), sorting.to_dict()) init_args = init_args + (extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, handle_returns=True, job_name='extract amplitudes', **job_kwargs) out = processor.run() amps, segments = zip(*out) amps = np.concatenate(amps) segments = np.concatenate(segments) self._amplitudes = [] for segment_index in range(recording.get_num_segments()): mask = segments == segment_index amps_seg = amps[mask] self._amplitudes.append(amps_seg) # save to folder file_amps = self.extension_folder / f'amplitude_segment_{segment_index}.npy' np.save(file_amps, amps_seg)
def run_for_all_spikes(self, file_path, max_channels_per_template=16, peak_sign='neg', **job_kwargs): """ Project all spikes from the sorting on the PCA model. This is a long computation because waveform need to be extracted from each spikes. Used mainly for `export_to_phy()` PCs are exported to a .npy single file. Parameters ---------- file_path : str or Path Path to npy file that will store the PCA projections max_channels_per_template : int, optionl Maximum number of best channels to compute PCA projections on peak_sign : str, optional Peak sign to get best channels ('neg', 'pos', 'both'), by default 'neg' {} """ p = self._params we = self.waveform_extractor sorting = we.sorting recording = we.recording assert sorting.get_num_segments() == 1 assert p['mode'] in ('by_channel_local', 'by_channel_global') file_path = Path(file_path) all_spikes = sorting.get_all_spike_trains(outputs='unit_index') spike_times, spike_labels = all_spikes[0] max_channels_per_template = min(max_channels_per_template, we.recording.get_num_channels()) best_channels_index = get_template_channel_sparsity(we, outputs="index", peak_sign=peak_sign, num_channels=max_channels_per_template) unit_channels = [best_channels_index[unit_id] for unit_id in sorting.unit_ids] pca_model = self.get_pca_model() if p['mode'] in ['by_channel_global', 'concatenated']: pca_model = [pca_model] * recording.get_num_channels() # nSpikes, nFeaturesPerChannel, nPCFeatures # this come from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets shape = (spike_times.size, p['n_components'], max_channels_per_template) all_pcs = np.lib.format.open_memmap(filename=file_path, mode='w+', dtype='float32', shape=shape) all_pcs_args = dict(filename=file_path, mode='r+', dtype='float32', shape=shape) # and run func = _all_pc_extractor_chunk init_func = _init_work_all_pc_extractor n_jobs = ensure_n_jobs(recording, job_kwargs.get('n_jobs', None)) if n_jobs == 1: init_args = (recording,) else: init_args = (recording.to_dict(),) init_args = init_args + (all_pcs_args, spike_times, spike_labels, we.nbefore, we.nafter, unit_channels, pca_model) processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name='extract PCs', **job_kwargs) processor.run()