예제 #1
0
def test_pca_models_and_project_new():
    from sklearn.decomposition import IncrementalPCA
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')

    wfs0 = we.get_waveforms(unit_id=we.sorting.unit_ids[0])
    n_samples = wfs0.shape[1]
    n_channels = wfs0.shape[2]
    n_components = 5

    # local
    pc_local = compute_principal_components(we, n_components=n_components,
                                            load_if_exists=True, mode="by_channel_local")

    all_pca = pc_local.get_pca_model()
    assert len(all_pca) == we.recording.get_num_channels()

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_local.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # global
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
        
    pc_global = compute_principal_components(we, n_components=n_components,
                                             load_if_exists=True, mode="by_channel_global")

    all_pca = pc_global.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_global.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # concatenated
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    
    pc_concatenated = compute_principal_components(we, n_components=n_components,
                                                   load_if_exists=True, mode="concatenated")

    all_pca = pc_concatenated.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_concatenated.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components)
예제 #2
0
def test_select_units():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'toy_waveforms_1seg_filt')
    assert "principal_components" in we_filt.get_available_extension_names()
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file = 'all_pc.npy'
    pc.run_for_all_spikes(pc_file, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    
    all_pc = np.load(pc_file)
예제 #4
0
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file1 = 'all_pc1.npy'
    pc.run_for_all_spikes(pc_file1, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    all_pc1 = np.load(pc_file1)

    pc_file2 = 'all_pc2.npy'
    pc.run_for_all_spikes(pc_file2, max_channels_per_template=7, chunk_size=10000, n_jobs=2)
    all_pc2 = np.load(pc_file2)

    assert np.array_equal(all_pc1, all_pc2)
예제 #5
0
def test_compute_principal_components():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)