示例#1
0
def test_spikeinterface_viewer(interactive=False):
    import spikeinterface as si
    from spikeinterface.core.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording)

    sorting = generate_sorting()
    spike_source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting)

    app = ephyviewer.mkQApp()
    win = ephyviewer.MainViewer(debug=True, show_auto_scale=True)

    view = ephyviewer.TraceViewer(source=sig_source, name='signals')
    win.add_view(view)

    view = ephyviewer.SpikeTrainViewer(source=spike_source, name='spikes')
    win.add_view(view)

    if interactive:
        win.show()
        app.exec_()
    else:
        # close thread properly
        win.close()
示例#2
0
def test_FrameSliceSorting():
    fs = 30000
    duration = 10
    sort = generate_sorting(num_units=10,
                            durations=[duration],
                            sampling_frequency=fs)

    mid_frame = (duration * fs) // 2
    # duration of all slices is mid_frame. Spike trains are re-referenced to the start_time
    sub_sort = sort.frame_slice(None, None)
    for u in sort.get_unit_ids():
        assert len(sort.get_unit_spike_train(u)) == len(
            sub_sort.get_unit_spike_train(u))

    sub_sort = sort.frame_slice(None, mid_frame)
    for u in sort.get_unit_ids():
        assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame

    sub_sort = sort.frame_slice(mid_frame, None)
    for u in sort.get_unit_ids():
        assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame

    sub_sort = sort.frame_slice(mid_frame - mid_frame // 2,
                                mid_frame + mid_frame // 2)
    for u in sort.get_unit_ids():
        assert max(sub_sort.get_unit_spike_train(u)) <= mid_frame
示例#3
0
def test_time_handling():
    cache_folder = Path('./my_cache_folder')
    durations = [[10], [10, 5]]

    # test multi-segment
    for i, dur in enumerate(durations):
        rec = generate_recording(num_channels=4, durations=dur)
        sort = generate_sorting(num_units=10, durations=dur)

        for segment_index in range(rec.get_num_segments()):
            original_times = rec.get_times(segment_index=segment_index)
            new_times = original_times + 5
            rec.set_times(new_times, segment_index=segment_index)

        sort.register_recording(rec)
        assert sort.has_recording()

        rec_cache = rec.save(folder=cache_folder / f"rec{i}")

        for segment_index in range(sort.get_num_segments()):
            assert rec.has_time_vector(segment_index=segment_index)
            assert sort.has_time_vector(segment_index=segment_index)

            # times are correctly saved by the recording
            assert np.allclose(
                rec.get_times(segment_index=segment_index),
                rec_cache.get_times(segment_index=segment_index))

            # spike times are correctly adjusted
            for u in sort.get_unit_ids():
                spike_times = sort.get_unit_spike_train(
                    u, segment_index=segment_index, return_times=True)
                rec_times = rec.get_times(segment_index=segment_index)
                assert np.all(spike_times >= rec_times[0])
                assert np.all(spike_times <= rec_times[-1])
示例#4
0
def test_WaveformExtractor():
    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_waveform_extractor')
    if folder.is_dir():
        shutil.rmtree(folder)

    we = WaveformExtractor.create(recording, sorting, folder)

    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)

    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True)

    wfs = we.get_waveforms(0)
    assert wfs.shape[0] <= 500
    assert wfs.shape[1:] == (210, 2)

    wfs, sampled_index = we.get_waveforms(0, with_index=True)

    # load back
    we = WaveformExtractor.load_from_folder(folder)

    wfs = we.get_waveforms(0)

    template = we.get_template(0)
    assert template.shape == (210, 2)
    templates = we.get_all_templates()
    assert templates.shape == (5, 210, 2)

    wf_std = we.get_template(0, mode='std')
    assert wf_std.shape == (210, 2)
    wfs_std = we.get_all_templates(mode='std')
    assert wfs_std.shape == (5, 210, 2)


    wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
    assert wf_segment.shape == (210, 2)
    assert wf_segment.shape == (210, 2)
