예제 #1
0
def test_get_noise_levels():
    rec = generate_recording(num_channels=2,
                             sampling_frequency=1000.,
                             durations=[60.])

    noise_levels = get_noise_levels(rec)
    print(noise_levels)
예제 #2
0
def test_filter_opencl():
    rec = generate_recording(
        num_channels=256,
        # num_channels = 32,
        sampling_frequency=30000.,
        durations=[
            100.325,
        ],
        # durations = [10.325, 3.5],
    )
    rec = rec.save(total_memory="100M", n_jobs=1, progress_bar=True)

    print(rec.get_dtype())
    print(rec.is_dumpable)
    # print(rec.to_dict())

    rec_filtered = filter(rec, engine='scipy')
    rec_filtered = rec_filtered.save(chunk_size=1000,
                                     progress_bar=True,
                                     n_jobs=30)

    rec2 = filter(rec, engine='opencl')
    rec2_cached0 = rec2.save(chunk_size=1000,
                             verbose=False,
                             progress_bar=True,
                             n_jobs=1)
def test_spikeinterface_viewer(interactive=False):
    import spikeinterface as si
    from spikeinterface.core.tests.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    sig_source = ephyviewer.FromSpikeinterfaceRecordingSource(
        recording=recording)

    sorting = generate_sorting()
    spike_source = ephyviewer.FromSpikeinterfaceSorintgSource(sorting=sorting)

    app = ephyviewer.mkQApp()
    win = ephyviewer.MainViewer(debug=True, show_auto_scale=True)

    view = ephyviewer.TraceViewer(source=sig_source, name='signals')
    win.add_view(view)

    view = ephyviewer.SpikeTrainViewer(source=spike_source, name='spikes')
    win.add_view(view)

    if interactive:
        win.show()
        app.exec_()
    else:
        # close thread properly
        win.close()
예제 #4
0
def test_filter():
    rec = generate_recording()
    rec = rec.save()

    rec2 = bandpass_filter(rec, freq_min=300., freq_max=6000.)

    # compute by chunk
    rec2_cached0 = rec2.save(chunk_size=100000,
                             verbose=False,
                             progress_bar=True)
    # compute by chunkf with joblib
    rec2_cached1 = rec2.save(total_memory="10k", n_jobs=4, verbose=True)
    # compute once
    rec2_cached2 = rec2.save(verbose=False)

    trace0 = rec2.get_traces(segment_index=0)
    trace1 = rec2_cached1.get_traces(segment_index=0)

    # other filtering types
    rec3 = filter(rec, band=[40., 60.], btype='bandstop')
    rec4 = filter(rec,
                  band=500.,
                  btype='highpass',
                  filter_mode='ba',
                  filter_order=2)

    rec5 = notch_filter(rec, freq=3000, q=30, margin_ms=5.)
예제 #5
0
def test_write_memory_recording():
    # 2 segments
    recording = generate_recording(num_channels=2, durations=[10.325, 3.5])
    # make dumpable
    recording = recording.save()

    # write with loop
    write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1)

    write_memory_recording(recording,
                           dtype=None,
                           verbose=True,
                           n_jobs=1,
                           chunk_memory='100k',
                           progress_bar=True)

    # write parrallel
    write_memory_recording(recording,
                           dtype=None,
                           verbose=False,
                           n_jobs=2,
                           chunk_memory='100k')

    # write parrallel
    write_memory_recording(recording,
                           dtype=None,
                           verbose=False,
                           n_jobs=2,
                           total_memory='200k',
                           progress_bar=True)
예제 #6
0
def test_get_closest_channels():
    rec = generate_recording(num_channels=32,
                             sampling_frequency=1000.,
                             durations=[0.1])
    closest_channels_inds, distances = get_closest_channels(rec)
    closest_channels_inds, distances = get_closest_channels(rec,
                                                            num_channels=4)
예제 #7
0
def test_rectify():
    rec = generate_recording()

    rec2 = rectify(rec)
    rec2.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
