def test_compute_quality_metrics_peak_sign():
    rec = load_extractor('toy_rec')
    sort = load_extractor('toy_sorting')

    # invert recording
    rec_inv = scale(rec, gain=-1.)

    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    we_inv = WaveformExtractor.create(rec_inv, sort, 'toy_waveforms_inv')
    we_inv.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we_inv.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    print(we_inv)

    # without PC
    metrics = compute_quality_metrics(we,
                                      metric_names=['snr', 'amplitude_cutoff'],
                                      peak_sign="neg")
    metrics_inv = compute_quality_metrics(
        we_inv, metric_names=['snr', 'amplitude_cutoff'], peak_sign="pos")

    assert np.allclose(metrics["snr"].values, metrics_inv["snr"].values)
    assert np.allclose(metrics["amplitude_cutoff"].values,
                       metrics_inv["amplitude_cutoff"].values)
Exemple #2
0
def test_WaveformExtractor():
    durations = [30, 40]
    sampling_frequency = 30000.

    # 2 segments
    recording = generate_recording(num_channels=2, durations=durations, sampling_frequency=sampling_frequency)
    recording.annotate(is_filtered=True)
    folder_rec = "wf_rec1"
    recording = recording.save(folder=folder_rec)
    sorting = generate_sorting(num_units=5, sampling_frequency=sampling_frequency, durations=durations)

    # test with dump !!!!
    recording = recording.save()
    sorting = sorting.save()

    folder = Path('test_waveform_extractor')
    if folder.is_dir():
        shutil.rmtree(folder)

    we = WaveformExtractor.create(recording, sorting, folder)

    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)

    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
    we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True)

    wfs = we.get_waveforms(0)
    assert wfs.shape[0] <= 500
    assert wfs.shape[1:] == (210, 2)

    wfs, sampled_index = we.get_waveforms(0, with_index=True)

    # load back
    we = WaveformExtractor.load_from_folder(folder)

    wfs = we.get_waveforms(0)

    template = we.get_template(0)
    assert template.shape == (210, 2)
    templates = we.get_all_templates()
    assert templates.shape == (5, 210, 2)

    wf_std = we.get_template(0, mode='std')
    assert wf_std.shape == (210, 2)
    wfs_std = we.get_all_templates(mode='std')
    assert wfs_std.shape == (5, 210, 2)


    wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
    assert wf_segment.shape == (210, 2)
    assert wf_segment.shape == (210, 2)
def test_compute_quality_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    # without PC
    metrics = compute_quality_metrics(we, metric_names=['snr'])
    assert 'snr' in metrics.columns
    assert 'isolation_distance' not in metrics.columns
    print(metrics)

    # with PCs
    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
    metrics = compute_quality_metrics(we)
    assert 'isolation_distance' in metrics.columns
    print(metrics)

    # reload as an extension from we
    assert QualityMetricCalculator in we.get_available_extensions()
    assert we.is_extension('quality_metrics')
    qmc = we.load_extension('quality_metrics')
    assert isinstance(qmc, QualityMetricCalculator)
    assert qmc._metrics is not None
    # print(qmc._metrics)
    qmc = QualityMetricCalculator.load_from_folder('toy_waveforms')
    assert qmc._metrics is not None
def test_select_units():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    qm = compute_quality_metrics(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'toy_waveforms_filt')
    assert "quality_metrics" in we_filt.get_available_extension_names()
def test_select_units():
    we = WaveformExtractor.load_from_folder('mearec_waveforms')
    amps = compute_spike_amplitudes(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'mearec_waveforms_filt')
    assert "spike_amplitudes" in we_filt.get_available_extension_names()
Exemple #6
0
def test_get_template_channel_sparsity():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5)
    print(sparsity)

    sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='threshold', outputs='id', threshold=3)
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='threshold', outputs='index', threshold=3)
    print(sparsity)

    # load from folder because sorting properties must be loaded
    rec = load_extractor('toy_rec')
    sort = load_extractor('toy_sort')
    we = extract_waveforms(rec, sort, 'toy_waveforms_1')
    sparsity = get_template_channel_sparsity(we, method='by_property', outputs='id', by_property="group")
    print(sparsity)
    sparsity = get_template_channel_sparsity(we, method='by_property', outputs='index', by_property="group")

    print(sparsity)
