def test_thresh_isolation_distances():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    s_threshold = 200

    iso = compute_isolation_distances(sort, rec, apply_filter=False, seed=0)
    sort_iso = threshold_isolation_distances(sort,
                                             rec,
                                             s_threshold,
                                             'less',
                                             apply_filter=False,
                                             seed=0)

    original_ids = sort.get_unit_ids()
    new_iso = []
    for unit in sort_iso.get_unit_ids():
        new_iso.append(iso[original_ids.index(unit)])
    new_iso = np.array(new_iso)
    assert np.all(new_iso >= s_threshold)
    check_dumping(sort_iso)
    shutil.rmtree('test')
def test_thresh_l_ratios():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    l_ratios_thresh = 0

    l_ratios = compute_l_ratios(sort, rec, apply_filter=False, seed=0)
    sort_l_ratios = threshold_l_ratios(sort,
                                       rec,
                                       l_ratios_thresh,
                                       "less",
                                       apply_filter=False,
                                       seed=0)

    original_ids = sort.get_unit_ids()
    new_l_ratios = []
    for unit in sort_l_ratios.get_unit_ids():
        new_l_ratios.append(l_ratios[original_ids.index(unit)])
    new_l_ratios = np.array(new_l_ratios)
    assert np.all(new_l_ratios >= l_ratios_thresh)
    check_dumping(sort_l_ratios)
    shutil.rmtree('test')
def test_remove_artifacts():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                seed=0)
    triggers = [30000, 90000]
    ms = 10
    ms_frames = int(ms * rec.get_sampling_frequency() / 1000)

    rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10)
    traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames,
                                        end_frame=triggers[0] + ms_frames)
    traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10,
                                          end_frame=triggers[0] + 10)
    traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames,
                                        end_frame=triggers[1] + ms_frames)
    traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10,
                                          end_frame=triggers[1] + 10)

    assert not np.any(traces_all_0)
    assert not np.any(traces_all_1)
    assert not np.any(traces_short_0)
    assert not np.any(traces_short_1)

    check_dumping(rec_rmart)
    shutil.rmtree('test')
Exemple #4
0
def test_thresh_nn_metrics():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    s_threshold_hit = 0.9
    s_threshold_miss = 0.002

    nn_hit, nn_miss = compute_nn_metrics(sort, rec, apply_filter=False, seed=0)
    sort_hit = threshold_nn_metrics(sort, rec, s_threshold_hit, 'less', metric_name="nn_hit_rate",
                                    apply_filter=False, seed=0)
    sort_miss = threshold_nn_metrics(sort, rec, s_threshold_miss, 'greater', metric_name="nn_miss_rate",
                                     apply_filter=False, seed=0)

    original_ids = sort.get_unit_ids()
    new_nn_hit = []
    for unit in sort_hit.get_unit_ids():
        new_nn_hit.append(nn_hit[original_ids.index(unit)])
    new_nn_miss = []
    for unit in sort_miss.get_unit_ids():
        new_nn_miss.append(nn_miss[original_ids.index(unit)])
    new_nn_hit = np.array(new_nn_hit)
    new_nn_miss = np.array(new_nn_miss)
    assert np.all(new_nn_hit >= s_threshold_hit)
    assert np.all(new_nn_miss <= s_threshold_miss)
    check_dumping(sort_hit)
    check_dumping(sort_miss)
    shutil.rmtree('test')
def test_transform():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                seed=0)

    scalar = 3
    offset = 50

    rec_t = transform(rec, scalar=scalar, offset=offset)
    assert np.allclose(rec_t.get_traces(),
                       scalar * rec.get_traces() + offset,
                       atol=0.001)

    scalars = np.random.randn(4)
    offsets = np.random.randn(4)
    rec_t_arr = transform(rec, scalar=scalars, offset=offsets)
    for (tt, to, s, o) in zip(rec_t_arr.get_traces(), rec.get_traces(),
                              scalars, offsets):
        assert np.allclose(tt, s * to + o, atol=0.001)

    check_dumping(rec_t)
    check_dumping(rec_t_arr)
    shutil.rmtree('test')
