def test_get_closest_channels():
    rec = generate_recording(num_channels=32,
                             sampling_frequency=1000.,
                             durations=[0.1])
    closest_channels_inds, distances = get_closest_channels(rec)
    closest_channels_inds, distances = get_closest_channels(rec,
                                                            num_channels=4)
Exemple #2
0
def test_spikeinterface_viewer(interactive=False):
    import spikeinterface as si
    from spikeinterface.core.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording)

    sorting = generate_sorting()
    spike_source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting)

    app = ephyviewer.mkQApp()
    win = ephyviewer.MainViewer(debug=True, show_auto_scale=True)

    view = ephyviewer.TraceViewer(source=sig_source, name='signals')
    win.add_view(view)

    view = ephyviewer.SpikeTrainViewer(source=spike_source, name='spikes')
    win.add_view(view)

    if interactive:
        win.show()
        app.exec_()
    else:
        # close thread properly
        win.close()
Exemple #3
0
def test_filter_opencl():
    rec = generate_recording(
        num_channels=256,
        # num_channels = 32,
        sampling_frequency=30000.,
        durations=[
            100.325,
        ],
        # durations = [10.325, 3.5],
    )
    rec = rec.save(total_memory="100M", n_jobs=1, progress_bar=True)

    print(rec.get_dtype())
    print(rec.is_dumpable)
    # print(rec.to_dict())

    rec_filtered = filter(rec, engine='scipy')
    rec_filtered = rec_filtered.save(chunk_size=1000,
                                     progress_bar=True,
                                     n_jobs=30)

    rec2 = filter(rec, engine='opencl')
    rec2_cached0 = rec2.save(chunk_size=1000,
                             verbose=False,
                             progress_bar=True,
                             n_jobs=1)
Exemple #4
0
def test_filter():
    rec = generate_recording()
    rec = rec.save()

    rec2 = bandpass_filter(rec, freq_min=300., freq_max=6000.)

    # compute by chunk
    rec2_cached0 = rec2.save(chunk_size=100000,
                             verbose=False,
                             progress_bar=True)
    # compute by chunkf with joblib
    rec2_cached1 = rec2.save(total_memory="10k", n_jobs=4, verbose=True)
    # compute once
    rec2_cached2 = rec2.save(verbose=False)

    trace0 = rec2.get_traces(segment_index=0)
    trace1 = rec2_cached1.get_traces(segment_index=0)

    # other filtering types
    rec3 = filter(rec, band=[40., 60.], btype='bandstop')
    rec4 = filter(rec,
                  band=500.,
                  btype='highpass',
                  filter_mode='ba',
                  filter_order=2)

    rec5 = notch_filter(rec, freq=3000, q=30, margin_ms=5.)
Exemple #5
0
def test_time_handling():
    cache_folder = Path('./my_cache_folder')
    durations = [[10], [10, 5]]

    # test multi-segment
    for i, dur in enumerate(durations):
        rec = generate_recording(num_channels=4, durations=dur)
        sort = generate_sorting(num_units=10, durations=dur)

        for segment_index in range(rec.get_num_segments()):
            original_times = rec.get_times(segment_index=segment_index)
            new_times = original_times + 5
            rec.set_times(new_times, segment_index=segment_index)

        sort.register_recording(rec)
        assert sort.has_recording()

        rec_cache = rec.save(folder=cache_folder / f"rec{i}")

        for segment_index in range(sort.get_num_segments()):
            assert rec.has_time_vector(segment_index=segment_index)
            assert sort.has_time_vector(segment_index=segment_index)

            # times are correctly saved by the recording
            assert np.allclose(
                rec.get_times(segment_index=segment_index),
                rec_cache.get_times(segment_index=segment_index))

            # spike times are correctly adjusted
            for u in sort.get_unit_ids():
                spike_times = sort.get_unit_spike_train(
                    u, segment_index=segment_index, return_times=True)
                rec_times = rec.get_times(segment_index=segment_index)
                assert np.all(spike_times >= rec_times[0])
                assert np.all(spike_times <= rec_times[-1])