示例#5
0
def test_portability():
    durations = [30, 40]
    sampling_frequency = 30000.
    
    folder_to_move = Path("original_folder")
    if folder_to_move.is_dir():
        shutil.rmtree(folder_to_move)
    folder_to_move.mkdir()
    folder_moved = Path("moved_folder")
    if folder_moved.is_dir():
        shutil.rmtree(folder_moved)
    # folder_moved.mkdir()
        
    # 2 segments
    num_channels = 2
    recording = generate_recording(num_channels=num_channels, durations=durations, 
                                   sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = folder_to_move / "rec"
    recording = recording.save(folder=folder_rec)
    num_units = 15
    sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = folder_to_move / "sort"
    sorting = sorting.save(folder=folder_sort)

    wf_folder = folder_to_move / "waveform_extractor"
    if wf_folder.is_dir():
        shutil.rmtree(wf_folder)

    # save with relative paths
    we = extract_waveforms(recording, sorting, wf_folder, use_relative_path=True)

    # move all to a separate folder
    shutil.copytree(folder_to_move, folder_moved)
    wf_folder_moved = folder_moved / "waveform_extractor"
    we_loaded = extract_waveforms(recording, sorting, wf_folder_moved, load_if_exists=True)
    
    assert we_loaded.recording is not None
    assert we_loaded.sorting is not None

    assert np.allclose(
        we.recording.get_channel_ids(), we_loaded.recording.get_channel_ids())
    assert np.allclose(
        we.sorting.get_unit_ids(), we_loaded.sorting.get_unit_ids())


    for unit in we.sorting.get_unit_ids():
        wf = we.get_waveforms(unit_id=unit)
        wf_loaded = we_loaded.get_waveforms(unit_id=unit)

        assert np.allclose(wf, wf_loaded)
示例#6
0
def test_spikeinterface_sources():
    import spikeinterface as si
    from spikeinterface.core.testing_tools import generate_recording, generate_sorting

    recording = generate_recording()
    source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording)
    print(source)

    print(source.t_start, source.nb_channel, source.sample_rate)

    sorting = generate_sorting()
    source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting)
    print(source)

    print(source.t_start, source.nb_channel, source.get_channel_name())
示例#7
0
def test_extract_waveforms():
    # 2 segments

    durations = [30, 40]
    sampling_frequency = 30000.

    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec2"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = "wf_sort2"
    sorting = sorting.save(folder=folder_sort)
    # test without dump !!!!
    #  recording = recording.save()
    #  sorting = sorting.save()

    folder1 = Path('test_extract_waveforms_1job')
    if folder1.is_dir():
        shutil.rmtree(folder1)
    we1 = extract_waveforms(recording, sorting, folder1, max_spikes_per_unit=None, return_scaled=False)

    folder2 = Path('test_extract_waveforms_2job')
    if folder2.is_dir():
        shutil.rmtree(folder2)
    we2 = extract_waveforms(recording, sorting, folder2, n_jobs=2, total_memory="10M", max_spikes_per_unit=None,
                            return_scaled=False)

    wf1 = we1.get_waveforms(0)
    wf2 = we2.get_waveforms(0)
    assert np.array_equal(wf1, wf2)

    folder3 = Path('test_extract_waveforms_returnscaled')
    if folder3.is_dir():
        shutil.rmtree(folder3)

    # set scaling values to recording
    gain = 0.1
    recording.set_channel_gains(gain)
    recording.set_channel_offsets(0)
    we3 = extract_waveforms(recording, sorting, folder3, n_jobs=2, total_memory="10M", max_spikes_per_unit=None,
                            return_scaled=True)

    wf3 = we3.get_waveforms(0)
    assert np.array_equal((wf1).astype("float32") * gain, wf3)
示例#8
0
def test_sparsity():
    durations = [30]
    sampling_frequency = 30000.

    recording = generate_recording(num_channels=10, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec3"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)
    folder_sort = "wf_sort3"
    sorting = sorting.save(folder=folder_sort)

    folder = Path('test_extract_waveforms_sparsity')
    if folder.is_dir():
        shutil.rmtree(folder)
    we = extract_waveforms(recording, sorting, folder, max_spikes_per_unit=None)

    # sparsity: same number of channels
    num_channels = 3
    channel_bounds = [2, 9]
    sparsity_same = {}
    sparsity_diff = {}

    for unit in sorting.get_unit_ids():
        sparsity_same[unit] = np.random.permutation(recording.get_channel_ids())[:num_channels]
        rand_channel_num = np.random.randint(channel_bounds[0], channel_bounds[1])
        sparsity_diff[unit] = np.random.permutation(recording.get_channel_ids())[:rand_channel_num]

    print(sparsity_same)
    print(sparsity_diff)

    for unit in sorting.get_unit_ids():
        wf_same = we.get_waveforms(unit_id=unit, sparsity=sparsity_same)
        temp_same = we.get_template(unit_id=unit, sparsity=sparsity_same)

        assert wf_same.shape[-1] == num_channels
        assert temp_same.shape[-1] == num_channels

        wf_diff = we.get_waveforms(unit_id=unit, sparsity=sparsity_diff)
        temp_diff = we.get_template(unit_id=unit, sparsity=sparsity_diff)

        assert wf_diff.shape[-1] == len(sparsity_diff[unit])
        assert temp_diff.shape[-1] == len(sparsity_diff[unit])
示例#9
0
def test_frame_slicing():
    duration = [10]

    rec = generate_recording(num_channels=4, durations=duration)
    sort = generate_sorting(num_units=10, durations=duration)

    original_times = rec.get_times()
    new_times = original_times + 5
    rec.set_times(new_times)

    sort.register_recording(rec)

    start_frame = 3 * rec.get_sampling_frequency()
    end_frame = 7 * rec.get_sampling_frequency()

    rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame)
    sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame)

    for u in sort_slice.get_unit_ids():
        spike_times = sort_slice.get_unit_spike_train(u, return_times=True)
        rec_times = rec_slice.get_times()
        assert np.all(spike_times >= rec_times[0])
        assert np.all(spike_times <= rec_times[-1])