Exemple #7
0
def test_select_units():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)

    keep_units = we.sorting.get_unit_ids()[::2]
    we_filt = we.select_units(keep_units, 'toy_waveforms_1seg_filt')
    assert "principal_components" in we_filt.get_available_extension_names()
Exemple #8
0
    def compute_metrics(self,
                        rec_name,
                        metric_names=['snr'],
                        ms_before=3.,
                        ms_after=4.,
                        max_spikes_per_unit=500,
                        n_jobs=-1,
                        total_memory='1G',
                        **snr_kargs):

        rec = self.get_recording(rec_name)
        gt_sorting = self.get_ground_truth(rec_name)

        # waveform extractor
        waveform_folder = self.study_folder / 'metrics' / f'waveforms_{rec_name}'
        if waveform_folder.is_dir():
            shutil.rmtree(waveform_folder)
        we = WaveformExtractor.create(rec, gt_sorting, waveform_folder)
        we.set_params(ms_before=ms_before,
                      ms_after=ms_after,
                      max_spikes_per_unit=max_spikes_per_unit)
        we.run(n_jobs=n_jobs, total_memory=total_memory)

        # metrics
        metrics = compute_quality_metrics(we, metric_names=metric_names)
        filename = self.study_folder / 'metrics' / f'metrics _{rec_name}.txt'
        metrics.to_csv(filename, sep='\t', index=True)

        return metrics
Exemple #9
0
def test_get_template_channel_sparsity():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='id', num_channels=5)
    sparsity = get_template_channel_sparsity(we, method='best_channels', outputs='index', num_channels=5)

    sparsity = get_template_channel_sparsity(we, method='radius', outputs='id', radius_um=50)
    sparsity = get_template_channel_sparsity(we, method='radius', outputs='index', radius_um=50)
def test_calculate_pc_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)
    pca = WaveformPrincipalComponent.load_from_folder('toy_waveforms')
    print(pca)

    res = calculate_pc_metrics(pca)
    print(res)
Exemple #11
0
def test_compute_unit_center_of_mass():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    unit_location = localize_units(we, method='center_of_mass', num_channels=4)
    unit_location_dict = localize_units(we,
                                        method='center_of_mass',
                                        num_channels=4,
                                        output='dict')
Exemple #12
0
def test_compute_monopolar_triangulation():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    unit_location = localize_units(we,
                                   method='monopolar_triangulation',
                                   radius_um=150)
    unit_location_dict = localize_units(we,
                                        method='monopolar_triangulation',
                                        radius_um=150,
                                        output='dict')
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file = 'all_pc.npy'
    pc.run_for_all_spikes(pc_file, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    
    all_pc = np.load(pc_file)
Exemple #14
0
    def get_waveform_extractor(self, rec_name, sorter_name=None):
        rec = self.get_recording(rec_name)

        if sorter_name is None:
            name = 'GroundTruth'
            sorting = self.get_ground_truth(rec_name)
        else:
            assert sorter_name in self.sorter_names
            name = sorter_name
            sorting = self.get_sorting(sorter_name, rec_name)

        waveform_folder = self.study_folder / 'waveforms' / f'waveforms_{name}_{rec_name}'

        if waveform_folder.is_dir():
            we = WaveformExtractor.load_from_folder(waveform_folder)
        else:
            we = WaveformExtractor.create(rec, sorting, waveform_folder)
        return we
Exemple #15
0
def test_pca_models_and_project_new():
    from sklearn.decomposition import IncrementalPCA
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')

    wfs0 = we.get_waveforms(unit_id=we.sorting.unit_ids[0])
    n_samples = wfs0.shape[1]
    n_channels = wfs0.shape[2]
    n_components = 5

    # local
    pc_local = compute_principal_components(we, n_components=n_components,
                                            load_if_exists=True, mode="by_channel_local")

    all_pca = pc_local.get_pca_model()
    assert len(all_pca) == we.recording.get_num_channels()

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_local.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # global
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
        
    pc_global = compute_principal_components(we, n_components=n_components,
                                             load_if_exists=True, mode="by_channel_global")

    all_pca = pc_global.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_global.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components, n_channels)
    
    # concatenated
    if Path('toy_waveforms_1seg/principal_components').is_dir():
        shutil.rmtree('toy_waveforms_1seg/principal_components')
    
    pc_concatenated = compute_principal_components(we, n_components=n_components,
                                                   load_if_exists=True, mode="concatenated")

    all_pca = pc_concatenated.get_pca_model()
    assert isinstance(all_pca, IncrementalPCA)

    # project
    new_waveforms = np.random.randn(100, n_samples, n_channels)
    new_proj = pc_concatenated.project_new(new_waveforms)

    assert new_proj.shape == (100, n_components)