Exemple #6
0
def test_common_reference():
    rec = generate_recording(durations=[5.], num_channels=4)
    rec._main_ids = np.array(['a', 'b', 'c', 'd'])
    rec = rec.save()

    # no groups
    rec_cmr = common_reference(rec, reference='global', operator='median')
    rec_car = common_reference(rec, reference='global', operator='average')
    rec_sin = common_reference(rec, reference='single', ref_channel_ids=['a'])
    rec_local_car = common_reference(rec, reference='local', local_radius=(20, 65), operator='median')
   
    rec_cmr.save(verbose=False)
    rec_car.save(verbose=False)
    rec_sin.save(verbose=False)
    rec_local_car.save(verbose=False)

    traces = rec.get_traces()
    assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
    assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01)
    assert not np.all(rec_sin.get_traces()[0])
    assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])

    assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1),
                       atol=0.01)
    assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1),
                       atol=0.01)
def test_write_memory_recording():
    # 2 segments
    recording = generate_recording(num_channels=2, durations=[10.325, 3.5])
    # make dumpable
    recording = recording.save()

    # write with loop
    write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1)

    write_memory_recording(recording,
                           dtype=None,
                           verbose=True,
                           n_jobs=1,
                           chunk_memory='100k',
                           progress_bar=True)

    if HAVE_SHAREDMEMORY and platform.system() != 'Windows':
        # write parrallel
        write_memory_recording(recording,
                               dtype=None,
                               verbose=False,
                               n_jobs=2,
                               chunk_memory='100k')

        # write parrallel
        write_memory_recording(recording,
                               dtype=None,
                               verbose=False,
                               n_jobs=2,
                               total_memory='200k',
                               progress_bar=True)
Exemple #8
0
def test_rectify():
    rec = generate_recording()

    rec2 = rectify(rec)
    rec2.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
def test_remove_artifacts():
    # one segment only
    rec = generate_recording(durations=[10.])

    triggers = [15000, 30000]
    list_triggers = [triggers]

    ms = 10
    ms_frames = int(ms * rec.get_sampling_frequency() / 1000)

    traces_all_0_clean = rec.get_traces(start_frame=triggers[0] - ms_frames,
                                        end_frame=triggers[0] + ms_frames)
    traces_all_1_clean = rec.get_traces(start_frame=triggers[1] - ms_frames,
                                        end_frame=triggers[1] + ms_frames)

    rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10)
    traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames,
                                        end_frame=triggers[0] + ms_frames)
    traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10,
                                          end_frame=triggers[0] + 10)
    traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames,
                                        end_frame=triggers[1] + ms_frames)
    traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10,
                                          end_frame=triggers[1] + 10)

    assert not np.any(traces_all_0)
    assert not np.any(traces_all_1)
    assert not np.any(traces_short_0)
    assert not np.any(traces_short_1)

    rec_rmart_lin = remove_artifacts(rec,
                                     triggers,
                                     ms_before=10,
                                     ms_after=10,
                                     mode="linear")
    traces_all_0 = rec_rmart_lin.get_traces(start_frame=triggers[0] -
                                            ms_frames,
                                            end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_lin.get_traces(start_frame=triggers[1] -
                                            ms_frames,
                                            end_frame=triggers[1] + ms_frames)
    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)

    rec_rmart_cub = remove_artifacts(rec,
                                     triggers,
                                     ms_before=10,
                                     ms_after=10,
                                     mode="cubic")
    traces_all_0 = rec_rmart_cub.get_traces(start_frame=triggers[0] -
                                            ms_frames,
                                            end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_cub.get_traces(start_frame=triggers[1] -
                                            ms_frames,
                                            end_frame=triggers[1] + ms_frames)

    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)