예제 #8
0
def test_remove_artifacts():
    # one segment only
    rec = generate_recording(durations=[10.])

    triggers = [15000, 30000]
    list_triggers = [triggers]

    ms = 10
    ms_frames = int(ms * rec.get_sampling_frequency() / 1000)

    traces_all_0_clean = rec.get_traces(start_frame=triggers[0] - ms_frames,
                                        end_frame=triggers[0] + ms_frames)
    traces_all_1_clean = rec.get_traces(start_frame=triggers[1] - ms_frames,
                                        end_frame=triggers[1] + ms_frames)

    rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10)
    traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames,
                                        end_frame=triggers[0] + ms_frames)
    traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10,
                                          end_frame=triggers[0] + 10)
    traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames,
                                        end_frame=triggers[1] + ms_frames)
    traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10,
                                          end_frame=triggers[1] + 10)

    assert not np.any(traces_all_0)
    assert not np.any(traces_all_1)
    assert not np.any(traces_short_0)
    assert not np.any(traces_short_1)

    rec_rmart_lin = remove_artifacts(rec,
                                     triggers,
                                     ms_before=10,
                                     ms_after=10,
                                     mode="linear")
    traces_all_0 = rec_rmart_lin.get_traces(start_frame=triggers[0] -
                                            ms_frames,
                                            end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_lin.get_traces(start_frame=triggers[1] -
                                            ms_frames,
                                            end_frame=triggers[1] + ms_frames)
    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)

    rec_rmart_cub = remove_artifacts(rec,
                                     triggers,
                                     ms_before=10,
                                     ms_after=10,
                                     mode="cubic")
    traces_all_0 = rec_rmart_cub.get_traces(start_frame=triggers[0] -
                                            ms_frames,
                                            end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_cub.get_traces(start_frame=triggers[1] -
                                            ms_frames,
                                            end_frame=triggers[1] + ms_frames)

    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)
예제 #9
0
def test_get_random_data_chunks():
    rec = generate_recording(num_channels=1,
                             sampling_frequency=1000.,
                             durations=[10., 20.])
    chunks = get_random_data_chunks(rec,
                                    num_chunks_per_segment=50,
                                    chunk_size=500,
                                    seed=0)
    assert chunks.shape == (50000, 1)
예제 #10
0
def test_get_chunk_with_margin():
    rec = generate_recording(num_channels=1,
                             sampling_frequency=1000.,
                             durations=[10.])
    rec_seg = rec._recording_segments[0]
    length = rec_seg.get_num_samples()

    # rec_segment, start_frame, end_frame, channel_indices, sample_margin

    traces, l, r = get_chunk_with_margin(rec_seg, None, None, None, 10)
    assert l == 0 and r == 0

    traces, l, r = get_chunk_with_margin(rec_seg, 5, None, None, 10)
    assert l == 5 and r == 0

    traces, l, r = get_chunk_with_margin(rec_seg, length - 1000, length - 5,
                                         None, 10)
    assert l == 10 and r == 5
    assert traces.shape[0] == 1010

    traces, l, r = get_chunk_with_margin(rec_seg, 2000, 3000, None, 10)
    assert l == 10 and r == 10
    assert traces.shape[0] == 1020

    # add zeros
    traces, l, r = get_chunk_with_margin(rec_seg,
                                         5,
                                         1005,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert l == 10
    assert r == 10
    assert np.all(traces[:5] == 0)

    traces, l, r = get_chunk_with_margin(rec_seg,
                                         length - 1005,
                                         length - 5,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert np.all(traces[-5:] == 0)
    assert l == 10
    assert r == 10

    traces, l, r = get_chunk_with_margin(rec_seg,
                                         length - 500,
                                         length + 500,
                                         None,
                                         10,
                                         add_zeros=True)
    assert traces.shape[0] == 1020
    assert np.all(traces[-510:] == 0)
    assert l == 10
    assert r == 510
예제 #11
0
def test_clip():
    rec = generate_recording()

    rec2 = clip(rec, a_min=-2, a_max=3.)
    rec2.save(verbose=False)

    rec3 = clip(rec, a_min=-1.5)
    rec3.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
예제 #12
0
def test_ChunkRecordingExecutor():
    recording = generate_recording(num_channels=2)
    # make dumpable
    recording = recording.save()

    def func(segment_index, start_frame, end_frame, worker_ctx):
        import os, time
        # print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid())
        time.sleep(0.010)
        # time.sleep(1.0)
        return os.getpid()

    def init_func(arg1, arg2, arg3):
        worker_ctx = {}
        worker_ctx['arg1'] = arg1
        worker_ctx['arg2'] = arg2
        worker_ctx['arg3'] = arg3
        return worker_ctx

    init_args = 'a', 120, 'yep'

    # no chunk
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_size=None)
    processor.run()

    # chunk + loop
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=False,
                                       n_jobs=1,
                                       chunk_memory="500k")
    processor.run()

    # chunk + parralel
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       verbose=True,
                                       progress_bar=True,
                                       n_jobs=2,
                                       total_memory="200k",
                                       job_name='job_name')
    processor.run()
예제 #13
0
def test_blank_staturationy():
    rec = generate_recording()

    rec2 = blank_staturation(rec, abs_threshold=3.)
    rec2.save(verbose=False)

    rec3 = blank_staturation(rec, quantile_threshold=0.01, direction='both')
    rec3.save(verbose=False)

    traces = rec2.get_traces(segment_index=0, channel_ids=[1])
    assert traces.shape[1] == 1