def test_thresh_silhouettes():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    silhouette_thresh = .5

    silhouette = compute_silhouette_scores(sort,
                                           rec,
                                           apply_filter=False,
                                           seed=0)
    sort_silhouette = threshold_silhouette_scores(sort,
                                                  rec,
                                                  silhouette_thresh,
                                                  "less",
                                                  apply_filter=False,
                                                  seed=0)

    original_ids = sort.get_unit_ids()
    new_silhouette = []
    for unit in sort_silhouette.get_unit_ids():
        new_silhouette.append(silhouette[original_ids.index(unit)])
    new_silhouette = np.array(new_silhouette)
    assert np.all(new_silhouette >= silhouette_thresh)
    check_dumping(sort_silhouette)
    shutil.rmtree('test')
Exemple #7
0
def test_thresh_frs():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    fr_thresh = 2

    sort_fr = threshold_firing_rates(sort, fr_thresh, 'less', rec.get_num_frames())
    new_fr = compute_firing_rates(sort_fr, rec.get_num_frames())

    assert np.all(new_fr >= fr_thresh)
    check_dumping(sort_fr)
    shutil.rmtree('test')
def test_rectify():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')

    rec_rect = rectify(rec)

    assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces()))

    check_dumping(rec_rect)
    shutil.rmtree('test')
Exemple #9
0
def test_thresh_num_spikes():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    s_threshold = 25

    sort_ns = threshold_num_spikes(sort, s_threshold, 'less')
    new_ns = compute_num_spikes(sort_ns, sort.get_sampling_frequency())

    assert np.all(new_ns >= s_threshold)
    check_dumping(sort_ns)
    shutil.rmtree('test')
Exemple #10
0
def test_thresh_isi_violations():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    s_threshold = 0.01

    sort_isi = threshold_isi_violations(sort, s_threshold, 'greater', rec.get_num_frames())
    new_isi = compute_isi_violations(sort_isi, rec.get_num_frames(), sort.get_sampling_frequency())

    assert np.all(new_isi <= s_threshold)
    check_dumping(sort_isi)
    shutil.rmtree('test')
Exemple #11
0
def test_thresh_presence_ratios():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    s_threshold = 0.18

    sort_pr = threshold_presence_ratios(sort, s_threshold, 'less', rec.get_num_frames())
    new_pr = compute_presence_ratios(sort_pr, rec.get_num_frames(), sort.get_sampling_frequency())

    assert np.all(new_pr >= s_threshold)
    check_dumping(sort_pr)
    shutil.rmtree('test')
Exemple #12
0
def test_curation_sorting_extractor():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=3,
                                                seed=0)

    # Dummy features for testing merging and splitting of features
    sort.set_unit_spike_features(
        1, 'f_int', range(0 + 1,
                          len(sort.get_unit_spike_train(1)) + 1))
    sort.set_unit_spike_features(2, 'f_int',
                                 range(0, len(sort.get_unit_spike_train(2))))
    sort.set_unit_spike_features(
        2, 'bad_features', np.repeat(1, len(sort.get_unit_spike_train(2))))
    sort.set_unit_spike_features(3, 'f_int',
                                 range(0, len(sort.get_unit_spike_train(3))))

    CSX = st.curation.CurationSortingExtractor(parent_sorting=sort)
    merged_unit_id = CSX.merge_units(unit_ids=[1, 2])
    assert np.allclose(merged_unit_id, 4)
    original_spike_train = np.concatenate(
        (sort.get_unit_spike_train(1), sort.get_unit_spike_train(2)))
    indices_sort = np.argsort(original_spike_train)
    original_spike_train = original_spike_train[indices_sort]
    original_features = np.concatenate(
        (sort.get_unit_spike_features(1, 'f_int'),
         sort.get_unit_spike_features(2, 'f_int')))
    original_features = original_features[indices_sort]
    assert np.allclose(CSX.get_unit_spike_train(4), original_spike_train)
    assert np.allclose(CSX.get_unit_spike_features(4, 'f_int'),
                       original_features)
    assert CSX.get_unit_spike_feature_names(4) == ['f_int']
    assert np.allclose(CSX.get_sampling_frequency(),
                       sort.get_sampling_frequency())

    unit_ids_split = CSX.split_unit(unit_id=3,
                                    indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    assert np.allclose(unit_ids_split[0], 5)
    assert np.allclose(unit_ids_split[1], 6)
    original_spike_train = sort.get_unit_spike_train(3)
    original_features = sort.get_unit_spike_features(3, 'f_int')
    split_spike_train_1 = CSX.get_unit_spike_train(5)
    split_spike_train_2 = CSX.get_unit_spike_train(6)
    split_features_1 = CSX.get_unit_spike_features(5, 'f_int')
    split_features_2 = CSX.get_unit_spike_features(6, 'f_int')
    assert np.allclose(original_spike_train[:10], split_spike_train_1)
    assert np.allclose(original_spike_train[10:], split_spike_train_2)
    assert np.allclose(original_features[:10], split_features_1)
    assert np.allclose(original_features[10:], split_features_2)

    check_dumping(CSX)
Exemple #13
0
def test_thresh_snrs():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)

    snr_thresh = 4

    sort_snr = threshold_snrs(sort, rec, snr_thresh, 'less', apply_filter=False, seed=0)
    new_snr = compute_snrs(sort_snr, rec, apply_filter=False, seed=0)

    assert np.all(new_snr >= snr_thresh)
    check_dumping(sort_snr)
    shutil.rmtree('test')