Exemple #16
0
def test_calculate_template_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    features = calculate_template_metrics(we, upsampling_factor=1)
    print(features)

    features_up = calculate_template_metrics(we, upsampling_factor=2)
    print(features_up)

    features_sparse = calculate_template_metrics(we, upsampling_factor=2,
                                                 sparsity_dict=dict(method="radius", 
                                                                    radius_um=20))
    print(features_sparse)
def setup_module():
    for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'):
        if Path(folder).is_dir():
            shutil.rmtree(folder)

    recording, sorting = toy_example(num_segments=2, num_units=10)
    recording = recording.save(folder='toy_rec')
    sorting = sorting.save(folder='toy_sorting')

    we = WaveformExtractor.create(recording, sorting, 'toy_waveforms')
    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we.run_extract_waveforms(n_jobs=1, chunk_size=30000)
Exemple #18
0
def test_compute_principal_components_for_all_spikes():
    we = WaveformExtractor.load_from_folder('toy_waveforms_1seg')
    pc = compute_principal_components(we, load_if_exists=True)
    print(pc)

    pc_file1 = 'all_pc1.npy'
    pc.run_for_all_spikes(pc_file1, max_channels_per_template=7, chunk_size=10000, n_jobs=1)
    all_pc1 = np.load(pc_file1)

    pc_file2 = 'all_pc2.npy'
    pc.run_for_all_spikes(pc_file2, max_channels_per_template=7, chunk_size=10000, n_jobs=2)
    all_pc2 = np.load(pc_file2)

    assert np.array_equal(all_pc1, all_pc2)
def test_compute_quality_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    print(we)

    # without PC
    metrics = compute_quality_metrics(we, metric_names=['snr'])
    print(metrics)
    print(metrics.columns)

    # with PCs
    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
    metrics = compute_quality_metrics(we, waveform_principal_component=pca)
    print(metrics)
    print(metrics.columns)
Exemple #20
0
def setup_module():
    for folder in ('toy_rec', 'toy_sorting', 'toy_waveforms'):
        if Path(folder).is_dir():
            shutil.rmtree(folder)

    recording, sorting = toy_example(num_segments=2, num_units=10)
    recording = recording.save(folder='toy_rec')
    sorting = sorting.save(folder='toy_sorting')

    we = WaveformExtractor.create(recording, sorting, 'toy_waveforms')
    we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=500)
    we.run(n_jobs=1, chunk_size=30000)

    pca = WaveformPrincipalComponent(we)
    pca.set_params(n_components=5, mode='by_channel_local')
    pca.run()