Exemple #10
0
def test_get_chunk_with_margin():
    rec = generate_recording(num_channels=1,
                             sampling_frequency=1000.,
                             durations=[10.])
    rec_seg = rec._recording_segments[0]
    length = rec_seg.get_num_samples()

    #  rec_segment, start_frame, end_frame, channel_indices, sample_margin

    traces, l, r = get_chunk_with_margin(rec_seg, None, None, None, 10)
    assert l == 0 and r == 0

    traces, l, r = get_chunk_with_margin(rec_seg, 5, None, None, 10)
    assert l == 5 and r == 0

    traces, l, r = get_chunk_with_margin(rec_seg, length - 1000, length - 5,
                                         None, 10)
    assert l == 10 and r == 5
    assert traces.shape[0] == 1010

    traces, l, r = get_chunk_with_margin(rec_seg, 2000, 3000, None, 10)
    assert l == 10 and r == 10
    assert traces.shape[0] == 1020

    # add zeros
    traces, l, r = get_chunk_with_margin(rec_seg,
                                         5,
                                         1005,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert l == 10
    assert r == 10
    assert np.all(traces[:5] == 0)

    traces, l, r = get_chunk_with_margin(rec_seg,
                                         length - 1005,
                                         length - 5,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert np.all(traces[-5:] == 0)
    assert l == 10
    assert r == 10

    traces, l, r = get_chunk_with_margin(rec_seg,
                                         length - 500,
                                         length + 500,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert np.all(traces[-510:] == 0)
    assert l == 10
    assert r == 510
def test_get_random_data_chunks():
    rec = generate_recording(num_channels=1,
                             sampling_frequency=1000.,
                             durations=[10., 20.])
    chunks = get_random_data_chunks(rec,
                                    num_chunks_per_segment=50,
                                    chunk_size=500,
                                    seed=0)
    assert chunks.shape == (50000, 1)
def test_normalize_by_quantile():
    rec = generate_recording()

    rec2 = normalize_by_quantile(rec, mode='by_channel')
    rec2.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1

    rec2 = normalize_by_quantile(rec, mode='pool_channel')
    rec2.save(verbose=False)
Exemple #13
0
def test_blank_staturationy():
    rec = generate_recording()

    rec2 = blank_staturation(rec, abs_threshold=3.)
    rec2.save(verbose=False)

    rec3 = blank_staturation(rec, quantile_threshold=0.01, direction='both')
    rec3.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
Exemple #14
0
def test_clip():
    rec = generate_recording()

    rec2 = clip(rec, a_min=-2, a_max=3.)
    rec2.save(verbose=False)

    rec3 = clip(rec, a_min=-1.5)
    rec3.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
def test_get_noise_levels():
    rec = generate_recording(num_channels=2,
                             sampling_frequency=1000.,
                             durations=[60.])

    noise_levels = get_noise_levels(rec, return_scaled=False)
    print(noise_levels)

    rec.set_channel_gains(0.1)
    rec.set_channel_offsets(0)
    noise_levels = get_noise_levels(rec, return_scaled=True)
    print(noise_levels)
Exemple #16
0
def test_WaveformExtractor():
    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_waveform_extractor')
    if folder.is_dir():
        shutil.rmtree(folder)

    we = WaveformExtractor.create(recording, sorting, folder)

    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)

    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True)

    wfs = we.get_waveforms(0)
    assert wfs.shape[0] <= 500
    assert wfs.shape[1:] == (210, 2)

    wfs, sampled_index = we.get_waveforms(0, with_index=True)

    # load back
    we = WaveformExtractor.load_from_folder(folder)

    wfs = we.get_waveforms(0)

    template = we.get_template(0)
    assert template.shape == (210, 2)
    templates = we.get_all_templates()
    assert templates.shape == (5, 210, 2)

    wf_std = we.get_template(0, mode='std')
    assert wf_std.shape == (210, 2)
    wfs_std = we.get_all_templates(mode='std')
    assert wfs_std.shape == (5, 210, 2)


    wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
    assert wf_segment.shape == (210, 2)
    assert wf_segment.shape == (210, 2)
Exemple #17
0
def test_portability():
    durations = [30, 40]
    sampling_frequency = 30000.
    
    folder_to_move = Path("original_folder")
    if folder_to_move.is_dir():
        shutil.rmtree(folder_to_move)
    folder_to_move.mkdir()
    folder_moved = Path("moved_folder")
    if folder_moved.is_dir():
        shutil.rmtree(folder_moved)
    # folder_moved.mkdir()
        
    # 2 segments
    num_channels = 2
    recording = generate_recording(num_channels=num_channels, durations=durations, 
                                   sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = folder_to_move / "rec"
    recording = recording.save(folder=folder_rec)
    num_units = 15
    sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = folder_to_move / "sort"
    sorting = sorting.save(folder=folder_sort)

    wf_folder = folder_to_move / "waveform_extractor"
    if wf_folder.is_dir():
        shutil.rmtree(wf_folder)

    # save with relative paths
    we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True)

    # move all to a separate folder
    shutil.copytree(folder_to_move, folder_moved)
    wf_folder_moved = folder_moved / "waveform_extractor"
    we_loaded = extract_waveforms(recording, sorting, wf_folder_moved, load_if_exists=True)
    
    assert we_loaded.recording is not None
    assert we_loaded.sorting is not None

    assert np.allclose(
        we.recording.get_channel_ids(), we_loaded.recording.get_channel_ids())
    assert np.allclose(
        we.sorting.get_unit_ids(), we_loaded.sorting.get_unit_ids())


    for unit in we.sorting.get_unit_ids():
        wf = we.get_waveforms(unit_id=unit)
        wf_loaded = we_loaded.get_waveforms(unit_id=unit)

        assert np.allclose(wf, wf_loaded)
def test_scale():
    rec = generate_recording()
    n = rec.get_num_channels()
    gain = np.ones(n) * 2.
    offset = np.ones(n) * -10.

    rec2 = scale(rec, gain=gain, offset=offset)
    rec2.get_traces(segment_index=0)

    rec2 = scale(rec, gain=2., offset=-10.)
    rec2.get_traces(segment_index=0)

    rec2 = scale(rec, gain=gain, offset=-10.)
    rec2.get_traces(segment_index=0)
def test_spikeinterface_sources():
    import spikeinterface as si
    from spikeinterface.core.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording)
    print(source)

    print(source.t_start, source.nb_channel, source.sample_rate)

    sorting = generate_sorting()
    source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting)
    print(source)

    print(source.t_start, source.nb_channel, source.get_channel_name())