def test_rectify():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                seed=0)

    rec_rect = rectify(rec)

    assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces()))

    check_dumping(rec_rect)
    shutil.rmtree('test')
def test_blank_saturation():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')
    threshold = 2
    rec_bs = blank_saturation(rec, threshold=threshold)

    index_below_threshold = np.where(rec.get_traces() < threshold)

    assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold)

    check_dumping(rec_bs)
    shutil.rmtree('test')
Exemple #16
0
def test_thresh_amplitude_cutoffs():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)

    amplitude_cutoff_thresh = 0

    sort_amplitude_cutoff = threshold_amplitude_cutoffs(sort, rec, amplitude_cutoff_thresh, "less",
                                                        apply_filter=False, seed=0)
    new_amplitude_cutoff = compute_amplitude_cutoffs(sort_amplitude_cutoff, rec, apply_filter=False, seed=0)

    assert np.all(new_amplitude_cutoff >= amplitude_cutoff_thresh)
    check_dumping(sort_amplitude_cutoff)
    shutil.rmtree('test')
def test_center():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')

    rec_c = center(rec, mode='mean')
    assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    rec_c = center(rec, mode='median')
    assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    shutil.rmtree('test')
Exemple #18
0
def test_thresh_presence_ratios():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               K=10,
                                                               seed=0,
                                                               folder='test')
    s_threshold = 0.18

    sort_pr = threshold_presence_ratios(sort, s_threshold, 'less')
    new_pr = compute_presence_ratios(sort_pr, sort.get_sampling_frequency())[0]

    assert np.all(new_pr >= s_threshold)
    check_dumping(sort_pr)
    shutil.rmtree('test')
Exemple #19
0
def test_thresh_frs():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               K=10,
                                                               seed=0,
                                                               folder='test')
    fr_thresh = 2

    sort_fr = threshold_firing_rates(sort, fr_thresh, 'less')
    new_fr = compute_firing_rates(sort_fr)[0]

    assert np.all(new_fr >= fr_thresh)
    check_dumping(sort_fr)
    shutil.rmtree('test')
Exemple #20
0
def test_thresh_isi_violations():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               K=10,
                                                               seed=0,
                                                               folder='test')
    s_threshold = 0.01

    sort_isi = threshold_isi_violations(sort, s_threshold, 'greater')
    new_isi = compute_isi_violations(sort_isi,
                                     sort.get_sampling_frequency())[0]

    assert np.all(new_isi <= s_threshold)
    check_dumping(sort_isi)
    shutil.rmtree('test')
def test_clip():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')
    threshold = 5
    rec_clip = clip(rec, a_min=-threshold, a_max=threshold)

    index_below_threshold = np.where(rec.get_traces() < -threshold)
    index_above_threshold = np.where(rec.get_traces() > threshold)

    assert np.all(rec_clip.get_traces()[index_below_threshold] == -threshold)
    assert np.all(rec_clip.get_traces()[index_above_threshold] == threshold)

    check_dumping(rec_clip)
    shutil.rmtree('test')
def test_notch_filter():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')

    rec_n = notch_filter(rec, 3000, q=10)

    assert check_signal_power_signal1_below_signal2(
        rec_n.get_traces(),
        rec.get_traces(),
        freq_range=[2900, 3100],
        fs=rec.get_sampling_frequency())

    check_dumping(rec_n)
    shutil.rmtree('test')
def test_blank_saturation():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)
    threshold = 2
    rec_bs = blank_saturation(rec, threshold=threshold)

    index_below_threshold = np.where(rec.get_traces() < threshold)

    assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold)

    check_dumping(rec_bs)
    shutil.rmtree('test')