Exemple #21
0
def test_WaveformPrincipalComponent():
    we = WaveformExtractor.load_from_folder('toy_waveforms_2seg')
    unit_ids = we.sorting.unit_ids
    num_channels = we.recording.get_num_channels()
    pc = WaveformPrincipalComponent(we)

    for mode in ('by_channel_local', 'by_channel_global'):
        pc.set_params(n_components=5, mode=mode)
        print(pc)
        pc.run()
        for i, unit_id in enumerate(unit_ids):
            proj = pc.get_projections(unit_id)
            # print(comp.shape)
            assert proj.shape[1:] == (5, 4)

        # import matplotlib.pyplot as plt
        # cmap = plt.get_cmap('jet', len(unit_ids))
        # fig, axs = plt.subplots(ncols=num_channels)
        # for i, unit_id in enumerate(unit_ids):
        # comp = pca.get_components(unit_id)
        # print(comp.shape)
        # for chan_ind in range(num_channels):
        # ax = axs[chan_ind]
        # ax.scatter(comp[:, 0, chan_ind], comp[:, 1, chan_ind], color=cmap(i))
        # plt.show()

    for mode in ('concatenated',):
        pc.set_params(n_components=5, mode=mode)
        print(pc)
        pc.run()
        for i, unit_id in enumerate(unit_ids):
            proj = pc.get_projections(unit_id)
            assert proj.shape[1] == 5
            # print(comp.shape)

    all_labels, all_components = pc.get_all_components()
    
    # relod as an extension from we
    assert WaveformPrincipalComponent in we.get_available_extensions()
    assert we.is_extension('principal_components')
    pc = we.load_extension('principal_components')
    assert isinstance(pc, WaveformPrincipalComponent)
    pc = WaveformPrincipalComponent.load_from_folder('toy_waveforms_2seg')
Exemple #22
0
def test_get_template_amplitudes():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    peak_values = get_template_amplitudes(we)
    print(peak_values)
Exemple #23
0
def test_get_template_extremum_channel_peak_shift():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    shifts = get_template_extremum_channel_peak_shift(we, peak_sign='neg')
    print(shifts)
Exemple #24
0
def test_get_template_best_channels():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    best_channels = get_template_best_channels(we, num_channels=2)
    print(best_channels)
Exemple #25
0
def test_calculate_template_metrics():
    we = WaveformExtractor.load_from_folder('toy_waveforms')
    features = calculate_template_metrics(we)
    print(features)
                       sorting,
                       folder,
                       ms_before=1.5,
                       ms_after=2.,
                       max_spikes_per_unit=500,
                       load_if_exists=True)
print(we)

###############################################################################
# Alternatively, the :code:`WaveformExtractor` object can be instantiated
# directly. In this case, we need to :code:`set_params()` to set the desired
# parameters:

folder = 'waveform_folder2'
we = WaveformExtractor.create(recording,
                              sorting,
                              folder,
                              remove_if_exists=True)
we.set_params(ms_before=3., ms_after=4., max_spikes_per_unit=1000)
we.run_extract_waveforms(n_jobs=1, chunk_size=30000, progress_bar=True)
print(we)

###############################################################################
# The :code:`'waveform_folder'` folder contains:
#  * the dumped recording (json)
#  * the dumped sorting (json)
#  * the parameters (json)
#  * a subfolder with "waveforms_XXX.npy" and "sampled_index_XXX.npy"

import os

print(os.listdir(folder))
Exemple #27
0
def test_get_template_extremum_amplitude():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    extremum_channels_ids = get_template_extremum_amplitude(we,
                                                            peak_sign='both')
    print(extremum_channels_ids)
