Example #1
0
def test_select_units():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'toy_waveforms_1seg_filt')
    assert "principal_components" in we_filt.get_available_extension_names()
def test_compute_quality_metrics_peak_sign():
    rec = load_extractor('toy_rec')
    sort = load_extractor('toy_sorting')

    # invert recording
    rec_inv = scale(rec, gain=-1.)

    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    we_inv = WaveformExtractor.create(rec_inv, sort, 'toy_waveforms_inv')
    we_inv.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we_inv.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    print(we_inv)

    # without PC
    metrics = compute_quality_metrics(we,
                                      metric_names=['snr', 'amplitude_cutoff'],
                                      peak_sign="neg")
    metrics_inv = compute_quality_metrics(
        we_inv, metric_names=['snr', 'amplitude_cutoff'], peak_sign="pos")

    assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values)
    assert np.allclose(metrics["amplitude_cutoff"].values,
                       metrics_inv["amplitude_cutoff"].values)
def test_select_units():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    qm = compute_quality_metrics(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'toy_waveforms_filt')
    assert "quality_metrics" in we_filt.get_available_extension_names()
def test_select_units():
    we = WaveformExtractor.load_from_folder('mearec_waveforms')
    amps = compute_spike_amplitudes(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'mearec_waveforms_filt')
    assert "spike_amplitudes" in we_filt.get_available_extension_names()
def test_compute_quality_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    # without PC
    metrics = compute_quality_metrics(we, metric_names=['snr'])
    assert 'snr' in metrics.columns
    assert 'isolation_distance' not in metrics.columns
    print(metrics)

    # with PCs
    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
    metrics = compute_quality_metrics(we)
    assert 'isolation_distance' in metrics.columns
    print(metrics)

    # reload as an extension from we
    assert QualityMetricCalculator in we.get_available_extensions()
    assert we.is_extension('quality_metrics')
    qmc = we.load_extension('quality_metrics')
    assert isinstance(qmc, QualityMetricCalculator)
    assert qmc._metrics is not None
    # print(qmc._metrics)
    qmc = QualityMetricCalculator.load_from_folder('toy_waveforms')
    assert qmc._metrics is not None
Example #6
0
def test_get_template_channel_sparsity():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5)
    print(sparsity)

    sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='threshold', outputs='id', threshold=3)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='threshold', outputs='index', threshold=3)
    print(sparsity)

    # load from folder because sorting properties must be loaded
    rec = load_extractor('toy_rec')
    sort = load_extractor('toy_sort')
    we = extract_waveforms(rec, sort, 'toy_waveforms_1')
    sparsity = get_template_channel_sparsity(we, method='by_property', outputs='id', by_property="group")
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='by_property', outputs='index', by_property="group")

    print(sparsity)
def test_calculate_pc_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)
    pca = WaveformPrincipalComponent.load_from_folder('toy_waveforms')
    print(pca)

    res = calculate_pc_metrics(pca)
    print(res)
Example #8
0
def test_compute_unit_center_of_mass():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    unit_location = localize_units(we, method='center_of_mass', num_channels=4)
    unit_location_dict = localize_units(we,
                                        method='center_of_mass',
                                        num_channels=4,
                                        output='dict')
Example #9
0
def test_get_template_channel_sparsity():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5)
    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5)

    sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50)
    sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50)
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file = 'all_pc.npy'
    pc.run_for_all_spikes(pc_file, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    
    all_pc = np.load(pc_file)
Example #11
0
def test_compute_monopolar_triangulation():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    unit_location = localize_units(we,
                                   method='monopolar_triangulation',
                                   radius_um=150)
    unit_location_dict = localize_units(we,
                                        method='monopolar_triangulation',
                                        radius_um=150,
                                        output='dict')
Example #12
0
def test_pca_models_and_project_new():
    from sklearn.decomposition import IncrementalPCA
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')

    wfs0 = we.get_waveforms(unit_id=we.sorting.unit_ids[0])
    n_samples = wfs0.shape[1]
    n_channels = wfs0.shape[2]
    n_components = 5

    # local
    pc_local = compute_principal_components(we, n_components=n_components,
                                            load_if_exists=True, mode="by_channel_local")

    all_pca = pc_local.get_pca_model()
    assert len(all_pca) == we.recording.get_num_channels()

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_local.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # global
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
        
    pc_global = compute_principal_components(we, n_components=n_components,
                                             load_if_exists=True, mode="by_channel_global")

    all_pca = pc_global.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_global.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # concatenated
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    
    pc_concatenated = compute_principal_components(we, n_components=n_components,
                                                   load_if_exists=True, mode="concatenated")

    all_pca = pc_concatenated.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_concatenated.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components)
Example #13
0
def test_calculate_template_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    features = calculate_template_metrics(we, upsampling_factor=1)
    print(features)

    features_up = calculate_template_metrics(we, upsampling_factor=2)
    print(features_up)

    features_sparse = calculate_template_metrics(we, upsampling_factor=2,
                                                 sparsity_dict=dict(method="radius", 
                                                                    radius_um=20))
    print(features_sparse)