Exemple #20
0
def test_extract_waveforms():
    # 2 segments

    durations = [30, 40]
    sampling_frequency = 30000.

    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec2"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = "wf_sort2"
    sorting = sorting.save(folder=folder_sort)
    # test without dump !!!!
    #  recording = recording.save()
    #  sorting = sorting.save()

    folder1 = Path('test_extract_waveforms_1job')
    if folder1.is_dir():
        shutil.rmtree(folder1)
    we1 = extract_waveforms(recording, sorting, folder1, max_spikes_per_unit=None, return_scaled=False)

    folder2 = Path('test_extract_waveforms_2job')
    if folder2.is_dir():
        shutil.rmtree(folder2)
    we2 = extract_waveforms(recording, sorting, folder2, n_jobs=2, total_memory="10M", max_spikes_per_unit=None,
                            return_scaled=False)

    wf1 = we1.get_waveforms(0)
    wf2 = we2.get_waveforms(0)
    assert np.array_equal(wf1, wf2)

    folder3 = Path('test_extract_waveforms_returnscaled')
    if folder3.is_dir():
        shutil.rmtree(folder3)

    # set scaling values to recording
    gain = 0.1
    recording.set_channel_gains(gain)
    recording.set_channel_offsets(0)
    we3 = extract_waveforms(recording, sorting, folder3, n_jobs=2, total_memory="10M", max_spikes_per_unit=None,
                            return_scaled=True)

    wf3 = we3.get_waveforms(0)
    assert np.array_equal((wf1).astype("float32") * gain, wf3)
def test_filter():
    rec = generate_recording()
    rec = rec.save()

    rec2 = bandpass_filter(rec, freq_min=300., freq_max=6000.)

    # compute by chunk
    rec2_cached0 = rec2.save(chunk_size=100000,
                             verbose=False,
                             progress_bar=True)
    # compute by chunkf with joblib
    rec2_cached1 = rec2.save(total_memory="10k", n_jobs=4, verbose=True)
    # compute once
    rec2_cached2 = rec2.save(verbose=False)

    trace0 = rec2.get_traces(segment_index=0)
    trace1 = rec2_cached1.get_traces(segment_index=0)

    # other filtering types
    rec3 = filter(rec,
                  band=500.,
                  btype='highpass',
                  filter_mode='ba',
                  filter_order=2)

    rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.)

    # filter from coefficients
    coeff = iirfilter(8, [0.02, 0.4],
                      rs=30,
                      btype='band',
                      analog=False,
                      ftype='cheby2',
                      output='sos')
    rec5 = filter(rec, coeff=coeff, filter_mode='sos')

    # compute by chunk
    rec5_cached0 = rec5.save(chunk_size=100000,
                             verbose=False,
                             progress_bar=True)

    trace50 = rec5.get_traces(segment_index=0)
    trace51 = rec5_cached0.get_traces(segment_index=0)

    assert np.allclose(rec.get_times(0), rec2.get_times(0))
def test_ensure_n_jobs():
    recording = generate_recording()

    n_jobs = ensure_n_jobs(recording)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=0)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=1)
    assert n_jobs == 1

    # not dumpable force n_jobs=1
    n_jobs = ensure_n_jobs(recording, n_jobs=-1)
    assert n_jobs == 1

    # dumpable
    n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1)
    assert n_jobs > 1