Exemple #28
0
def nearest_neighbors_noise_overlap(waveform_extractor: si.WaveformExtractor,
                                    this_unit_id: int,
                                    max_spikes_for_nn: int = 1000,
                                    n_neighbors: int = 5,
                                    n_components: int = 10,
                                    radius_um: float = 100,
                                    seed: int = 0):
    """Calculates unit noise overlap based on NearestNeighbors search in PCA space.

    Based on noise overlap metric described in Chung et al. (2017) Neuron 95: 1381-1394.

    Rough logic:
    ------------
    1) Generate a noise cluster by randomly sampling voltage snippets from recording.
    2) Subtract projection onto the weighted average of noise snippets
       of both the target and noise clusters to correct for bias in sampling.
    3) Compute the isolation score between the noise cluster and the target cluster.
    
    Implementation details:
    -----------------------
    As with nn_isolation, the clusters that are compared (target and noise clusters)
    have the same number of spikes.
    
    See docstring for `_compute_isolation` for the definition of isolation score.
    
    Parameters:
    -----------
    we: si.WaveformExtractor
    this_unit_id: int
        ID of unit for which this metric will be calculated
    max_spikes_for_nn: int
        max number of spikes to use per cluster
    n_neighbors: int
        number of neighbors to check membership of
    n_components: int
        number of PC components to project the snippets
    radius_um: float
        only the channels within this radius of the peak channel
        are used to compute the metric
    seed: int
        seed for random subsampling of spikes

    Outputs:
    --------
    nn_noise_overlap : float
    """

    # set random seed
    rng = np.random.default_rng(seed=seed)

    # get random snippets from the recording to create a noise cluster
    recording = waveform_extractor.recording
    noise_cluster = get_random_data_chunks(
        recording,
        return_scaled=waveform_extractor.return_scaled,
        num_chunks_per_segment=max_spikes_for_nn,
        chunk_size=waveform_extractor.nsamples,
        seed=seed)

    noise_cluster = np.reshape(
        noise_cluster, (max_spikes_for_nn, waveform_extractor.nsamples, -1))

    # get waveforms for target cluster
    waveforms = waveform_extractor.get_waveforms(unit_id=this_unit_id)

    # adjust the size of the target and noise clusters to be equal
    if waveforms.shape[0] > max_spikes_for_nn:
        wf_ind = rng.choice(waveforms.shape[0],
                            max_spikes_for_nn,
                            replace=False)
        waveforms = waveforms[wf_ind]
        n_snippets = max_spikes_for_nn
    elif waveforms.shape[0] < max_spikes_for_nn:
        noise_ind = rng.choice(noise_cluster.shape[0],
                               waveforms.shape[0],
                               replace=False)
        noise_cluster = noise_cluster[noise_ind]
        n_snippets = waveforms.shape[0]
    else:
        n_snippets = max_spikes_for_nn

    # restrict to channels with significant signal
    closest_chans_idx = get_template_channel_sparsity(waveform_extractor,
                                                      method='radius',
                                                      outputs='index',
                                                      peak_sign='both',
                                                      radius_um=radius_um)
    waveforms = waveforms[:, :, closest_chans_idx[this_unit_id]]
    noise_cluster = noise_cluster[:, :, closest_chans_idx[this_unit_id]]

    # compute weighted noise snippet (Z)
    median_waveform = waveform_extractor.get_template(unit_id=this_unit_id,
                                                      mode='median')
    median_waveform = median_waveform[:, closest_chans_idx[this_unit_id]]
    tmax, chmax = np.unravel_index(np.argmax(np.abs(median_waveform)),
                                   median_waveform.shape)
    weights = [noise_clip[tmax, chmax] for noise_clip in noise_cluster]
    weights = np.asarray(weights)
    weights = weights / np.sum(weights)
    weighted_noise_snippet = np.sum(weights * noise_cluster.swapaxes(0, 2),
                                    axis=2).swapaxes(0, 1)

    # subtract projection onto weighted noise snippet
    for snippet in range(n_snippets):
        waveforms[snippet, :, :] = _subtract_clip_component(
            waveforms[snippet, :, :], weighted_noise_snippet)
        noise_cluster[snippet, :, :] = _subtract_clip_component(
            noise_cluster[snippet, :, :], weighted_noise_snippet)

    # compute principal components after concatenation
    all_snippets = np.concatenate([
        waveforms.reshape((n_snippets, -1)),
        noise_cluster.reshape((n_snippets, -1))
    ],
                                  axis=0)
    pca = IncrementalPCA(n_components=n_components)
    pca.partial_fit(all_snippets)
    projected_snippets = pca.transform(all_snippets)

    # compute overlap
    nn_noise_overlap = 1 - _compute_isolation(
        projected_snippets[:n_snippets, :], projected_snippets[n_snippets:, :],
        n_neighbors)
    return nn_noise_overlap
Exemple #29
0
def test_compute_template_similarity():
    we = WaveformExtractor.load_from_folder('mearec_waveforms')
    similarity = compute_template_similarity(we)
Exemple #30
0
def test_compute_unit_centers_of_mass():
    we = WaveformExtractor.load_from_folder('toy_waveforms')

    coms = compute_unit_centers_of_mass(we, num_channels=4)
    print(coms)