예제 #14
0
def test_spikeinterface_sources():
    import spikeinterface as si
    from spikeinterface.core.tests.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    source = ephyviewer.FromSpikeinterfaceRecordingSource(recording=recording)
    print(source)

    print(source.t_start, source.nb_channel, source.sample_rate)

    sorting = generate_sorting()
    source = ephyviewer.FromSpikeinterfaceSorintgSource(sorting=sorting)
    print(source)

    print(source.t_start, source.nb_channel, source.get_channel_name())
예제 #15
0
def test_extract_waveforms():
    # 2 segments
    durations = [30, 40]
    sampling_frequency = 30000.
    
    recording = generate_recording(num_channels = 2, durations=durations, sampling_frequency=sampling_frequency)
    sorting =generate_sorting(num_units=5, sampling_frequency = sampling_frequency, durations=durations)
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_extract_waveforms')
    if folder.is_dir():
        shutil.rmtree(folder)
    
    we = extract_waveforms(recording, sorting, folder)
    print(we)
예제 #16
0
def test_ensure_n_jobs():
    recording = generate_recording()

    n_jobs = ensure_n_jobs(recording)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=0)
    assert n_jobs == 1

    n_jobs = ensure_n_jobs(recording, n_jobs=1)
    assert n_jobs == 1

    # not dumpable force n_jobs=1
    n_jobs = ensure_n_jobs(recording, n_jobs=-1)
    assert n_jobs == 1

    # dumpable
    n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1)
    assert n_jobs > 1
예제 #17
0
def test_WaveformExtractor():
    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    recording = generate_recording(num_channels=2,
                                   durations=durations,
                                   sampling_frequency=sampling_frequency)
    sorting = generate_sorting(num_units=5,
                               sampling_frequency=sampling_frequency,
                               durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_waveform_extractor')
    if folder.is_dir():
        shutil.rmtree(folder)

    we = WaveformExtractor.create(recording, sorting, folder)

    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)

    we.run(n_jobs=1, chunk_size=30000)

    we.run(n_jobs=4, chunk_size=30000, progress_bar=True)

    wfs = we.get_waveforms(0)
    assert wfs.shape[0] <= 500
    assert wfs.shape[1:] == (210, 2)

    wfs, sampled_index = we.get_waveforms(0, with_index=True)

    # load back
    we = WaveformExtractor.load_from_folder(folder)

    wfs = we.get_waveforms(0)

    template = we.get_template(0)
    assert template.shape == (210, 2)
    templates = we.get_all_templates()
    assert templates.shape == (5, 210, 2)
예제 #18
0
def test_ensure_chunk_size():
    recording = generate_recording(num_channels=2)
    dtype = recording.get_dtype()
    assert dtype == 'float32'
    # make dumpable
    recording = recording.save()

    chunk_size = ensure_chunk_size(recording,
                                   total_memory="512M",
                                   chunk_size=None,
                                   chunk_memory=None,
                                   n_jobs=2)
    assert chunk_size == 32000000

    chunk_size = ensure_chunk_size(recording, chunk_memory="256M")
    assert chunk_size == 32000000

    chunk_size = ensure_chunk_size(recording, chunk_memory="1k")
    assert chunk_size == 125

    chunk_size = ensure_chunk_size(recording, chunk_memory="1G")
    assert chunk_size == 125000000
예제 #19
0
def test_common_reference():
    rec = generate_recording(durations=[5.], num_channels=4)
    rec = rec.save()

    # no groups
    rec_cmr = common_reference(rec, reference='global', operator='median')
    rec_car = common_reference(rec, reference='global', operator='average')
    rec_sin = common_reference(rec, reference='single', ref_channels=0)
    rec_local_car = common_reference(rec,
                                     reference='local',
                                     local_radius=(20, 65),
                                     operator='median')

    rec_cmr.save(verbose=False)
    rec_car.save(verbose=False)
    rec_sin.save(verbose=False)
    rec_local_car.save(verbose=False)

    traces = rec.get_traces()
    assert np.allclose(traces,
                       rec_cmr.get_traces() +
                       np.median(traces, axis=1, keepdims=True),
                       atol=0.01)
    assert np.allclose(traces,
                       rec_car.get_traces() +
                       np.mean(traces, axis=1, keepdims=True),
                       atol=0.01)
    assert not np.all(rec_sin.get_traces()[0])
    assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])

    assert np.allclose(traces[:, 0],
                       rec_local_car.get_traces()[:, 0] +
                       np.median(traces[:, [2, 3]], axis=1),
                       atol=0.01)
    assert np.allclose(traces[:, 1],
                       rec_local_car.get_traces()[:, 1] +
                       np.median(traces[:, [3]], axis=1),
                       atol=0.01)
예제 #20
0
def test_normalize_by_quantile():
    rec = generate_recording()
    
    rec2 = whiten(rec)
    rec2.save(verbose=False)