Exemple #23
0
def test_sparsity():
    durations = [30]
    sampling_frequency = 30000.

    recording = generate_recording(num_channels=10, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec3"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = "wf_sort3"
    sorting = sorting.save(folder=folder_sort)

    folder = Path('test_extract_waveforms_sparsity')
    if folder.is_dir():
        shutil.rmtree(folder)
    we = extract_waveforms(recording, sorting, folder, max_spikes_per_unit=None)

    # sparsity: same number of channels
    num_channels = 3
    channel_bounds = [2, 9]
    sparsity_same = {}
    sparsity_diff = {}

    for unit in sorting.get_unit_ids():
        sparsity_same[unit] = np.random.permutation(recording.get_channel_ids())[:num_channels]
        rand_channel_num = np.random.randint(channel_bounds[0], channel_bounds[1])
        sparsity_diff[unit] = np.random.permutation(recording.get_channel_ids())[:rand_channel_num]

    print(sparsity_same)
    print(sparsity_diff)

    for unit in sorting.get_unit_ids():
        wf_same = we.get_waveforms(unit_id=unit, sparsity=sparsity_same)
        temp_same = we.get_template(unit_id=unit, sparsity=sparsity_same)

        assert wf_same.shape[-1] == num_channels
        assert temp_same.shape[-1] == num_channels

        wf_diff = we.get_waveforms(unit_id=unit, sparsity=sparsity_diff)
        temp_diff = we.get_template(unit_id=unit, sparsity=sparsity_diff)

        assert wf_diff.shape[-1] == len(sparsity_diff[unit])
        assert temp_diff.shape[-1] == len(sparsity_diff[unit])
def test_ChunkRecordingExecutor():
    recording = generate_recording(num_channels=2)
    # make dumpable
    recording = recording.save()

    init_args = 'a', 120, 'yep'

    # no chunk
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_size=None)
    processor.run()

    # chunk + loop
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_memory="500k")
    processor.run()

    # chunk + parralel
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=True,
                                       n_jobs=2,
                                       total_memory="200k",
                                       job_name='job_name')
    processor.run()
def test_ensure_chunk_size():
    recording = generate_recording(num_channels=2)
    dtype = recording.get_dtype()
    assert dtype == 'float32'
    # make dumpable
    recording = recording.save()

    chunk_size = ensure_chunk_size(recording,
                                   total_memory="512M",
                                   chunk_size=None,
                                   chunk_memory=None,
                                   n_jobs=2)
    assert chunk_size == 32000000

    chunk_size = ensure_chunk_size(recording, chunk_memory="256M")
    assert chunk_size == 32000000

    chunk_size = ensure_chunk_size(recording, chunk_memory="1k")
    assert chunk_size == 125

    chunk_size = ensure_chunk_size(recording, chunk_memory="1G")
    assert chunk_size == 125000000
Exemple #26
0
def test_frame_slicing():
    duration = [10]

    rec = generate_recording(num_channels=4, durations=duration)
    sort = generate_sorting(num_units=10, durations=duration)

    original_times = rec.get_times()
    new_times = original_times + 5
    rec.set_times(new_times)

    sort.register_recording(rec)

    start_frame = 3 * rec.get_sampling_frequency()
    end_frame = 7 * rec.get_sampling_frequency()

    rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame)
    sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame)

    for u in sort_slice.get_unit_ids():
        spike_times = sort_slice.get_unit_spike_train(u, return_times=True)
        rec_times = rec_slice.get_times()
        assert np.all(spike_times >= rec_times[0])
        assert np.all(spike_times <= rec_times[-1])
def test_write_binary_recording():
    # 2 segments
    recording = generate_recording(num_channels=2, durations=[10.325, 3.5])
    # make dumpable
    recording = recording.save()

    # write with loop
    write_binary_recording(recording,
                           file_paths=['binary01.raw', 'binary02.raw'],
                           dtype=None,
                           verbose=True,
                           n_jobs=1)

    write_binary_recording(recording,
                           file_paths=['binary01.raw', 'binary02.raw'],
                           dtype=None,
                           verbose=True,
                           n_jobs=1,
                           chunk_memory='100k',
                           progress_bar=True)

    # write parrallel
    write_binary_recording(recording,
                           file_paths=['binary01.raw', 'binary02.raw'],
                           dtype=None,
                           verbose=False,
                           n_jobs=2,
                           chunk_memory='100k')

    # write parrallel
    write_binary_recording(recording,
                           file_paths=['binary01.raw', 'binary02.raw'],
                           dtype=None,
                           verbose=False,
                           n_jobs=2,
                           total_memory='200k',
                           progress_bar=True)
