def test_pca_models_and_project_new(): from sklearn.decomposition import IncrementalPCA if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') wfs0 = we.get_waveforms(unit_id=we.sorting.unit_ids[0]) n_samples = wfs0.shape[1] n_channels = wfs0.shape[2] n_components = 5 # local pc_local = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="by_channel_local") all_pca = pc_local.get_pca_model() assert len(all_pca) == we.recording.get_num_channels() # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_local.project_new(new_waveforms) assert new_proj.shape == (100, n_components, n_channels) # global if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') pc_global = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="by_channel_global") all_pca = pc_global.get_pca_model() assert isinstance(all_pca, IncrementalPCA) # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_global.project_new(new_waveforms) assert new_proj.shape == (100, n_components, n_channels) # concatenated if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') pc_concatenated = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="concatenated") all_pca = pc_concatenated.get_pca_model() assert isinstance(all_pca, IncrementalPCA) # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_concatenated.project_new(new_waveforms) assert new_proj.shape == (100, n_components)
def test_select_units(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) keep_units = we.sorting.get_unit_ids()[::2] we_filt = we.select_units(keep_units, 'toy_waveforms_1seg_filt') assert "principal_components" in we_filt.get_available_extension_names()
def test_compute_principal_components_for_all_spikes(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) print(pc) pc_file = 'all_pc.npy' pc.run_for_all_spikes(pc_file, max_channels_per_template=7, chunk_size=10000, n_jobs=1) all_pc = np.load(pc_file)
def test_compute_principal_components_for_all_spikes(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) print(pc) pc_file1 = 'all_pc1.npy' pc.run_for_all_spikes(pc_file1, max_channels_per_template=7, chunk_size=10000, n_jobs=1) all_pc1 = np.load(pc_file1) pc_file2 = 'all_pc2.npy' pc.run_for_all_spikes(pc_file2, max_channels_per_template=7, chunk_size=10000, n_jobs=2) all_pc2 = np.load(pc_file2) assert np.array_equal(all_pc1, all_pc2)
def test_compute_principal_components(): we = WaveformExtractor.load_from_folder('toy_waveforms') pc = compute_principal_components(we, load_if_exists=True) print(pc)