def test_thresh_metrics():
    rec, sort = se.example_datasets.toy_example(duration=10,
                                                num_channels=4,
                                                seed=0)
    fr_thresh = 2
    snr_thresh = 4

    sorting_metrics1 = threshold_metrics(sort,
                                         rec,
                                         metrics=['firing_rate', 'snr'],
                                         thresholds=[fr_thresh, snr_thresh],
                                         threshold_signs=['less', 'less'],
                                         mode='or')

    new_fr = compute_firing_rates(sorting_metrics1)
    new_snr = compute_snrs(sorting_metrics1, rec)

    assert np.all(new_fr >= fr_thresh) and np.all(new_snr >= snr_thresh)

    sorting_metrics1 = threshold_metrics(sort,
                                         rec,
                                         metrics=['firing_rate', 'snr'],
                                         thresholds=[fr_thresh, snr_thresh],
                                         threshold_signs=['less', 'less'],
                                         mode='and')

    new_fr = compute_firing_rates(sorting_metrics1)
    new_snr = compute_snrs(sorting_metrics1, rec)

    assert np.all((new_fr >= fr_thresh) + (new_snr >= snr_thresh))
def test_functions():
    rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0)

    firing_rates = compute_firing_rates(sort, seed=0)[0]
    num_spikes = compute_num_spikes(sort, seed=0)[0]
    isi = compute_isi_violations(sort, seed=0)[0]
    presence = compute_presence_ratios(sort, seed=0)[0]
    amp_cutoff = compute_amplitude_cutoffs(sort, rec, seed=0)[0]
    max_drift, cum_drift = compute_drift_metrics(sort, rec, seed=0, memmap=False)[0]
    silh = compute_silhouette_scores(sort, rec, seed=0)[0]
    iso = compute_isolation_distances(sort, rec, seed=0)[0]
    l_ratio = compute_l_ratios(sort, rec, seed=0)[0]
    dprime = compute_d_primes(sort, rec, seed=0)[0]
    nn_hit, nn_miss = compute_nn_metrics(sort, rec, seed=0)[0]
    snr = compute_snrs(sort, rec, seed=0)[0]
    metrics = compute_metrics(sort, rec, return_dict=True, seed=0)

    assert np.allclose(metrics['firing_rate'][0], firing_rates)
    assert np.allclose(metrics['num_spikes'][0], num_spikes)
    assert np.allclose(metrics['isi_viol'][0], isi)
    assert np.allclose(metrics['amplitude_cutoff'][0], amp_cutoff)
    assert np.allclose(metrics['presence_ratio'][0], presence)
    assert np.allclose(metrics['silhouette_score'][0], silh)
    assert np.allclose(metrics['isolation_distance'][0], iso)
    assert np.allclose(metrics['l_ratio'][0], l_ratio)
    assert np.allclose(metrics['d_prime'][0], dprime)
    assert np.allclose(metrics['snr'][0], snr)
    assert np.allclose(metrics['max_drift'][0], max_drift)
    assert np.allclose(metrics['cumulative_drift'][0], cum_drift)
    assert np.allclose(metrics['nn_hit_rate'][0], nn_hit)
    assert np.allclose(metrics['nn_miss_rate'][0], nn_miss)
Exemple #3
0
def test_thresh_snr():
    rec, sort = se.example_datasets.toy_example(
        duration=10, num_channels=4, seed=0
    )
    snr_thresh = 4

    sort_snr = threshold_snr(sort, rec, snr_thresh, 'less')
    new_snr = compute_snrs(sort_snr, rec)[0]

    assert np.all(new_snr >= snr_thresh)
Exemple #4
0
def test_thresh_snrs():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)

    snr_thresh = 4

    sort_snr = threshold_snrs(sort, rec, snr_thresh, 'less', apply_filter=False, seed=0)
    new_snr = compute_snrs(sort_snr, rec, apply_filter=False, seed=0)

    assert np.all(new_snr >= snr_thresh)
    check_dumping(sort_snr)
    shutil.rmtree('test')
def test_functions():
    rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4)

    firing_rates = compute_firing_rates(sort)
    num_spikes = compute_num_spikes(sort)
    isi = compute_isi_violations(sort)
    presence = compute_presence_ratios(sort)

    amp_cutoff = compute_amplitude_cutoffs(sort, rec)

    max_drift, cum_drift = compute_drift_metrics(sort, rec)
    silh = compute_silhouette_scores(sort, rec)
    iso = compute_isolation_distances(sort, rec)
    l_ratio = compute_l_ratios(sort, rec)
    dprime = compute_d_primes(sort, rec)
    nn_hit, nn_miss = compute_nn_metrics(sort, rec)

    snr = compute_snrs(sort, rec)