def test_center():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)

    rec_c = center(rec, mode='mean')
    assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    rec_c = center(rec, mode='median')
    assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    shutil.rmtree('test')
def test_resample():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test')

    resample_rate_low = 0.1 * rec.get_sampling_frequency()
    resample_rate_high = 2 * rec.get_sampling_frequency()

    rec_rsl = resample(rec, resample_rate_low)
    rec_rsh = resample(rec, resample_rate_high)

    assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1)
    assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2)

    check_dumping(rec_rsl)
    check_dumping(rec_rsh)
    shutil.rmtree('test')
def test_whiten():
    rec, sort = se.example_datasets.create_dumpable_extractors(duration=10,
                                                               num_channels=4,
                                                               folder='test',
                                                               seed=0)

    rec_w = whiten(rec)
    cov_w = np.cov(rec_w.get_traces())
    assert np.allclose(cov_w, np.eye(4), atol=0.3)

    # should size should not affect
    rec_w2 = whiten(rec, chunk_size=30000)

    assert np.array_equal(rec_w.get_traces(), rec_w2.get_traces())

    check_dumping(rec_w)
    shutil.rmtree('test')
Exemple #27
0
def test_thresh_threshold_drift_metrics():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)
    s_threshold = 1

    sort_max = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="max_drift",
                                       apply_filter=False, seed=0)
    sort_cum = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="cumulative_drift",
                                       apply_filter=False, seed=0)
    new_max_drift, _ = compute_drift_metrics(sort_max, rec, apply_filter=False, seed=0)
    _, new_cum_drift = compute_drift_metrics(sort_cum, rec, apply_filter=False, seed=0)

    assert np.all(new_max_drift <= s_threshold)
    assert np.all(new_cum_drift <= s_threshold)
    check_dumping(sort_max)
    check_dumping(sort_cum)
    shutil.rmtree('test')
Exemple #28
0
def test_thresh_noise_overlaps():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10,
                                                seed=0)

    noise_thresh = 0.3

    noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0)
    sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0)

    original_ids = sort.get_unit_ids()
    new_noise = []
    for unit in sort_noise.get_unit_ids():
        new_noise.append(noise_overlaps[original_ids.index(unit)])
    new_noise = np.array(new_noise)
    assert np.all(new_noise >= noise_thresh)
    check_dumping(sort_noise)
    shutil.rmtree('test')
def test_resample():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)

    resample_rate_low = 0.1 * rec.get_sampling_frequency()
    resample_rate_high = 2 * rec.get_sampling_frequency()

    rec_rsl = resample(rec, resample_rate_low)
    rec_rsh = resample(rec, resample_rate_high)

    assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1)
    assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2)

    check_dumping(rec_rsl)
    check_dumping(rec_rsh)
    shutil.rmtree('test')
def test_bandpass_filter_with_cache():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                seed=0)

    rec_filtered = bandpass_filter(rec,
                                   freq_min=5000,
                                   freq_max=10000,
                                   cache_to_file=True,
                                   chunk_size=10000)

    rec_filtered2 = bandpass_filter(rec,
                                    freq_min=5000,
                                    freq_max=10000,
                                    cache_to_file=True,
                                    chunk_size=None)

    rec_filtered3 = bandpass_filter(rec,
                                    freq_min=5000,
                                    freq_max=10000,
                                    cache_chunks=True,
                                    chunk_size=10000)
    rec_filtered3.get_traces()
    assert rec_filtered3._filtered_cache_chunks.get('0') is not None

    rec_filtered4 = bandpass_filter(rec,
                                    freq_min=5000,
                                    freq_max=10000,
                                    cache_chunks=True,
                                    chunk_size=None)

    assert np.allclose(rec_filtered.get_traces(),
                       rec_filtered2.get_traces(),
                       rtol=1e-02,
                       atol=1e-02)
    assert np.allclose(rec_filtered.get_traces(),
                       rec_filtered3.get_traces(),
                       rtol=1e-02,
                       atol=1e-02)
    assert np.allclose(rec_filtered.get_traces(),
                       rec_filtered4.get_traces(),
                       rtol=1e-02,
                       atol=1e-02)

    check_dumping(rec_filtered)
    check_dumping(rec_filtered2)
    check_dumping(rec_filtered3)
    check_dumping(rec_filtered4)

    shutil.rmtree('test')