def test_get_closest_channels(): rec = generate_recording(num_channels=32, sampling_frequency=1000., durations=[0.1]) closest_channels_inds, distances = get_closest_channels(rec) closest_channels_inds, distances = get_closest_channels(rec, num_channels=4)
def test_spikeinterface_viewer(interactive=False): import spikeinterface as si from spikeinterface.core.testing_tools import generate_recording, generate_sorting recording = generate_recording() sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording) sorting = generate_sorting() spike_source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting) app = ephyviewer.mkQApp() win = ephyviewer.MainViewer(debug=True, show_auto_scale=True) view = ephyviewer.TraceViewer(source=sig_source, name='signals') win.add_view(view) view = ephyviewer.SpikeTrainViewer(source=spike_source, name='spikes') win.add_view(view) if interactive: win.show() app.exec_() else: # close thread properly win.close()
def test_filter_opencl(): rec = generate_recording( num_channels=256, # num_channels = 32, sampling_frequency=30000., durations=[ 100.325, ], # durations = [10.325, 3.5], ) rec = rec.save(total_memory="100M", n_jobs=1, progress_bar=True) print(rec.get_dtype()) print(rec.is_dumpable) # print(rec.to_dict()) rec_filtered = filter(rec, engine='scipy') rec_filtered = rec_filtered.save(chunk_size=1000, progress_bar=True, n_jobs=30) rec2 = filter(rec, engine='opencl') rec2_cached0 = rec2.save(chunk_size=1000, verbose=False, progress_bar=True, n_jobs=1)
def test_filter(): rec = generate_recording() rec = rec.save() rec2 = bandpass_filter(rec, freq_min=300., freq_max=6000.) # compute by chunk rec2_cached0 = rec2.save(chunk_size=100000, verbose=False, progress_bar=True) # compute by chunkf with joblib rec2_cached1 = rec2.save(total_memory="10k", n_jobs=4, verbose=True) # compute once rec2_cached2 = rec2.save(verbose=False) trace0 = rec2.get_traces(segment_index=0) trace1 = rec2_cached1.get_traces(segment_index=0) # other filtering types rec3 = filter(rec, band=[40., 60.], btype='bandstop') rec4 = filter(rec, band=500., btype='highpass', filter_mode='ba', filter_order=2) rec5 = notch_filter(rec, freq=3000, q=30, margin_ms=5.)
def test_time_handling(): cache_folder = Path('./my_cache_folder') durations = [[10], [10, 5]] # test multi-segment for i, dur in enumerate(durations): rec = generate_recording(num_channels=4, durations=dur) sort = generate_sorting(num_units=10, durations=dur) for segment_index in range(rec.get_num_segments()): original_times = rec.get_times(segment_index=segment_index) new_times = original_times + 5 rec.set_times(new_times, segment_index=segment_index) sort.register_recording(rec) assert sort.has_recording() rec_cache = rec.save(folder=cache_folder / f"rec{i}") for segment_index in range(sort.get_num_segments()): assert rec.has_time_vector(segment_index=segment_index) assert sort.has_time_vector(segment_index=segment_index) # times are correctly saved by the recording assert np.allclose( rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index)) # spike times are correctly adjusted for u in sort.get_unit_ids(): spike_times = sort.get_unit_spike_train( u, segment_index=segment_index, return_times=True) rec_times = rec.get_times(segment_index=segment_index) assert np.all(spike_times >= rec_times[0]) assert np.all(spike_times <= rec_times[-1])
def test_common_reference(): rec = generate_recording(durations=[5.], num_channels=4) rec._main_ids = np.array(['a', 'b', 'c', 'd']) rec = rec.save() # no groups rec_cmr = common_reference(rec, reference='global', operator='median') rec_car = common_reference(rec, reference='global', operator='average') rec_sin = common_reference(rec, reference='single', ref_channel_ids=['a']) rec_local_car = common_reference(rec, reference='local', local_radius=(20, 65), operator='median') rec_cmr.save(verbose=False) rec_car.save(verbose=False) rec_sin.save(verbose=False) rec_local_car.save(verbose=False) traces = rec.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01) assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0]) assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01)
def test_write_memory_recording(): # 2 segments recording = generate_recording(num_channels=2, durations=[10.325, 3.5]) # make dumpable recording = recording.save() # write with loop write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory='100k', progress_bar=True) if HAVE_SHAREDMEMORY and platform.system() != 'Windows': # write parrallel write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory='100k') # write parrallel write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, total_memory='200k', progress_bar=True)
def test_rectify(): rec = generate_recording() rec2 = rectify(rec) rec2.save(verbose=False) traces = rec2.get_traces(segment_index=0, channel_ids=[1]) assert traces.shape[1] == 1
def test_remove_artifacts(): # one segment only rec = generate_recording(durations=[10.]) triggers = [15000, 30000] list_triggers = [triggers] ms = 10 ms_frames = int(ms * rec.get_sampling_frequency() / 1000) traces_all_0_clean = rec.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1_clean = rec.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10) traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10, end_frame=triggers[0] + 10) traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10, end_frame=triggers[1] + 10) assert not np.any(traces_all_0) assert not np.any(traces_all_1) assert not np.any(traces_short_0) assert not np.any(traces_short_1) rec_rmart_lin = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="linear") traces_all_0 = rec_rmart_lin.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1 = rec_rmart_lin.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) assert not np.allclose(traces_all_0, traces_all_0_clean) assert not np.allclose(traces_all_1, traces_all_1_clean) rec_rmart_cub = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="cubic") traces_all_0 = rec_rmart_cub.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1 = rec_rmart_cub.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) assert not np.allclose(traces_all_0, traces_all_0_clean) assert not np.allclose(traces_all_1, traces_all_1_clean)
def test_get_chunk_with_margin(): rec = generate_recording(num_channels=1, sampling_frequency=1000., durations=[10.]) rec_seg = rec._recording_segments[0] length = rec_seg.get_num_samples() # rec_segment, start_frame, end_frame, channel_indices, sample_margin traces, l, r = get_chunk_with_margin(rec_seg, None, None, None, 10) assert l == 0 and r == 0 traces, l, r = get_chunk_with_margin(rec_seg, 5, None, None, 10) assert l == 5 and r == 0 traces, l, r = get_chunk_with_margin(rec_seg, length - 1000, length - 5, None, 10) assert l == 10 and r == 5 assert traces.shape[0] == 1010 traces, l, r = get_chunk_with_margin(rec_seg, 2000, 3000, None, 10) assert l == 10 and r == 10 assert traces.shape[0] == 1020 # add zeros traces, l, r = get_chunk_with_margin(rec_seg, 5, 1005, None, 10, add_zeros=True) assert traces.shape[0] == 1020 assert l == 10 assert r == 10 assert np.all(traces[:5] == 0) traces, l, r = get_chunk_with_margin(rec_seg, length - 1005, length - 5, None, 10, add_zeros=True) assert traces.shape[0] == 1020 assert np.all(traces[-5:] == 0) assert l == 10 assert r == 10 traces, l, r = get_chunk_with_margin(rec_seg, length - 500, length + 500, None, 10, add_zeros=True) assert traces.shape[0] == 1020 assert np.all(traces[-510:] == 0) assert l == 10 assert r == 510
def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000., durations=[10., 20.]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) assert chunks.shape == (50000, 1)
def test_normalize_by_quantile(): rec = generate_recording() rec2 = normalize_by_quantile(rec, mode='by_channel') rec2.save(verbose=False) traces = rec2.get_traces(segment_index=0, channel_ids=[1]) assert traces.shape[1] == 1 rec2 = normalize_by_quantile(rec, mode='pool_channel') rec2.save(verbose=False)
def test_blank_staturationy(): rec = generate_recording() rec2 = blank_staturation(rec, abs_threshold=3.) rec2.save(verbose=False) rec3 = blank_staturation(rec, quantile_threshold=0.01, direction='both') rec3.save(verbose=False) traces = rec2.get_traces(segment_index=0, channel_ids=[1]) assert traces.shape[1] == 1
def test_clip(): rec = generate_recording() rec2 = clip(rec, a_min=-2, a_max=3.) rec2.save(verbose=False) rec3 = clip(rec, a_min=-1.5) rec3.save(verbose=False) traces = rec2.get_traces(segment_index=0, channel_ids=[1]) assert traces.shape[1] == 1
def test_get_noise_levels(): rec = generate_recording(num_channels=2, sampling_frequency=1000., durations=[60.]) noise_levels = get_noise_levels(rec, return_scaled=False) print(noise_levels) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) noise_levels = get_noise_levels(rec, return_scaled=True) print(noise_levels)
def test_WaveformExtractor(): durations = [30, 40] sampling_frequency = 30000. # 2 segments recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec1" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) # test with dump !!!! recording = recording.save() sorting = sorting.save() folder = Path('test_waveform_extractor') if folder.is_dir(): shutil.rmtree(folder) we = WaveformExtractor.create(recording, sorting, folder) we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500) we.run_extract_waveforms(n_jobs=1, chunk_size=30000) we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True) wfs = we.get_waveforms(0) assert wfs.shape[0] <= 500 assert wfs.shape[1:] == (210, 2) wfs, sampled_index = we.get_waveforms(0, with_index=True) # load back we = WaveformExtractor.load_from_folder(folder) wfs = we.get_waveforms(0) template = we.get_template(0) assert template.shape == (210, 2) templates = we.get_all_templates() assert templates.shape == (5, 210, 2) wf_std = we.get_template(0, mode='std') assert wf_std.shape == (210, 2) wfs_std = we.get_all_templates(mode='std') assert wfs_std.shape == (5, 210, 2) wf_segment = we.get_template_segment(unit_id=0, segment_index=0) assert wf_segment.shape == (210, 2) assert wf_segment.shape == (210, 2)
def test_portability(): durations = [30, 40] sampling_frequency = 30000. folder_to_move = Path("original_folder") if folder_to_move.is_dir(): shutil.rmtree(folder_to_move) folder_to_move.mkdir() folder_moved = Path("moved_folder") if folder_moved.is_dir(): shutil.rmtree(folder_moved) # folder_moved.mkdir() # 2 segments num_channels = 2 recording = generate_recording(num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = folder_to_move / "rec" recording = recording.save(folder=folder_rec) num_units = 15 sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) folder_sort = folder_to_move / "sort" sorting = sorting.save(folder=folder_sort) wf_folder = folder_to_move / "waveform_extractor" if wf_folder.is_dir(): shutil.rmtree(wf_folder) # save with relative paths we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True) # move all to a separate folder shutil.copytree(folder_to_move, folder_moved) wf_folder_moved = folder_moved / "waveform_extractor" we_loaded = extract_waveforms(recording, sorting, wf_folder_moved, load_if_exists=True) assert we_loaded.recording is not None assert we_loaded.sorting is not None assert np.allclose( we.recording.get_channel_ids(), we_loaded.recording.get_channel_ids()) assert np.allclose( we.sorting.get_unit_ids(), we_loaded.sorting.get_unit_ids()) for unit in we.sorting.get_unit_ids(): wf = we.get_waveforms(unit_id=unit) wf_loaded = we_loaded.get_waveforms(unit_id=unit) assert np.allclose(wf, wf_loaded)
def test_scale(): rec = generate_recording() n = rec.get_num_channels() gain = np.ones(n) * 2. offset = np.ones(n) * -10. rec2 = scale(rec, gain=gain, offset=offset) rec2.get_traces(segment_index=0) rec2 = scale(rec, gain=2., offset=-10.) rec2.get_traces(segment_index=0) rec2 = scale(rec, gain=gain, offset=-10.) rec2.get_traces(segment_index=0)
def test_spikeinterface_sources(): import spikeinterface as si from spikeinterface.core.testing_tools import generate_recording, generate_sorting recording = generate_recording() source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording) print(source) print(source.t_start, source.nb_channel, source.sample_rate) sorting = generate_sorting() source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting) print(source) print(source.t_start, source.nb_channel, source.get_channel_name())
def test_extract_waveforms(): # 2 segments durations = [30, 40] sampling_frequency = 30000. recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec2" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) folder_sort = "wf_sort2" sorting = sorting.save(folder=folder_sort) # test without dump !!!! # recording = recording.save() # sorting = sorting.save() folder1 = Path('test_extract_waveforms_1job') if folder1.is_dir(): shutil.rmtree(folder1) we1 = extract_waveforms(recording, sorting, folder1, max_spikes_per_unit=None, return_scaled=False) folder2 = Path('test_extract_waveforms_2job') if folder2.is_dir(): shutil.rmtree(folder2) we2 = extract_waveforms(recording, sorting, folder2, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=False) wf1 = we1.get_waveforms(0) wf2 = we2.get_waveforms(0) assert np.array_equal(wf1, wf2) folder3 = Path('test_extract_waveforms_returnscaled') if folder3.is_dir(): shutil.rmtree(folder3) # set scaling values to recording gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) we3 = extract_waveforms(recording, sorting, folder3, n_jobs=2, total_memory="10M", max_spikes_per_unit=None, return_scaled=True) wf3 = we3.get_waveforms(0) assert np.array_equal((wf1).astype("float32") * gain, wf3)
def test_filter(): rec = generate_recording() rec = rec.save() rec2 = bandpass_filter(rec, freq_min=300., freq_max=6000.) # compute by chunk rec2_cached0 = rec2.save(chunk_size=100000, verbose=False, progress_bar=True) # compute by chunkf with joblib rec2_cached1 = rec2.save(total_memory="10k", n_jobs=4, verbose=True) # compute once rec2_cached2 = rec2.save(verbose=False) trace0 = rec2.get_traces(segment_index=0) trace1 = rec2_cached1.get_traces(segment_index=0) # other filtering types rec3 = filter(rec, band=500., btype='highpass', filter_mode='ba', filter_order=2) rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.) # filter from coefficients coeff = iirfilter(8, [0.02, 0.4], rs=30, btype='band', analog=False, ftype='cheby2', output='sos') rec5 = filter(rec, coeff=coeff, filter_mode='sos') # compute by chunk rec5_cached0 = rec5.save(chunk_size=100000, verbose=False, progress_bar=True) trace50 = rec5.get_traces(segment_index=0) trace51 = rec5_cached0.get_traces(segment_index=0) assert np.allclose(rec.get_times(0), rec2.get_times(0))
def test_ensure_n_jobs(): recording = generate_recording() n_jobs = ensure_n_jobs(recording) assert n_jobs == 1 n_jobs = ensure_n_jobs(recording, n_jobs=0) assert n_jobs == 1 n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 # not dumpable force n_jobs=1 n_jobs = ensure_n_jobs(recording, n_jobs=-1) assert n_jobs == 1 # dumpable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1
def test_sparsity(): durations = [30] sampling_frequency = 30000. recording = generate_recording(num_channels=10, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec3" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) folder_sort = "wf_sort3" sorting = sorting.save(folder=folder_sort) folder = Path('test_extract_waveforms_sparsity') if folder.is_dir(): shutil.rmtree(folder) we = extract_waveforms(recording, sorting, folder, max_spikes_per_unit=None) # sparsity: same number of channels num_channels = 3 channel_bounds = [2, 9] sparsity_same = {} sparsity_diff = {} for unit in sorting.get_unit_ids(): sparsity_same[unit] = np.random.permutation(recording.get_channel_ids())[:num_channels] rand_channel_num = np.random.randint(channel_bounds[0], channel_bounds[1]) sparsity_diff[unit] = np.random.permutation(recording.get_channel_ids())[:rand_channel_num] print(sparsity_same) print(sparsity_diff) for unit in sorting.get_unit_ids(): wf_same = we.get_waveforms(unit_id=unit, sparsity=sparsity_same) temp_same = we.get_template(unit_id=unit, sparsity=sparsity_same) assert wf_same.shape[-1] == num_channels assert temp_same.shape[-1] == num_channels wf_diff = we.get_waveforms(unit_id=unit, sparsity=sparsity_diff) temp_diff = we.get_template(unit_id=unit, sparsity=sparsity_diff) assert wf_diff.shape[-1] == len(sparsity_diff[unit]) assert temp_diff.shape[-1] == len(sparsity_diff[unit])
def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) # make dumpable recording = recording.save() init_args = 'a', 120, 'yep' # no chunk processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None) processor.run() # chunk + loop processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_memory="500k") processor.run() # chunk + parralel processor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=True, progress_bar=True, n_jobs=2, total_memory="200k", job_name='job_name') processor.run()
def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == 'float32' # make dumpable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) assert chunk_size == 32000000 chunk_size = ensure_chunk_size(recording, chunk_memory="256M") assert chunk_size == 32000000 chunk_size = ensure_chunk_size(recording, chunk_memory="1k") assert chunk_size == 125 chunk_size = ensure_chunk_size(recording, chunk_memory="1G") assert chunk_size == 125000000
def test_frame_slicing(): duration = [10] rec = generate_recording(num_channels=4, durations=duration) sort = generate_sorting(num_units=10, durations=duration) original_times = rec.get_times() new_times = original_times + 5 rec.set_times(new_times) sort.register_recording(rec) start_frame = 3 * rec.get_sampling_frequency() end_frame = 7 * rec.get_sampling_frequency() rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) for u in sort_slice.get_unit_ids(): spike_times = sort_slice.get_unit_spike_train(u, return_times=True) rec_times = rec_slice.get_times() assert np.all(spike_times >= rec_times[0]) assert np.all(spike_times <= rec_times[-1])
def test_write_binary_recording(): # 2 segments recording = generate_recording(num_channels=2, durations=[10.325, 3.5]) # make dumpable recording = recording.save() # write with loop write_binary_recording(recording, file_paths=['binary01.raw', 'binary02.raw'], dtype=None, verbose=True, n_jobs=1) write_binary_recording(recording, file_paths=['binary01.raw', 'binary02.raw'], dtype=None, verbose=True, n_jobs=1, chunk_memory='100k', progress_bar=True) # write parrallel write_binary_recording(recording, file_paths=['binary01.raw', 'binary02.raw'], dtype=None, verbose=False, n_jobs=2, chunk_memory='100k') # write parrallel write_binary_recording(recording, file_paths=['binary01.raw', 'binary02.raw'], dtype=None, verbose=False, n_jobs=2, total_memory='200k', progress_bar=True)
def test_center(): rec = generate_recording() rec2 = center(rec, mode='median') rec2.get_traces(segment_index=0)
def test_waveform_tools(): durations = [30, 40] sampling_frequency = 30000. # 2 segments num_channels = 2 recording = generate_recording(num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec1" #~ recording = recording.save(folder=folder_rec) num_units = 15 sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) # test with dump !!!! recording = recording.save() sorting = sorting.save() #~ we = WaveformExtractor.create(recording, sorting, folder) nbefore = int(3. * sampling_frequency / 1000.) nafter = int(4. * sampling_frequency / 1000.) dtype = recording.get_dtype() return_scaled = False spikes = sorting.to_spike_vector() unit_ids = sorting.unit_ids some_job_kwargs = [ {}, { 'n_jobs': 1, 'chunk_size': 3000, 'progress_bar': True }, { 'n_jobs': 2, 'chunk_size': 3000, 'progress_bar': True }, ] # memmap mode list_wfs = [] for j, job_kwargs in enumerate(some_job_kwargs): wf_folder = Path(f'test_waveform_tools_{j}') if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir() wfs_arrays, wfs_arrays_info = allocate_waveforms(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) for unit_ind, unit_id in enumerate(unit_ids): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind) list_wfs.append( {unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) _check_all_wf_equal(list_wfs) # memory if platform.system() != 'Windows': # shared memory on windows is buggy... list_wfs = [] for job_kwargs in some_job_kwargs: wfs_arrays, wfs_arrays_info = allocate_waveforms( recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) for unit_ind, unit_id in enumerate(unit_ids): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind) list_wfs.append( {unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) # to avoid warning we need to first destroy arrays then sharedmemm object del wfs_arrays del wfs_arrays_info _check_all_wf_equal(list_wfs) # with sparsity wf_folder = Path('test_waveform_tools_sparse') if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir() sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype='bool') wfs_arrays, wfs_arrays_info = allocate_waveforms( recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) job_kwargs = {'n_jobs': 1, 'chunk_size': 3000, 'progress_bar': True} distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs)
def test_normalize_by_quantile(): rec = generate_recording() rec2 = whiten(rec) rec2.save(verbose=False)