def test_thresh_isolation_distances(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 200 iso = compute_isolation_distances(sort, rec, apply_filter=False, seed=0) sort_iso = threshold_isolation_distances(sort, rec, s_threshold, 'less', apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_iso = [] for unit in sort_iso.get_unit_ids(): new_iso.append(iso[original_ids.index(unit)]) new_iso = np.array(new_iso) assert np.all(new_iso >= s_threshold) check_dumping(sort_iso) shutil.rmtree('test')
def test_thresh_l_ratios(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) l_ratios_thresh = 0 l_ratios = compute_l_ratios(sort, rec, apply_filter=False, seed=0) sort_l_ratios = threshold_l_ratios(sort, rec, l_ratios_thresh, "less", apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_l_ratios = [] for unit in sort_l_ratios.get_unit_ids(): new_l_ratios.append(l_ratios[original_ids.index(unit)]) new_l_ratios = np.array(new_l_ratios) assert np.all(new_l_ratios >= l_ratios_thresh) check_dumping(sort_l_ratios) shutil.rmtree('test')
def test_remove_artifacts(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0) triggers = [30000, 90000] ms = 10 ms_frames = int(ms * rec.get_sampling_frequency() / 1000) rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10) traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10, end_frame=triggers[0] + 10) traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10, end_frame=triggers[1] + 10) assert not np.any(traces_all_0) assert not np.any(traces_all_1) assert not np.any(traces_short_0) assert not np.any(traces_short_1) check_dumping(rec_rmart) shutil.rmtree('test')
def test_thresh_nn_metrics(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold_hit = 0.9 s_threshold_miss = 0.002 nn_hit, nn_miss = compute_nn_metrics(sort, rec, apply_filter=False, seed=0) sort_hit = threshold_nn_metrics(sort, rec, s_threshold_hit, 'less', metric_name="nn_hit_rate", apply_filter=False, seed=0) sort_miss = threshold_nn_metrics(sort, rec, s_threshold_miss, 'greater', metric_name="nn_miss_rate", apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_nn_hit = [] for unit in sort_hit.get_unit_ids(): new_nn_hit.append(nn_hit[original_ids.index(unit)]) new_nn_miss = [] for unit in sort_miss.get_unit_ids(): new_nn_miss.append(nn_miss[original_ids.index(unit)]) new_nn_hit = np.array(new_nn_hit) new_nn_miss = np.array(new_nn_miss) assert np.all(new_nn_hit >= s_threshold_hit) assert np.all(new_nn_miss <= s_threshold_miss) check_dumping(sort_hit) check_dumping(sort_miss) shutil.rmtree('test')
def test_transform(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0) scalar = 3 offset = 50 rec_t = transform(rec, scalar=scalar, offset=offset) assert np.allclose(rec_t.get_traces(), scalar * rec.get_traces() + offset, atol=0.001) scalars = np.random.randn(4) offsets = np.random.randn(4) rec_t_arr = transform(rec, scalar=scalars, offset=offsets) for (tt, to, s, o) in zip(rec_t_arr.get_traces(), rec.get_traces(), scalars, offsets): assert np.allclose(tt, s * to + o, atol=0.001) check_dumping(rec_t) check_dumping(rec_t_arr) shutil.rmtree('test')
def test_thresh_silhouettes(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) silhouette_thresh = .5 silhouette = compute_silhouette_scores(sort, rec, apply_filter=False, seed=0) sort_silhouette = threshold_silhouette_scores(sort, rec, silhouette_thresh, "less", apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_silhouette = [] for unit in sort_silhouette.get_unit_ids(): new_silhouette.append(silhouette[original_ids.index(unit)]) new_silhouette = np.array(new_silhouette) assert np.all(new_silhouette >= silhouette_thresh) check_dumping(sort_silhouette) shutil.rmtree('test')
def test_thresh_frs(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) fr_thresh = 2 sort_fr = threshold_firing_rates(sort, fr_thresh, 'less', rec.get_num_frames()) new_fr = compute_firing_rates(sort_fr, rec.get_num_frames()) assert np.all(new_fr >= fr_thresh) check_dumping(sort_fr) shutil.rmtree('test')
def test_rectify(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') rec_rect = rectify(rec) assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces())) check_dumping(rec_rect) shutil.rmtree('test')
def test_thresh_num_spikes(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 25 sort_ns = threshold_num_spikes(sort, s_threshold, 'less') new_ns = compute_num_spikes(sort_ns, sort.get_sampling_frequency()) assert np.all(new_ns >= s_threshold) check_dumping(sort_ns) shutil.rmtree('test')
def test_thresh_isi_violations(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 0.01 sort_isi = threshold_isi_violations(sort, s_threshold, 'greater', rec.get_num_frames()) new_isi = compute_isi_violations(sort_isi, rec.get_num_frames(), sort.get_sampling_frequency()) assert np.all(new_isi <= s_threshold) check_dumping(sort_isi) shutil.rmtree('test')
def test_thresh_presence_ratios(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 0.18 sort_pr = threshold_presence_ratios(sort, s_threshold, 'less', rec.get_num_frames()) new_pr = compute_presence_ratios(sort_pr, rec.get_num_frames(), sort.get_sampling_frequency()) assert np.all(new_pr >= s_threshold) check_dumping(sort_pr) shutil.rmtree('test')
def test_curation_sorting_extractor(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=3, seed=0) # Dummy features for testing merging and splitting of features sort.set_unit_spike_features( 1, 'f_int', range(0 + 1, len(sort.get_unit_spike_train(1)) + 1)) sort.set_unit_spike_features(2, 'f_int', range(0, len(sort.get_unit_spike_train(2)))) sort.set_unit_spike_features( 2, 'bad_features', np.repeat(1, len(sort.get_unit_spike_train(2)))) sort.set_unit_spike_features(3, 'f_int', range(0, len(sort.get_unit_spike_train(3)))) CSX = st.curation.CurationSortingExtractor(parent_sorting=sort) merged_unit_id = CSX.merge_units(unit_ids=[1, 2]) assert np.allclose(merged_unit_id, 4) original_spike_train = np.concatenate( (sort.get_unit_spike_train(1), sort.get_unit_spike_train(2))) indices_sort = np.argsort(original_spike_train) original_spike_train = original_spike_train[indices_sort] original_features = np.concatenate( (sort.get_unit_spike_features(1, 'f_int'), sort.get_unit_spike_features(2, 'f_int'))) original_features = original_features[indices_sort] assert np.allclose(CSX.get_unit_spike_train(4), original_spike_train) assert np.allclose(CSX.get_unit_spike_features(4, 'f_int'), original_features) assert CSX.get_unit_spike_feature_names(4) == ['f_int'] assert np.allclose(CSX.get_sampling_frequency(), sort.get_sampling_frequency()) unit_ids_split = CSX.split_unit(unit_id=3, indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) assert np.allclose(unit_ids_split[0], 5) assert np.allclose(unit_ids_split[1], 6) original_spike_train = sort.get_unit_spike_train(3) original_features = sort.get_unit_spike_features(3, 'f_int') split_spike_train_1 = CSX.get_unit_spike_train(5) split_spike_train_2 = CSX.get_unit_spike_train(6) split_features_1 = CSX.get_unit_spike_features(5, 'f_int') split_features_2 = CSX.get_unit_spike_features(6, 'f_int') assert np.allclose(original_spike_train[:10], split_spike_train_1) assert np.allclose(original_spike_train[10:], split_spike_train_2) assert np.allclose(original_features[:10], split_features_1) assert np.allclose(original_features[10:], split_features_2) check_dumping(CSX)
def test_thresh_snrs(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) snr_thresh = 4 sort_snr = threshold_snrs(sort, rec, snr_thresh, 'less', apply_filter=False, seed=0) new_snr = compute_snrs(sort_snr, rec, apply_filter=False, seed=0) assert np.all(new_snr >= snr_thresh) check_dumping(sort_snr) shutil.rmtree('test')
def test_rectify(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0) rec_rect = rectify(rec) assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces())) check_dumping(rec_rect) shutil.rmtree('test')
def test_blank_saturation(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') threshold = 2 rec_bs = blank_saturation(rec, threshold=threshold) index_below_threshold = np.where(rec.get_traces() < threshold) assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold) check_dumping(rec_bs) shutil.rmtree('test')
def test_thresh_amplitude_cutoffs(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) amplitude_cutoff_thresh = 0 sort_amplitude_cutoff = threshold_amplitude_cutoffs(sort, rec, amplitude_cutoff_thresh, "less", apply_filter=False, seed=0) new_amplitude_cutoff = compute_amplitude_cutoffs(sort_amplitude_cutoff, rec, apply_filter=False, seed=0) assert np.all(new_amplitude_cutoff >= amplitude_cutoff_thresh) check_dumping(sort_amplitude_cutoff) shutil.rmtree('test')
def test_center(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') rec_c = center(rec, mode='mean') assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) rec_c = center(rec, mode='median') assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) shutil.rmtree('test')
def test_thresh_presence_ratios(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, K=10, seed=0, folder='test') s_threshold = 0.18 sort_pr = threshold_presence_ratios(sort, s_threshold, 'less') new_pr = compute_presence_ratios(sort_pr, sort.get_sampling_frequency())[0] assert np.all(new_pr >= s_threshold) check_dumping(sort_pr) shutil.rmtree('test')
def test_thresh_frs(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, K=10, seed=0, folder='test') fr_thresh = 2 sort_fr = threshold_firing_rates(sort, fr_thresh, 'less') new_fr = compute_firing_rates(sort_fr)[0] assert np.all(new_fr >= fr_thresh) check_dumping(sort_fr) shutil.rmtree('test')
def test_thresh_isi_violations(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, K=10, seed=0, folder='test') s_threshold = 0.01 sort_isi = threshold_isi_violations(sort, s_threshold, 'greater') new_isi = compute_isi_violations(sort_isi, sort.get_sampling_frequency())[0] assert np.all(new_isi <= s_threshold) check_dumping(sort_isi) shutil.rmtree('test')
def test_clip(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') threshold = 5 rec_clip = clip(rec, a_min=-threshold, a_max=threshold) index_below_threshold = np.where(rec.get_traces() < -threshold) index_above_threshold = np.where(rec.get_traces() > threshold) assert np.all(rec_clip.get_traces()[index_below_threshold] == -threshold) assert np.all(rec_clip.get_traces()[index_above_threshold] == threshold) check_dumping(rec_clip) shutil.rmtree('test')
def test_notch_filter(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') rec_n = notch_filter(rec, 3000, q=10) assert check_signal_power_signal1_below_signal2( rec_n.get_traces(), rec.get_traces(), freq_range=[2900, 3100], fs=rec.get_sampling_frequency()) check_dumping(rec_n) shutil.rmtree('test')
def test_blank_saturation(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) threshold = 2 rec_bs = blank_saturation(rec, threshold=threshold) index_below_threshold = np.where(rec.get_traces() < threshold) assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold) check_dumping(rec_bs) shutil.rmtree('test')
def test_center(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_c = center(rec, mode='mean') assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) rec_c = center(rec, mode='median') assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) shutil.rmtree('test')
def test_resample(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test') resample_rate_low = 0.1 * rec.get_sampling_frequency() resample_rate_high = 2 * rec.get_sampling_frequency() rec_rsl = resample(rec, resample_rate_low) rec_rsh = resample(rec, resample_rate_high) assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1) assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2) check_dumping(rec_rsl) check_dumping(rec_rsh) shutil.rmtree('test')
def test_whiten(): rec, sort = se.example_datasets.create_dumpable_extractors(duration=10, num_channels=4, folder='test', seed=0) rec_w = whiten(rec) cov_w = np.cov(rec_w.get_traces()) assert np.allclose(cov_w, np.eye(4), atol=0.3) # should size should not affect rec_w2 = whiten(rec, chunk_size=30000) assert np.array_equal(rec_w.get_traces(), rec_w2.get_traces()) check_dumping(rec_w) shutil.rmtree('test')
def test_thresh_threshold_drift_metrics(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 1 sort_max = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="max_drift", apply_filter=False, seed=0) sort_cum = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="cumulative_drift", apply_filter=False, seed=0) new_max_drift, _ = compute_drift_metrics(sort_max, rec, apply_filter=False, seed=0) _, new_cum_drift = compute_drift_metrics(sort_cum, rec, apply_filter=False, seed=0) assert np.all(new_max_drift <= s_threshold) assert np.all(new_cum_drift <= s_threshold) check_dumping(sort_max) check_dumping(sort_cum) shutil.rmtree('test')
def test_thresh_noise_overlaps(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) noise_thresh = 0.3 noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0) sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_noise = [] for unit in sort_noise.get_unit_ids(): new_noise.append(noise_overlaps[original_ids.index(unit)]) new_noise = np.array(new_noise) assert np.all(new_noise >= noise_thresh) check_dumping(sort_noise) shutil.rmtree('test')
def test_resample(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) resample_rate_low = 0.1 * rec.get_sampling_frequency() resample_rate_high = 2 * rec.get_sampling_frequency() rec_rsl = resample(rec, resample_rate_low) rec_rsh = resample(rec, resample_rate_high) assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1) assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2) check_dumping(rec_rsl) check_dumping(rec_rsh) shutil.rmtree('test')
def test_bandpass_filter_with_cache(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0) rec_filtered = bandpass_filter(rec, freq_min=5000, freq_max=10000, cache_to_file=True, chunk_size=10000) rec_filtered2 = bandpass_filter(rec, freq_min=5000, freq_max=10000, cache_to_file=True, chunk_size=None) rec_filtered3 = bandpass_filter(rec, freq_min=5000, freq_max=10000, cache_chunks=True, chunk_size=10000) rec_filtered3.get_traces() assert rec_filtered3._filtered_cache_chunks.get('0') is not None rec_filtered4 = bandpass_filter(rec, freq_min=5000, freq_max=10000, cache_chunks=True, chunk_size=None) assert np.allclose(rec_filtered.get_traces(), rec_filtered2.get_traces(), rtol=1e-02, atol=1e-02) assert np.allclose(rec_filtered.get_traces(), rec_filtered3.get_traces(), rtol=1e-02, atol=1e-02) assert np.allclose(rec_filtered.get_traces(), rec_filtered4.get_traces(), rtol=1e-02, atol=1e-02) check_dumping(rec_filtered) check_dumping(rec_filtered2) check_dumping(rec_filtered3) check_dumping(rec_filtered4) shutil.rmtree('test')