def test_center():
    rec = generate_recording()

    rec2 = center(rec, mode='median')
    rec2.get_traces(segment_index=0)
Exemple #29
0
def test_waveform_tools():

    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    num_channels = 2
    recording = generate_recording(num_channels=num_channels,
                                   durations=durations,
                                   sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    #~ recording = recording.save(folder=folder_rec)
    num_units = 15
    sorting = generate_sorting(num_units=num_units,
                               sampling_frequency=sampling_frequency,
                               durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    #~ we = WaveformExtractor.create(recording, sorting, folder)

    nbefore = int(3. * sampling_frequency / 1000.)
    nafter = int(4. * sampling_frequency / 1000.)

    dtype = recording.get_dtype()
    return_scaled = False

    spikes = sorting.to_spike_vector()

    unit_ids = sorting.unit_ids

    some_job_kwargs = [
        {},
        {
            'n_jobs': 1,
            'chunk_size': 3000,
            'progress_bar': True
        },
        {
            'n_jobs': 2,
            'chunk_size': 3000,
            'progress_bar': True
        },
    ]

    # memmap mode
    list_wfs = []
    for j, job_kwargs in enumerate(some_job_kwargs):
        wf_folder = Path(f'test_waveform_tools_{j}')
        if wf_folder.is_dir():
            shutil.rmtree(wf_folder)
        wf_folder.mkdir()
        wfs_arrays, wfs_arrays_info = allocate_waveforms(recording,
                                                         spikes,
                                                         unit_ids,
                                                         nbefore,
                                                         nafter,
                                                         mode='memmap',
                                                         folder=wf_folder,
                                                         dtype=dtype)
        distribute_waveforms_to_buffers(recording, spikes, unit_ids,
                                        wfs_arrays_info, nbefore, nafter,
                                        return_scaled, **job_kwargs)
        for unit_ind, unit_id in enumerate(unit_ids):
            wf = wfs_arrays[unit_id]
            assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind)
        list_wfs.append(
            {unit_id: wfs_arrays[unit_id].copy()
             for unit_id in unit_ids})
    _check_all_wf_equal(list_wfs)

    # memory
    if platform.system() != 'Windows':
        # shared memory on windows is buggy...
        list_wfs = []
        for job_kwargs in some_job_kwargs:
            wfs_arrays, wfs_arrays_info = allocate_waveforms(
                recording,
                spikes,
                unit_ids,
                nbefore,
                nafter,
                mode='shared_memory',
                folder=None,
                dtype=dtype)
            distribute_waveforms_to_buffers(recording,
                                            spikes,
                                            unit_ids,
                                            wfs_arrays_info,
                                            nbefore,
                                            nafter,
                                            return_scaled,
                                            mode='shared_memory',
                                            **job_kwargs)
            for unit_ind, unit_id in enumerate(unit_ids):
                wf = wfs_arrays[unit_id]
                assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind)
            list_wfs.append(
                {unit_id: wfs_arrays[unit_id].copy()
                 for unit_id in unit_ids})
            # to avoid warning we need to first destroy arrays then sharedmemm object
            del wfs_arrays
            del wfs_arrays_info
        _check_all_wf_equal(list_wfs)

    # with sparsity
    wf_folder = Path('test_waveform_tools_sparse')
    if wf_folder.is_dir():
        shutil.rmtree(wf_folder)
    wf_folder.mkdir()

    sparsity_mask = np.random.randint(0,
                                      2,
                                      size=(unit_ids.size,
                                            recording.channel_ids.size),
                                      dtype='bool')

    wfs_arrays, wfs_arrays_info = allocate_waveforms(
        recording,
        spikes,
        unit_ids,
        nbefore,
        nafter,
        mode='memmap',
        folder=wf_folder,
        dtype=dtype,
        sparsity_mask=sparsity_mask)
    job_kwargs = {'n_jobs': 1, 'chunk_size': 3000, 'progress_bar': True}
    distribute_waveforms_to_buffers(recording,
                                    spikes,
                                    unit_ids,
                                    wfs_arrays_info,
                                    nbefore,
                                    nafter,
                                    return_scaled,
                                    sparsity_mask=sparsity_mask,
                                    **job_kwargs)
def test_normalize_by_quantile():
    rec = generate_recording()

    rec2 = whiten(rec)
    rec2.save(verbose=False)