Example #14
0
def test_WaveformExtractor():
    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_waveform_extractor')
    if folder.is_dir():
        shutil.rmtree(folder)

    we = WaveformExtractor.create(recording, sorting, folder)

    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)

    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True)

    wfs = we.get_waveforms(0)
    assert wfs.shape[0] <= 500
    assert wfs.shape[1:] == (210, 2)

    wfs, sampled_index = we.get_waveforms(0, with_index=True)

    # load back
    we = WaveformExtractor.load_from_folder(folder)

    wfs = we.get_waveforms(0)

    template = we.get_template(0)
    assert template.shape == (210, 2)
    templates = we.get_all_templates()
    assert templates.shape == (5, 210, 2)

    wf_std = we.get_template(0, mode='std')
    assert wf_std.shape == (210, 2)
    wfs_std = we.get_all_templates(mode='std')
    assert wfs_std.shape == (5, 210, 2)


    wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
    assert wf_segment.shape == (210, 2)
    assert wf_segment.shape == (210, 2)
Example #15
0
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file1 = 'all_pc1.npy'
    pc.run_for_all_spikes(pc_file1, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    all_pc1 = np.load(pc_file1)

    pc_file2 = 'all_pc2.npy'
    pc.run_for_all_spikes(pc_file2, max_channels_per_template=7, chunk_size=10000, n_jobs=2)
    all_pc2 = np.load(pc_file2)

    assert np.array_equal(all_pc1, all_pc2)
def test_compute_quality_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    # without PC
    metrics = compute_quality_metrics(we, metric_names=['snr'])
    print(metrics)
    print(metrics.columns)

    # with PCs
    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
    metrics = compute_quality_metrics(we, waveform_principal_component=pca)
    print(metrics)
    print(metrics.columns)
Example #17
0
    def get_waveform_extractor(self, rec_name, sorter_name=None):
        rec = self.get_recording(rec_name)

        if sorter_name is None:
            name = 'GroundTruth'
            sorting = self.get_ground_truth(rec_name)
        else:
            assert sorter_name in self.sorter_names
            name = sorter_name
            sorting = self.get_sorting(sorter_name, rec_name)

        waveform_folder = self.study_folder / 'waveforms' / f'waveforms_{name}_{rec_name}'

        if waveform_folder.is_dir():
            we = WaveformExtractor.load_from_folder(waveform_folder)
        else:
            we = WaveformExtractor.create(rec, sorting, waveform_folder)
        return we
Example #18
0
def test_WaveformPrincipalComponent():
    we = WaveformExtractor.load_from_folder('toy_waveforms_2seg')
    unit_ids = we.sorting.unit_ids
    num_channels = we.recording.get_num_channels()
    pc = WaveformPrincipalComponent(we)

    for mode in ('by_channel_local', 'by_channel_global'):
        pc.set_params(n_components=5, mode=mode)
        print(pc)
        pc.run()
        for i, unit_id in enumerate(unit_ids):
            proj = pc.get_projections(unit_id)
            # print(comp.shape)
            assert proj.shape[1:] == (5, 4)

        # import matplotlib.pyplot as plt
        # cmap = plt.get_cmap('jet', len(unit_ids))
        # fig, axs = plt.subplots(ncols=num_channels)
        # for i, unit_id in enumerate(unit_ids):
        # comp = pca.get_components(unit_id)
        # print(comp.shape)
        # for chan_ind in range(num_channels):
        # ax = axs[chan_ind]
        # ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i))
        # plt.show()

    for mode in ('concatenated',):
        pc.set_params(n_components=5, mode=mode)
        print(pc)
        pc.run()
        for i, unit_id in enumerate(unit_ids):
            proj = pc.get_projections(unit_id)
            assert proj.shape[1] == 5
            # print(comp.shape)

    all_labels, all_components = pc.get_all_components()
    
    # relod as an extension from we
    assert WaveformPrincipalComponent in we.get_available_extensions()
    assert we.is_extension('principal_components')
    pc = we.load_extension('principal_components')
    assert isinstance(pc, WaveformPrincipalComponent)
    pc = WaveformPrincipalComponent.load_from_folder('toy_waveforms_2seg')
Example #19
0
def test_get_template_extremum_amplitude():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    extremum_channels_ids = get_template_extremum_amplitude(we,
                                                            peak_sign='both')
    print(extremum_channels_ids)
Example #20
0
def test_compute_principal_components():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)
Example #21
0
def test_calculate_template_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    features = calculate_template_metrics(we)
    print(features)
Example #22
0
def test_get_template_amplitudes():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    peak_values = get_template_amplitudes(we)
    print(peak_values)
Example #23
0
def test_get_template_extremum_channel_peak_shift():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    shifts = get_template_extremum_channel_peak_shift(we, peak_sign='neg')
    print(shifts)
Example #24
0
def test_get_template_best_channels():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    best_channels = get_template_best_channels(we, num_channels=2)
    print(best_channels)
Example #25
0
def test_compute_template_similarity():
    we = WaveformExtractor.load_from_folder('mearec_waveforms')
    similarity = compute_template_similarity(we)
Example #26
0
def test_compute_unit_centers_of_mass():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    coms = compute_unit_centers_of_mass(we, num_channels=4)
    print(coms)