示例#10
0
def test_waveform_tools():

    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    num_channels = 2
    recording = generate_recording(num_channels=num_channels,
                                   durations=durations,
                                   sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    #~ recording = recording.save(folder=folder_rec)
    num_units = 15
    sorting = generate_sorting(num_units=num_units,
                               sampling_frequency=sampling_frequency,
                               durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    #~ we = WaveformExtractor.create(recording, sorting, folder)

    nbefore = int(3. * sampling_frequency / 1000.)
    nafter = int(4. * sampling_frequency / 1000.)

    dtype = recording.get_dtype()
    return_scaled = False

    spikes = sorting.to_spike_vector()

    unit_ids = sorting.unit_ids

    some_job_kwargs = [
        {},
        {
            'n_jobs': 1,
            'chunk_size': 3000,
            'progress_bar': True
        },
        {
            'n_jobs': 2,
            'chunk_size': 3000,
            'progress_bar': True
        },
    ]

    # memmap mode
    list_wfs = []
    for j, job_kwargs in enumerate(some_job_kwargs):
        wf_folder = Path(f'test_waveform_tools_{j}')
        if wf_folder.is_dir():
            shutil.rmtree(wf_folder)
        wf_folder.mkdir()
        wfs_arrays, wfs_arrays_info = allocate_waveforms(recording,
                                                         spikes,
                                                         unit_ids,
                                                         nbefore,
                                                         nafter,
                                                         mode='memmap',
                                                         folder=wf_folder,
                                                         dtype=dtype)
        distribute_waveforms_to_buffers(recording, spikes, unit_ids,
                                        wfs_arrays_info, nbefore, nafter,
                                        return_scaled, **job_kwargs)
        for unit_ind, unit_id in enumerate(unit_ids):
            wf = wfs_arrays[unit_id]
            assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind)
        list_wfs.append(
            {unit_id: wfs_arrays[unit_id].copy()
             for unit_id in unit_ids})
    _check_all_wf_equal(list_wfs)

    # memory
    if platform.system() != 'Windows':
        # shared memory on windows is buggy...
        list_wfs = []
        for job_kwargs in some_job_kwargs:
            wfs_arrays, wfs_arrays_info = allocate_waveforms(
                recording,
                spikes,
                unit_ids,
                nbefore,
                nafter,
                mode='shared_memory',
                folder=None,
                dtype=dtype)
            distribute_waveforms_to_buffers(recording,
                                            spikes,
                                            unit_ids,
                                            wfs_arrays_info,
                                            nbefore,
                                            nafter,
                                            return_scaled,
                                            mode='shared_memory',
                                            **job_kwargs)
            for unit_ind, unit_id in enumerate(unit_ids):
                wf = wfs_arrays[unit_id]
                assert wf.shape[0] == np.sum(spikes['unit_ind'] == unit_ind)
            list_wfs.append(
                {unit_id: wfs_arrays[unit_id].copy()
                 for unit_id in unit_ids})
            # to avoid warning we need to first destroy arrays then sharedmemm object
            del wfs_arrays
            del wfs_arrays_info
        _check_all_wf_equal(list_wfs)

    # with sparsity
    wf_folder = Path('test_waveform_tools_sparse')
    if wf_folder.is_dir():
        shutil.rmtree(wf_folder)
    wf_folder.mkdir()

    sparsity_mask = np.random.randint(0,
                                      2,
                                      size=(unit_ids.size,
                                            recording.channel_ids.size),
                                      dtype='bool')

    wfs_arrays, wfs_arrays_info = allocate_waveforms(
        recording,
        spikes,
        unit_ids,
        nbefore,
        nafter,
        mode='memmap',
        folder=wf_folder,
        dtype=dtype,
        sparsity_mask=sparsity_mask)
    job_kwargs = {'n_jobs': 1, 'chunk_size': 3000, 'progress_bar': True}
    distribute_waveforms_to_buffers(recording,
                                    spikes,
                                    unit_ids,
                                    wfs_arrays_info,
                                    nbefore,
                                    nafter,
                                    return_scaled,
                                    sparsity_mask=sparsity_mask,
                                    **job_kwargs)


"""
import ephyviewer
import spikeinterface.full as si
from spikeinterface.core.testing_tools import generate_recording, generate_sorting

recording = generate_recording()
sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=recording)

filtered_recording = si.bandpass_filter(recording, freq_min=60., freq_max=100.)
sig_filtered_source = ephyviewer.SpikeInterfaceRecordingSource(
    recording=filtered_recording)

sorting = generate_sorting()
spike_source = ephyviewer.SpikeInterfaceSortingSource(sorting=sorting)

app = ephyviewer.mkQApp()
win = ephyviewer.MainViewer(debug=True, show_auto_scale=True)

view = ephyviewer.TraceViewer(source=sig_source, name='signals')
win.add_view(view)

view = ephyviewer.TraceViewer(source=sig_filtered_source,
                              name='signals filtered')
win.add_view(view)

view = ephyviewer.SpikeTrainViewer(source=spike_source, name='spikes')
win.add_view(view)