def test_compute_quality_metrics_peak_sign(): rec = load_extractor('toy_rec') sort = load_extractor('toy_sorting') # invert recording rec_inv = scale(rec, gain=-1.) we = WaveformExtractor.load_from_folder('toy_waveforms') print(we) we_inv = WaveformExtractor.create(rec_inv, sort, 'toy_waveforms_inv') we_inv.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500) we_inv.run_extract_waveforms(n_jobs=1, chunk_size=30000) print(we_inv) # without PC metrics = compute_quality_metrics(we, metric_names=['snr', 'amplitude_cutoff'], peak_sign="neg") metrics_inv = compute_quality_metrics( we_inv, metric_names=['snr', 'amplitude_cutoff'], peak_sign="pos") assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values) assert np.allclose(metrics["amplitude_cutoff"].values, metrics_inv["amplitude_cutoff"].values)
def test_WaveformExtractor(): durations = [30, 40] sampling_frequency = 30000. # 2 segments recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency) recording.annotate(is_filtered=True) folder_rec = "wf_rec1" recording = recording.save(folder=folder_rec) sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations) # test with dump !!!! recording = recording.save() sorting = sorting.save() folder = Path('test_waveform_extractor') if folder.is_dir(): shutil.rmtree(folder) we = WaveformExtractor.create(recording, sorting, folder) we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500) we.run_extract_waveforms(n_jobs=1, chunk_size=30000) we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True) wfs = we.get_waveforms(0) assert wfs.shape[0] <= 500 assert wfs.shape[1:] == (210, 2) wfs, sampled_index = we.get_waveforms(0, with_index=True) # load back we = WaveformExtractor.load_from_folder(folder) wfs = we.get_waveforms(0) template = we.get_template(0) assert template.shape == (210, 2) templates = we.get_all_templates() assert templates.shape == (5, 210, 2) wf_std = we.get_template(0, mode='std') assert wf_std.shape == (210, 2) wfs_std = we.get_all_templates(mode='std') assert wfs_std.shape == (5, 210, 2) wf_segment = we.get_template_segment(unit_id=0, segment_index=0) assert wf_segment.shape == (210, 2) assert wf_segment.shape == (210, 2)
def test_compute_quality_metrics(): we = WaveformExtractor.load_from_folder('toy_waveforms') print(we) # without PC metrics = compute_quality_metrics(we, metric_names=['snr']) assert 'snr' in metrics.columns assert 'isolation_distance' not in metrics.columns print(metrics) # with PCs pca = WaveformPrincipalComponent(we) pca.set_params(n_components=5, mode='by_channel_local') pca.run() metrics = compute_quality_metrics(we) assert 'isolation_distance' in metrics.columns print(metrics) # reload as an extension from we assert QualityMetricCalculator in we.get_available_extensions() assert we.is_extension('quality_metrics') qmc = we.load_extension('quality_metrics') assert isinstance(qmc, QualityMetricCalculator) assert qmc._metrics is not None # print(qmc._metrics) qmc = QualityMetricCalculator.load_from_folder('toy_waveforms') assert qmc._metrics is not None
def test_select_units(): we = WaveformExtractor.load_from_folder('toy_waveforms') qm = compute_quality_metrics(we, load_if_exists=True) keep_units = we.sorting.get_unit_ids()[::2] we_filt = we.select_units(keep_units, 'toy_waveforms_filt') assert "quality_metrics" in we_filt.get_available_extension_names()
def test_select_units(): we = WaveformExtractor.load_from_folder('mearec_waveforms') amps = compute_spike_amplitudes(we, load_if_exists=True) keep_units = we.sorting.get_unit_ids()[::2] we_filt = we.select_units(keep_units, 'mearec_waveforms_filt') assert "spike_amplitudes" in we_filt.get_available_extension_names()
def test_get_template_channel_sparsity(): we = WaveformExtractor.load_from_folder('toy_waveforms') sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5) print(sparsity) sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5) print(sparsity) sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50) print(sparsity) sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50) print(sparsity) sparsity = get_template_channel_sparsity(we, method='threshold', outputs='id', threshold=3) print(sparsity) sparsity = get_template_channel_sparsity(we, method='threshold', outputs='index', threshold=3) print(sparsity) # load from folder because sorting properties must be loaded rec = load_extractor('toy_rec') sort = load_extractor('toy_sort') we = extract_waveforms(rec, sort, 'toy_waveforms_1') sparsity = get_template_channel_sparsity(we, method='by_property', outputs='id', by_property="group") print(sparsity) sparsity = get_template_channel_sparsity(we, method='by_property', outputs='index', by_property="group") print(sparsity)
def test_select_units(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) keep_units = we.sorting.get_unit_ids()[::2] we_filt = we.select_units(keep_units, 'toy_waveforms_1seg_filt') assert "principal_components" in we_filt.get_available_extension_names()
def compute_metrics(self, rec_name, metric_names=['snr'], ms_before=3., ms_after=4., max_spikes_per_unit=500, n_jobs=-1, total_memory='1G', **snr_kargs): rec = self.get_recording(rec_name) gt_sorting = self.get_ground_truth(rec_name) # waveform extractor waveform_folder = self.study_folder / 'metrics' / f'waveforms_{rec_name}' if waveform_folder.is_dir(): shutil.rmtree(waveform_folder) we = WaveformExtractor.create(rec, gt_sorting, waveform_folder) we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) we.run(n_jobs=n_jobs, total_memory=total_memory) # metrics metrics = compute_quality_metrics(we, metric_names=metric_names) filename = self.study_folder / 'metrics' / f'metrics _{rec_name}.txt' metrics.to_csv(filename, sep='\t', index=True) return metrics
def test_get_template_channel_sparsity(): we = WaveformExtractor.load_from_folder('toy_waveforms') sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5) sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5) sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50) sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50)
def test_calculate_pc_metrics(): we = WaveformExtractor.load_from_folder('toy_waveforms') print(we) pca = WaveformPrincipalComponent.load_from_folder('toy_waveforms') print(pca) res = calculate_pc_metrics(pca) print(res)
def test_compute_unit_center_of_mass(): we = WaveformExtractor.load_from_folder('toy_waveforms') unit_location = localize_units(we, method='center_of_mass', num_channels=4) unit_location_dict = localize_units(we, method='center_of_mass', num_channels=4, output='dict')
def test_compute_monopolar_triangulation(): we = WaveformExtractor.load_from_folder('toy_waveforms') unit_location = localize_units(we, method='monopolar_triangulation', radius_um=150) unit_location_dict = localize_units(we, method='monopolar_triangulation', radius_um=150, output='dict')
def test_compute_principal_components_for_all_spikes(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) print(pc) pc_file = 'all_pc.npy' pc.run_for_all_spikes(pc_file, max_channels_per_template=7, chunk_size=10000, n_jobs=1) all_pc = np.load(pc_file)
def get_waveform_extractor(self, rec_name, sorter_name=None): rec = self.get_recording(rec_name) if sorter_name is None: name = 'GroundTruth' sorting = self.get_ground_truth(rec_name) else: assert sorter_name in self.sorter_names name = sorter_name sorting = self.get_sorting(sorter_name, rec_name) waveform_folder = self.study_folder / 'waveforms' / f'waveforms_{name}_{rec_name}' if waveform_folder.is_dir(): we = WaveformExtractor.load_from_folder(waveform_folder) else: we = WaveformExtractor.create(rec, sorting, waveform_folder) return we
def test_pca_models_and_project_new(): from sklearn.decomposition import IncrementalPCA if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') wfs0 = we.get_waveforms(unit_id=we.sorting.unit_ids[0]) n_samples = wfs0.shape[1] n_channels = wfs0.shape[2] n_components = 5 # local pc_local = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="by_channel_local") all_pca = pc_local.get_pca_model() assert len(all_pca) == we.recording.get_num_channels() # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_local.project_new(new_waveforms) assert new_proj.shape == (100, n_components, n_channels) # global if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') pc_global = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="by_channel_global") all_pca = pc_global.get_pca_model() assert isinstance(all_pca, IncrementalPCA) # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_global.project_new(new_waveforms) assert new_proj.shape == (100, n_components, n_channels) # concatenated if Path('toy_waveforms_1seg/principal_components').is_dir(): shutil.rmtree('toy_waveforms_1seg/principal_components') pc_concatenated = compute_principal_components(we, n_components=n_components, load_if_exists=True, mode="concatenated") all_pca = pc_concatenated.get_pca_model() assert isinstance(all_pca, IncrementalPCA) # project new_waveforms = np.random.randn(100, n_samples, n_channels) new_proj = pc_concatenated.project_new(new_waveforms) assert new_proj.shape == (100, n_components)
def test_calculate_template_metrics(): we = WaveformExtractor.load_from_folder('toy_waveforms') features = calculate_template_metrics(we, upsampling_factor=1) print(features) features_up = calculate_template_metrics(we, upsampling_factor=2) print(features_up) features_sparse = calculate_template_metrics(we, upsampling_factor=2, sparsity_dict=dict(method="radius", radius_um=20)) print(features_sparse)
def setup_module(): for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'): if Path(folder).is_dir(): shutil.rmtree(folder) recording, sorting = toy_example(num_segments=2, num_units=10) recording = recording.save(folder='toy_rec') sorting = sorting.save(folder='toy_sorting') we = WaveformExtractor.create(recording, sorting, 'toy_waveforms') we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500) we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
def test_compute_principal_components_for_all_spikes(): we = WaveformExtractor.load_from_folder('toy_waveforms_1seg') pc = compute_principal_components(we, load_if_exists=True) print(pc) pc_file1 = 'all_pc1.npy' pc.run_for_all_spikes(pc_file1, max_channels_per_template=7, chunk_size=10000, n_jobs=1) all_pc1 = np.load(pc_file1) pc_file2 = 'all_pc2.npy' pc.run_for_all_spikes(pc_file2, max_channels_per_template=7, chunk_size=10000, n_jobs=2) all_pc2 = np.load(pc_file2) assert np.array_equal(all_pc1, all_pc2)
def test_compute_quality_metrics(): we = WaveformExtractor.load_from_folder('toy_waveforms') print(we) # without PC metrics = compute_quality_metrics(we, metric_names=['snr']) print(metrics) print(metrics.columns) # with PCs pca = WaveformPrincipalComponent(we) pca.set_params(n_components=5, mode='by_channel_local') pca.run() metrics = compute_quality_metrics(we, waveform_principal_component=pca) print(metrics) print(metrics.columns)
def setup_module(): for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'): if Path(folder).is_dir(): shutil.rmtree(folder) recording, sorting = toy_example(num_segments=2, num_units=10) recording = recording.save(folder='toy_rec') sorting = sorting.save(folder='toy_sorting') we = WaveformExtractor.create(recording, sorting, 'toy_waveforms') we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500) we.run(n_jobs=1, chunk_size=30000) pca = WaveformPrincipalComponent(we) pca.set_params(n_components=5, mode='by_channel_local') pca.run()
def test_WaveformPrincipalComponent(): we = WaveformExtractor.load_from_folder('toy_waveforms_2seg') unit_ids = we.sorting.unit_ids num_channels = we.recording.get_num_channels() pc = WaveformPrincipalComponent(we) for mode in ('by_channel_local', 'by_channel_global'): pc.set_params(n_components=5, mode=mode) print(pc) pc.run() for i, unit_id in enumerate(unit_ids): proj = pc.get_projections(unit_id) # print(comp.shape) assert proj.shape[1:] == (5, 4) # import matplotlib.pyplot as plt # cmap = plt.get_cmap('jet', len(unit_ids)) # fig, axs = plt.subplots(ncols=num_channels) # for i, unit_id in enumerate(unit_ids): # comp = pca.get_components(unit_id) # print(comp.shape) # for chan_ind in range(num_channels): # ax = axs[chan_ind] # ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i)) # plt.show() for mode in ('concatenated',): pc.set_params(n_components=5, mode=mode) print(pc) pc.run() for i, unit_id in enumerate(unit_ids): proj = pc.get_projections(unit_id) assert proj.shape[1] == 5 # print(comp.shape) all_labels, all_components = pc.get_all_components() # relod as an extension from we assert WaveformPrincipalComponent in we.get_available_extensions() assert we.is_extension('principal_components') pc = we.load_extension('principal_components') assert isinstance(pc, WaveformPrincipalComponent) pc = WaveformPrincipalComponent.load_from_folder('toy_waveforms_2seg')
def test_get_template_amplitudes(): we = WaveformExtractor.load_from_folder('toy_waveforms') peak_values = get_template_amplitudes(we) print(peak_values)
def test_get_template_extremum_channel_peak_shift(): we = WaveformExtractor.load_from_folder('toy_waveforms') shifts = get_template_extremum_channel_peak_shift(we, peak_sign='neg') print(shifts)
def test_get_template_best_channels(): we = WaveformExtractor.load_from_folder('toy_waveforms') best_channels = get_template_best_channels(we, num_channels=2) print(best_channels)
def test_calculate_template_metrics(): we = WaveformExtractor.load_from_folder('toy_waveforms') features = calculate_template_metrics(we) print(features)
sorting, folder, ms_before=1.5, ms_after=2., max_spikes_per_unit=500, load_if_exists=True) print(we) ############################################################################### # Alternatively, the :code:`WaveformExtractor` object can be instantiated # directly. In this case, we need to :code:`set_params()` to set the desired # parameters: folder = 'waveform_folder2' we = WaveformExtractor.create(recording, sorting, folder, remove_if_exists=True) we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=1000) we.run_extract_waveforms(n_jobs=1, chunk_size=30000, progress_bar=True) print(we) ############################################################################### # The :code:`'waveform_folder'` folder contains: # * the dumped recording (json) # * the dumped sorting (json) # * the parameters (json) # * a subfolder with "waveforms_XXX.npy" and "sampled_index_XXX.npy" import os print(os.listdir(folder))
def test_get_template_extremum_amplitude(): we = WaveformExtractor.load_from_folder('toy_waveforms') extremum_channels_ids = get_template_extremum_amplitude(we, peak_sign='both') print(extremum_channels_ids)
def nearest_neighbors_noise_overlap(waveform_extractor: si.WaveformExtractor, this_unit_id: int, max_spikes_for_nn: int = 1000, n_neighbors: int = 5, n_components: int = 10, radius_um: float = 100, seed: int = 0): """Calculates unit noise overlap based on NearestNeighbors search in PCA space. Based on noise overlap metric described in Chung et al. (2017) Neuron 95: 1381-1394. Rough logic: ------------ 1) Generate a noise cluster by randomly sampling voltage snippets from recording. 2) Subtract projection onto the weighted average of noise snippets of both the target and noise clusters to correct for bias in sampling. 3) Compute the isolation score between the noise cluster and the target cluster. Implementation details: ----------------------- As with nn_isolation, the clusters that are compared (target and noise clusters) have the same number of spikes. See docstring for `_compute_isolation` for the definition of isolation score. Parameters: ----------- we: si.WaveformExtractor this_unit_id: int ID of unit for which this metric will be calculated max_spikes_for_nn: int max number of spikes to use per cluster n_neighbors: int number of neighbors to check membership of n_components: int number of PC components to project the snippets radius_um: float only the channels within this radius of the peak channel are used to compute the metric seed: int seed for random subsampling of spikes Outputs: -------- nn_noise_overlap : float """ # set random seed rng = np.random.default_rng(seed=seed) # get random snippets from the recording to create a noise cluster recording = waveform_extractor.recording noise_cluster = get_random_data_chunks( recording, return_scaled=waveform_extractor.return_scaled, num_chunks_per_segment=max_spikes_for_nn, chunk_size=waveform_extractor.nsamples, seed=seed) noise_cluster = np.reshape( noise_cluster, (max_spikes_for_nn, waveform_extractor.nsamples, -1)) # get waveforms for target cluster waveforms = waveform_extractor.get_waveforms(unit_id=this_unit_id) # adjust the size of the target and noise clusters to be equal if waveforms.shape[0] > max_spikes_for_nn: wf_ind = rng.choice(waveforms.shape[0], max_spikes_for_nn, replace=False) waveforms = waveforms[wf_ind] n_snippets = max_spikes_for_nn elif waveforms.shape[0] < max_spikes_for_nn: noise_ind = rng.choice(noise_cluster.shape[0], waveforms.shape[0], replace=False) noise_cluster = noise_cluster[noise_ind] n_snippets = waveforms.shape[0] else: n_snippets = max_spikes_for_nn # restrict to channels with significant signal closest_chans_idx = get_template_channel_sparsity(waveform_extractor, method='radius', outputs='index', peak_sign='both', radius_um=radius_um) waveforms = waveforms[:, :, closest_chans_idx[this_unit_id]] noise_cluster = noise_cluster[:, :, closest_chans_idx[this_unit_id]] # compute weighted noise snippet (Z) median_waveform = waveform_extractor.get_template(unit_id=this_unit_id, mode='median') median_waveform = median_waveform[:, closest_chans_idx[this_unit_id]] tmax, chmax = np.unravel_index(np.argmax(np.abs(median_waveform)), median_waveform.shape) weights = [noise_clip[tmax, chmax] for noise_clip in noise_cluster] weights = np.asarray(weights) weights = weights / np.sum(weights) weighted_noise_snippet = np.sum(weights * noise_cluster.swapaxes(0, 2), axis=2).swapaxes(0, 1) # subtract projection onto weighted noise snippet for snippet in range(n_snippets): waveforms[snippet, :, :] = _subtract_clip_component( waveforms[snippet, :, :], weighted_noise_snippet) noise_cluster[snippet, :, :] = _subtract_clip_component( noise_cluster[snippet, :, :], weighted_noise_snippet) # compute principal components after concatenation all_snippets = np.concatenate([ waveforms.reshape((n_snippets, -1)), noise_cluster.reshape((n_snippets, -1)) ], axis=0) pca = IncrementalPCA(n_components=n_components) pca.partial_fit(all_snippets) projected_snippets = pca.transform(all_snippets) # compute overlap nn_noise_overlap = 1 - _compute_isolation( projected_snippets[:n_snippets, :], projected_snippets[n_snippets:, :], n_neighbors) return nn_noise_overlap
def test_compute_template_similarity(): we = WaveformExtractor.load_from_folder('mearec_waveforms') similarity = compute_template_similarity(we)
def test_compute_unit_centers_of_mass(): we = WaveformExtractor.load_from_folder('toy_waveforms') coms = compute_unit_centers_of_mass(we, num_channels=4) print(coms)