def setup_module(): for folder in ('toy_rec_1seg', 'toy_sorting_1seg', 'toy_waveforms_1seg', 'toy_rec_2seg', 'toy_sorting_2seg', 'toy_waveforms_2seg'): if Path(folder).is_dir(): shutil.rmtree(folder) recording, sorting = toy_example(num_segments=2, num_units=10) recording = recording.save(folder='toy_rec_2seg') sorting = sorting.save(folder='toy_sorting_2seg') we = extract_waveforms(recording, sorting, 'toy_waveforms_2seg', ms_before=3., ms_after=4., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) recording, sorting = toy_example(num_segments=1, num_units=10, num_channels=12) recording = recording.save(folder='toy_rec_1seg') sorting = sorting.save(folder='toy_sorting_1seg') we = extract_waveforms(recording, sorting, 'toy_waveforms_1seg', ms_before=3., ms_after=4., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000)
def test_export_to_phy_by_property(): num_units = 4 recording, sorting = se.toy_example(num_channels=8, duration=10, num_units=num_units, num_segments=1) recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) sorting.set_property("group", [0, 0, 1, 1]) waveform_folder = Path('waveforms') waveform_folder_rm = Path('waveforms_rm') output_folder = Path('phy_output') output_folder_rm = Path('phy_output_rm') rec_folder = Path("rec") sort_folder = Path("sort") for f in (waveform_folder, waveform_folder_rm, output_folder, output_folder_rm, rec_folder, sort_folder): if f.is_dir(): shutil.rmtree(f) recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) export_to_phy(waveform_extractor, output_folder, compute_pc_features=True, compute_amplitudes=True, max_channels_per_template=8, sparsity_dict=dict(method="by_property", by_property="group"), n_jobs=1, chunk_size=10000, progress_bar=True) template_inds = np.load(output_folder / "template_ind.npy") assert template_inds.shape == (num_units, 4) # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) export_to_phy(waveform_extractor_rm, output_folder_rm, compute_pc_features=True, compute_amplitudes=True, max_channels_per_template=8, sparsity_dict=dict(method="by_property", by_property="group"), n_jobs=1, chunk_size=10000, progress_bar=True) template_inds = np.load(output_folder_rm / "template_ind.npy") assert template_inds.shape == (num_units, 4) assert len(np.where(template_inds == -1)[0]) > 0
def test_portability(): durations = [30, 40] sampling_frequency = 30000. folder_to_move = Path("original_folder") if folder_to_move.is_dir(): shutil.rmtree(folder_to_move) folder_to_move.mkdir() folder_moved = Path("moved_folder") if folder_moved.is_dir(): shutil.rmtree(folder_moved) # folder_moved.mkdir() # 2 segments num_channels = 2 recording = generate_recording(num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = folder_to_move / "rec" recording = recording.save(folder=folder_rec) num_units = 15 sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) folder_sort = folder_to_move / "sort" sorting = sorting.save(folder=folder_sort) wf_folder = folder_to_move / "waveform_extractor" if wf_folder.is_dir(): shutil.rmtree(wf_folder) # save with relative paths we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True) # move all to a separate folder shutil.copytree(folder_to_move, folder_moved) wf_folder_moved = folder_moved / "waveform_extractor" we_loaded = extract_waveforms(recording, sorting, wf_folder_moved, load_if_exists=True) assert we_loaded.recording is not None assert we_loaded.sorting is not None assert np.allclose( we.recording.get_channel_ids(), we_loaded.recording.get_channel_ids()) assert np.allclose( we.sorting.get_unit_ids(), we_loaded.sorting.get_unit_ids()) for unit in we.sorting.get_unit_ids(): wf = we.get_waveforms(unit_id=unit) wf_loaded = we_loaded.get_waveforms(unit_id=unit) assert np.allclose(wf, wf_loaded)
def test_get_unit_amplitudes(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) we = extract_waveforms(recording, sorting, 'mearec_waveforms', ms_before=1., ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, load_if_exists=True) amplitudes = get_unit_amplitudes(we, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1) # print(amplitudes) amplitudes = get_unit_amplitudes(we, peak_sign='neg', outputs='by_units', chunk_size=10000, n_jobs=1)
def test_export_to_phy(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) waveform_folder = Path('waveforms') output_folder = Path('phy_output') for f in (waveform_folder, output_folder): if f.is_dir(): shutil.rmtree(f) waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) export_to_phy(recording, sorting, output_folder, waveform_extractor, compute_pc_features=True, compute_amplitudes=True, max_channels_per_template=8, n_jobs=1, chunk_size=10000, progress_bar=True)
def test_get_template_channel_sparsity(): we = WaveformExtractor.load_from_folder('toy_waveforms') sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5) print(sparsity) sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5) print(sparsity) sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50) print(sparsity) sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50) print(sparsity) sparsity = get_template_channel_sparsity(we, method='threshold', outputs='id', threshold=3) print(sparsity) sparsity = get_template_channel_sparsity(we, method='threshold', outputs='index', threshold=3) print(sparsity) # load from folder because sorting properties must be loaded rec = load_extractor('toy_rec') sort = load_extractor('toy_sort') we = extract_waveforms(rec, sort, 'toy_waveforms_1') sparsity = get_template_channel_sparsity(we, method='by_property', outputs='id', by_property="group") print(sparsity) sparsity = get_template_channel_sparsity(we, method='by_property', outputs='index', by_property="group") print(sparsity)
def test_find_spike_from_templates(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording, gt_sorting = read_mearec(local_path) folder = 'waveforms_mearec' we = extract_waveforms(recording, gt_sorting, folder, load_if_exists=True, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) # ~ print(we) spikes = find_spike_from_templates( recording, we, method='simple', ) print(spikes)
def test_extract_waveforms(): # 2 segments durations = [30, 40] sampling_frequency = 30000. recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec2" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) folder_sort = "wf_sort2" sorting = sorting.save(folder=folder_sort) # test without dump !!!! # recording = recording.save() # sorting = sorting.save() folder1 = Path('test_extract_waveforms_1job') if folder1.is_dir(): shutil.rmtree(folder1) we1 = extract_waveforms(recording, sorting, folder1, max_spikes_per_unit=None, return_scaled=False) folder2 = Path('test_extract_waveforms_2job') if folder2.is_dir(): shutil.rmtree(folder2) we2 = extract_waveforms(recording, sorting, folder2, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=False) wf1 = we1.get_waveforms(0) wf2 = we2.get_waveforms(0) assert np.array_equal(wf1, wf2) folder3 = Path('test_extract_waveforms_returnscaled') if folder3.is_dir(): shutil.rmtree(folder3) # set scaling values to recording gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) we3 = extract_waveforms(recording, sorting, folder3, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=True) wf3 = we3.get_waveforms(0) assert np.array_equal((wf1).astype("float32") * gain, wf3)
def setup_module(): for folder in ('mearec_waveforms'): if Path(folder).is_dir(): shutil.rmtree(folder) local_path = download_dataset(remote_path='mearec/mearec_test_10s.h5') recording, sorting = read_mearec(local_path) print(recording) print(sorting) we = extract_waveforms(recording, sorting, 'mearec_waveforms', ms_before=3., ms_after=4., max_spikes_per_unit=500, load_if_exists=True, n_jobs=1, chunk_size=30000)
def setUp(self): #~ self._rec, self._sorting = se.toy_example(num_channels=10, duration=10, num_segments=1) #~ self._rec = self._rec.save() #~ self._sorting = self._sorting.save() local_path = download_dataset(remote_path='mearec/mearec_test_10s.h5') self._rec = se.MEArecRecordingExtractor(local_path) self._sorting = se.MEArecSortingExtractor(local_path) self.num_units = len(self._sorting.get_unit_ids()) # self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) self._we = extract_waveforms(self._rec, self._sorting, './mearec_test', load_if_exists=True) self._amplitudes = st.get_spike_amplitudes(self._we, peak_sign='neg', outputs='by_unit') self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting)
def test_extract_waveforms(): # 2 segments durations = [30, 40] sampling_frequency = 30000. recording = generate_recording(num_channels = 2, durations=durations, sampling_frequency=sampling_frequency) sorting =generate_sorting(num_units=5, sampling_frequency = sampling_frequency, durations=durations) recording = recording.save() sorting = sorting.save() folder = Path('test_extract_waveforms') if folder.is_dir(): shutil.rmtree(folder) we = extract_waveforms(recording, sorting, folder) print(we)
def test_export_to_phy_by_sparsity(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) waveform_folder = Path('waveforms') output_folder_radius = Path('phy_output_radius') output_folder_thr = Path('phy_output_thr') for f in (waveform_folder, output_folder_radius, output_folder_thr): if f.is_dir(): shutil.rmtree(f) waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) export_to_phy(waveform_extractor, output_folder_radius, compute_pc_features=True, compute_amplitudes=True, max_channels_per_template=None, sparsity_dict=dict(method="radius", radius_um=50), n_jobs=1, chunk_size=10000, progress_bar=True) template_ind = np.load(output_folder_radius / "template_ind.npy") # templates have different shapes! assert -1 in template_ind export_to_phy(waveform_extractor, output_folder_thr, compute_pc_features=True, compute_amplitudes=True, max_channels_per_template=None, sparsity_dict=dict(method="threshold", threshold=2), n_jobs=1, chunk_size=10000, progress_bar=True) template_ind = np.load(output_folder_thr / "template_ind.npy") # templates have different shapes! assert -1 in template_ind
def test_sparsity(): durations = [30] sampling_frequency = 30000. recording = generate_recording(num_channels=10, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec3" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) folder_sort = "wf_sort3" sorting = sorting.save(folder=folder_sort) folder = Path('test_extract_waveforms_sparsity') if folder.is_dir(): shutil.rmtree(folder) we = extract_waveforms(recording, sorting, folder, max_spikes_per_unit=None) # sparsity: same number of channels num_channels = 3 channel_bounds = [2, 9] sparsity_same = {} sparsity_diff = {} for unit in sorting.get_unit_ids(): sparsity_same[unit] = np.random.permutation(recording.get_channel_ids())[:num_channels] rand_channel_num = np.random.randint(channel_bounds[0], channel_bounds[1]) sparsity_diff[unit] = np.random.permutation(recording.get_channel_ids())[:rand_channel_num] print(sparsity_same) print(sparsity_diff) for unit in sorting.get_unit_ids(): wf_same = we.get_waveforms(unit_id=unit, sparsity=sparsity_same) temp_same = we.get_template(unit_id=unit, sparsity=sparsity_same) assert wf_same.shape[-1] == num_channels assert temp_same.shape[-1] == num_channels wf_diff = we.get_waveforms(unit_id=unit, sparsity=sparsity_diff) temp_diff = we.get_template(unit_id=unit, sparsity=sparsity_diff) assert wf_diff.shape[-1] == len(sparsity_diff[unit]) assert temp_diff.shape[-1] == len(sparsity_diff[unit])
def test_export_report(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording, sorting = se.read_mearec(local_path) waveform_folder = Path('waveforms') output_folder = Path('mearec_GT_report') for f in (waveform_folder, output_folder): if f.is_dir(): shutil.rmtree(f) waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) job_kwargs = dict(n_jobs=1, chunk_size=30000, progress_bar=True) export_report(waveform_extractor, output_folder, **job_kwargs)
def test_find_spikes_from_templates(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording, gt_sorting = read_mearec(local_path) folder = 'waveforms_mearec' we = extract_waveforms(recording, gt_sorting, folder, load_if_exists=True, ms_before=1, ms_after=2., max_spikes_per_unit=500, return_scaled=False, n_jobs=1, chunk_size=10000) method_kwargs = { 'waveform_extractor': we, 'noise_levels': get_noise_levels(recording), } sampling_frequency = recording.get_sampling_frequency() result = {} for method in template_matching_methods.keys(): spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs, n_jobs=1, chunk_size=30000, progress_bar=True) result[method] = NumpySorting.from_times_labels( spikes['sample_ind'], spikes['cluster_ind'], sampling_frequency)
def test_compute_spike_amplitudes_parallel(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) folder = Path('mearec_waveforms_all') we = extract_waveforms(recording, sorting, folder, ms_before=1., ms_after=2., max_spikes_per_unit=None, n_jobs=1, chunk_size=30000, load_if_exists=True) amplitudes1 = compute_spike_amplitudes(we, peak_sign='neg', load_if_exists=False, outputs='concatenated', chunk_size=10000, n_jobs=1) # TODO : fix multi processing for spike amplitudes!!!!!!! amplitudes2 = compute_spike_amplitudes(we, peak_sign='neg', load_if_exists=False, outputs='concatenated', chunk_size=10000, n_jobs=2) assert np.array_equal(amplitudes1[0], amplitudes2[0])
def _do_recovery_loop(task_args): key, well_detected_score, isi_thr, fr_thr, sample_window_ms, \ percentage_spikes, balance_spikes, detect_threshold, method, skew_thr, n_jobs, we_params, compare, \ output_folder, job_kwargs = task_args recording = load_extractor(output_folder / 'back_recording' / key[1] / key[0]) if compare is True: gt = load_extractor(output_folder / 'back_recording' / key[1] / (key[0] + '_gt')) else: gt = None sorting = load_extractor(output_folder / 'back_recording' / key[0] / (key[1] + '_pre')) we = extract_waveforms( recording, sorting, folder=output_folder / 'waveforms' / key[0] / key[1], load_if_exists=we_params['load_if_exists'], ms_before=we_params['ms_before'], ms_after=we_params['ms_after'], max_spikes_per_unit=we_params['max_spikes_per_unit'], return_scaled=we_params['return_scaled'], dtype=we_params['dtype'], overwrite=True, **job_kwargs) if gt is not None: comparison = sc.compare_sorter_to_ground_truth(tested_sorting=sorting, gt_sorting=gt) selected_units = comparison.get_well_detected_units( well_detected_score) print(key[1][:-1]) if key[1] == 'hdsort': selected_units = [unit - 1000 for unit in selected_units] else: isi_violation = st.compute_isi_violations(we)[0] good_isi = np.argwhere( np.array(list(isi_violation.values())) < isi_thr)[:, 0] firing_rate = st.compute_firing_rate(we) good_fr_idx_up = np.argwhere( np.array(list(firing_rate.values())) < fr_thr[1])[:, 0] good_fr_idx_down = np.argwhere( np.array(list(firing_rate.values())) > fr_thr[0])[:, 0] selected_units = [ unit for unit in range(sorting.get_num_units()) if unit in good_fr_idx_up and unit in good_fr_idx_down and unit in good_isi ] templates = we.get_all_templates() templates_dict = { str(unit): templates[unit - 1] for unit in selected_units } recording_subtracted = subtract_templates(recording, sorting, templates_dict, we.nbefore, selected_units) sorter = SpyICASorter(recording_subtracted) sorter.mask_traces(sample_window_ms=sample_window_ms, percent_spikes=percentage_spikes, balance_spikes_on_channel=balance_spikes, detect_threshold=detect_threshold, method=method, **job_kwargs) sorter.compute_ica(n_comp='all') cleaning_result = clean_correlated_sources( recording, sorter.W_ica, skew_thresh=skew_thr, n_jobs=n_jobs, chunk_size=recording.get_num_samples(0) // n_jobs, **job_kwargs) sorter.A_ica[cleaning_result[1]] = -sorter.A_ica[cleaning_result[1]] sorter.W_ica[cleaning_result[1]] = -sorter.W_ica[cleaning_result[1]] sorter.source_idx = cleaning_result[0] sorter.cleaned_A_ica = sorter.A_ica[cleaning_result[0]] sorter.cleaned_W_ica = sorter.W_ica[cleaning_result[0]] ica_recording = st.preprocessing.lin_map(recording_subtracted, sorter.cleaned_W_ica) recording_back = st.preprocessing.lin_map(ica_recording, sorter.cleaned_A_ica.T) recording_back.save_to_folder(folder=output_folder / 'back_recording' / key[0] / key[1])
def test_compute_spike_amplitudes(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) folder = Path('mearec_waveforms') we = extract_waveforms(recording, sorting, folder, ms_before=1., ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, load_if_exists=False, overwrite=True) amplitudes = compute_spike_amplitudes(we, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1) amplitudes = compute_spike_amplitudes(we, peak_sign='neg', outputs='by_unit', chunk_size=10000, n_jobs=1) gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) folder = Path('mearec_waveforms_scaled') we_scaled = extract_waveforms(recording, sorting, folder, ms_before=1., ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, load_if_exists=False, overwrite=True, return_scaled=True) amplitudes_scaled = compute_spike_amplitudes(we_scaled, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1, return_scaled=True) amplitudes_unscaled = compute_spike_amplitudes(we_scaled, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1, return_scaled=False) assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain) # reload as an extension from we assert SpikeAmplitudesCalculator in we.get_available_extensions() assert we_scaled.is_extension('spike_amplitudes') sac = we.load_extension('spike_amplitudes') assert isinstance(sac, SpikeAmplitudesCalculator) assert sac._amplitudes is not None qmc = SpikeAmplitudesCalculator.load_from_folder(folder) assert sac._amplitudes is not None
probe = recording.get_probe() print(probe) from probeinterface.plotting import plot_probe plot_probe(probe) ############################################################################### # A :code:`WaveformExtractor` object can be created with the :code:`extract_waveforms` # function: folder = 'waveform_folder' we = extract_waveforms(recording, sorting, folder, ms_before=1.5, ms_after=2., max_spikes_per_unit=500, load_if_exists=True) print(we) ############################################################################### # Alternatively, the :code:`WaveformExtractor` object can be instantiated # directly. In this case, we need to :code:`set_params()` to set the desired # parameters: folder = 'waveform_folder2' we = WaveformExtractor.create(recording, sorting, folder, remove_if_exists=True)
# # Let's imagine that that sorting is in fact the output of a sorters. # local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) print(recording) print(sorting) ############################################################################## # Firt, we extractor waveforms and compute PC on it. folder = 'waveforms_mearec' we = si.extract_waveforms(recording, sorting, folder, load_if_exists=True, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) print(we) pc = st.compute_principal_components(we, load_if_exists=True, n_components=3, mode='by_channel_local') print(pc) ############################################################################## # Compute some metrics on it metrics = st.compute_quality_metrics(we, waveform_principal_component=pc, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) print(metrics)
def test_get_spike_amplitudes(): repo = 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' remote_path = 'mearec/mearec_test_10s.h5' local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) recording = se.MEArecRecordingExtractor(local_path) sorting = se.MEArecSortingExtractor(local_path) folder = Path('mearec_waveforms') we = extract_waveforms(recording, sorting, folder, ms_before=1., ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, load_if_exists=False, overwrite=True) amplitudes = get_spike_amplitudes(we, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1) amplitudes = get_spike_amplitudes(we, peak_sign='neg', outputs='by_unit', chunk_size=10000, n_jobs=1) gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) folder = Path('mearec_waveforms_scaled') we_scaled = extract_waveforms(recording, sorting, folder, ms_before=1., ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, load_if_exists=False, overwrite=True, return_scaled=True) amplitudes_scaled = get_spike_amplitudes(we_scaled, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1, return_scaled=True) amplitudes_unscaled = get_spike_amplitudes(we_scaled, peak_sign='neg', outputs='concatenated', chunk_size=10000, n_jobs=1, return_scaled=False) assert np.allclose(amplitudes_scaled[0], amplitudes_unscaled[0] * gain)