def test_biocam_extractor(self): path1 = self.test_dir + '/raw.brw' se.BiocamRecordingExtractor.write_recording(self.RX, path1) RX_biocam = se.BiocamRecordingExtractor(path1) check_recording_return_types(RX_biocam) check_recordings_equal(self.RX, RX_biocam) check_dumping(RX_biocam)
def test_remove_artifacts(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) triggers = [15000, 30000] ms = 10 ms_frames = int(ms * rec.get_sampling_frequency() / 1000) traces_all_0_clean = rec.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1_clean = rec.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10) traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10, end_frame=triggers[0] + 10) traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10, end_frame=triggers[1] + 10) assert not np.any(traces_all_0) assert not np.any(traces_all_1) assert not np.any(traces_short_0) assert not np.any(traces_short_1) rec_rmart_lin = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="linear") traces_all_0 = rec_rmart_lin.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1 = rec_rmart_lin.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) assert not np.allclose(traces_all_0, traces_all_0_clean) assert not np.allclose(traces_all_1, traces_all_1_clean) rec_rmart_cub = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="cubic") traces_all_0 = rec_rmart_cub.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames) traces_all_1 = rec_rmart_cub.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames) assert not np.allclose(traces_all_0, traces_all_0_clean) assert not np.allclose(traces_all_1, traces_all_1_clean) check_dumping(rec_rmart) shutil.rmtree('test')
def test_transform(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) scalar = 3 offset = 50 rec_t = transform(rec, scalar=scalar, offset=offset) assert np.allclose(rec_t.get_traces(), scalar * rec.get_traces() + offset, atol=0.001) scalars = np.random.randn(4) offsets = np.random.randn(4) rec_t_arr = transform(rec, scalar=scalars, offset=offsets) for (tt, to, s, o) in zip(rec_t_arr.get_traces(), rec.get_traces(), scalars, offsets): assert np.allclose(tt, s * to + o, atol=0.001) check_dumping(rec_t) check_dumping(rec_t_arr) shutil.rmtree('test')
def test_thresh_l_ratios(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) l_ratios_thresh = 0 l_ratios = compute_l_ratios(sort, rec, apply_filter=False, seed=0) sort_l_ratios = threshold_l_ratios(sort, rec, l_ratios_thresh, "less", apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_l_ratios = [] for unit in sort_l_ratios.get_unit_ids(): new_l_ratios.append(l_ratios[original_ids.index(unit)]) new_l_ratios = np.array(new_l_ratios) assert np.all(new_l_ratios >= l_ratios_thresh) check_dumping(sort_l_ratios) shutil.rmtree('test')
def test_resample(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) resample_rate_low = 0.1 * rec.get_sampling_frequency() resample_rate_high = 2 * rec.get_sampling_frequency() rec_rsl = resample(rec, resample_rate_low) rec_rsh = resample(rec, resample_rate_high) assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1) assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2) # with times times = rec.frame_to_time(np.arange(rec.get_num_frames())) - 10 times[1000:] += 0.5 rec.set_times(times) rec_rsl = resample(rec, resample_rate_low) rec_rsh = resample(rec, resample_rate_high) assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1) assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2) check_dumping(rec_rsl) check_dumping(rec_rsh) shutil.rmtree('test')
def test_thresh_silhouettes(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) silhouette_thresh = .5 silhouette = compute_silhouette_scores(sort, rec, apply_filter=False, seed=0) sort_silhouette = threshold_silhouette_scores(sort, rec, silhouette_thresh, "less", apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_silhouette = [] for unit in sort_silhouette.get_unit_ids(): new_silhouette.append(silhouette[original_ids.index(unit)]) new_silhouette = np.array(new_silhouette) assert np.all(new_silhouette >= silhouette_thresh) check_dumping(sort_silhouette) shutil.rmtree('test')
def test_thresh_isolation_distances(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 200 iso = compute_isolation_distances(sort, rec, apply_filter=False, seed=0) sort_iso = threshold_isolation_distances(sort, rec, s_threshold, 'less', apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_iso = [] for unit in sort_iso.get_unit_ids(): new_iso.append(iso[original_ids.index(unit)]) new_iso = np.array(new_iso) assert np.all(new_iso >= s_threshold) check_dumping(sort_iso) shutil.rmtree('test')
def test_thresh_noise_overlaps(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) noise_thresh = 0.3 noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0) sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0) original_ids = sort.get_unit_ids() new_noise = [] for unit in sort_noise.get_unit_ids(): new_noise.append(noise_overlaps[original_ids.index(unit)]) new_noise = np.array(new_noise) assert np.all(new_noise >= noise_thresh) check_dumping(sort_noise) shutil.rmtree('test')
def test_spykingcircus_extractor(self): path1 = self.test_dir + '/sc' se.SpykingCircusSortingExtractor.write_sorting(self.SX, path1) SX_spy = se.SpykingCircusSortingExtractor(path1) check_sorting_return_types(SX_spy) check_sortings_equal(self.SX, SX_spy) check_dumping(SX_spy)
def test_hs2_extractor(self): path1 = self.test_dir + '/firings_true.hdf5' se.HS2SortingExtractor.write_sorting(self.SX, path1) SX_hs2 = se.HS2SortingExtractor(path1) check_sorting_return_types(SX_hs2) check_sortings_equal(self.SX, SX_hs2) self.assertEqual(SX_hs2.get_sampling_frequency(), self.SX.get_sampling_frequency()) check_dumping(SX_hs2)
def test_rectify(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_rect = rectify(rec) assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces())) check_dumping(rec_rect) shutil.rmtree('test')
def test_cell_explorer_extractor(self): sorter_id = "cell_explorer_sorter" cell_explorer_dir = Path(self.test_dir) / sorter_id spikes_matfile_path = cell_explorer_dir / f"{sorter_id}.spikes.cellinfo.mat" se.CellExplorerSortingExtractor.write_sorting( sorting=self.SX, save_path=spikes_matfile_path) SX_cell_explorer = se.CellExplorerSortingExtractor( spikes_matfile_path=spikes_matfile_path) check_sorting_return_types(SX_cell_explorer) check_sortings_equal(self.SX, SX_cell_explorer) check_dumping(SX_cell_explorer)
def test_hdsort_extractor(self): path = self.test_dir + '/results_test_hdsort_extractor.mat' locations = np.ones((10, 2)) se.HDSortSortingExtractor.write_sorting(self.SX, path, locations=locations, noise_std_by_channel=None) SX_hd = se.HDSortSortingExtractor(path) check_sorting_return_types(SX_hd) check_sortings_equal(self.SX, SX_hd) check_dumping(SX_hd)
def test_blank_saturation(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) threshold = 2 rec_bs = blank_saturation(rec, threshold=threshold) index_below_threshold = np.where(rec.get_traces() < threshold) assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold) check_dumping(rec_bs) shutil.rmtree('test')
def test_curation_sorting_extractor(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=3, seed=0) # Dummy features for testing merging and splitting of features sort.set_unit_spike_features( 1, 'f_int', range(0 + 1, len(sort.get_unit_spike_train(1)) + 1)) sort.set_unit_spike_features(2, 'f_int', range(0, len(sort.get_unit_spike_train(2)))) sort.set_unit_spike_features( 2, 'bad_features', np.repeat(1, len(sort.get_unit_spike_train(2)))) sort.set_unit_spike_features(3, 'f_int', range(0, len(sort.get_unit_spike_train(3)))) CSX = st.curation.CurationSortingExtractor(parent_sorting=sort) merged_unit_id = CSX.merge_units(unit_ids=[1, 2]) assert np.allclose(merged_unit_id, 4) original_spike_train = np.concatenate( (sort.get_unit_spike_train(1), sort.get_unit_spike_train(2))) indices_sort = np.argsort(original_spike_train) original_spike_train = original_spike_train[indices_sort] original_features = np.concatenate( (sort.get_unit_spike_features(1, 'f_int'), sort.get_unit_spike_features(2, 'f_int'))) original_features = original_features[indices_sort] assert np.allclose(CSX.get_unit_spike_train(4), original_spike_train) assert np.allclose(CSX.get_unit_spike_features(4, 'f_int'), original_features) assert CSX.get_unit_spike_feature_names(4) == ['f_int'] assert np.allclose(CSX.get_sampling_frequency(), sort.get_sampling_frequency()) unit_ids_split = CSX.split_unit(unit_id=3, indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) assert np.allclose(unit_ids_split[0], 5) assert np.allclose(unit_ids_split[1], 6) original_spike_train = sort.get_unit_spike_train(3) original_features = sort.get_unit_spike_features(3, 'f_int') split_spike_train_1 = CSX.get_unit_spike_train(5) split_spike_train_2 = CSX.get_unit_spike_train(6) split_features_1 = CSX.get_unit_spike_features(5, 'f_int') split_features_2 = CSX.get_unit_spike_features(6, 'f_int') assert np.allclose(original_spike_train[:10], split_spike_train_1) assert np.allclose(original_spike_train[10:], split_spike_train_2) assert np.allclose(original_features[:10], split_features_1) assert np.allclose(original_features[10:], split_features_2) check_dumping(CSX)
def test_center(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_c = center(rec, mode='mean') assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) rec_c = center(rec, mode='median') assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001) check_dumping(rec_c) shutil.rmtree('test')
def test_mda_extractor(self): path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) check_recording_return_types(RX_mda) check_recordings_equal(self.RX, RX_mda) check_sorting_return_types(SX_mda) check_sortings_equal(self.SX, SX_mda) check_dumping(RX_mda) check_dumping(SX_mda)
def test_highpass_filter(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft') assert check_signal_power_signal1_below_signal2(rec_fft.get_traces(), rec.get_traces(), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) assert check_signal_power_signal1_below_signal2(rec_fft.get_traces(end_frame=30000), rec.get_traces(end_frame=30000), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) check_dumping(rec_fft) shutil.rmtree('test')
def test_notch_filter(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_n = notch_filter(rec, 3000, q=10) assert check_signal_power_signal1_below_signal2(rec_n.get_traces(), rec.get_traces(), freq_range=[2900, 3100], fs=rec.get_sampling_frequency()) assert check_signal_power_signal1_below_signal2(rec_n.get_traces(end_frame=30000), rec.get_traces(end_frame=30000), freq_range=[2900, 3100], fs=rec.get_sampling_frequency()) check_dumping(rec_n) shutil.rmtree('test')
def test_clip(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) threshold = 5 rec_clip = clip(rec, a_min=-threshold, a_max=threshold) index_below_threshold = np.where(rec.get_traces() < -threshold) index_above_threshold = np.where(rec.get_traces() > threshold) assert np.all(rec_clip.get_traces()[index_below_threshold] == -threshold) assert np.all(rec_clip.get_traces()[index_above_threshold] == threshold) check_dumping(rec_clip) shutil.rmtree('test')
def test_npz_extractor(self): path = self.test_dir + '/sorting.npz' se.NpzSortingExtractor.write_sorting(self.SX, path) SX_npz = se.NpzSortingExtractor(path) # empty write sorting_empty = se.NumpySortingExtractor() path_empty = self.test_dir + '/sorting_empty.npz' se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty) check_sorting_return_types(SX_npz) check_sortings_equal(self.SX, SX_npz) check_dumping(SX_npz)
def test_whiten(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0) rec_w = whiten(rec) cov_w = np.cov(rec_w.get_traces()) assert np.allclose(cov_w, np.eye(4), atol=0.3) # should size should not affect rec_w2 = whiten(rec, chunk_size=30000) assert np.array_equal(rec_w.get_traces(), rec_w2.get_traces()) check_dumping(rec_w) shutil.rmtree('test')
def test_remove_bad_channels(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_rm = remove_bad_channels(rec, bad_channel_ids=[0]) assert 0 not in rec_rm.get_channel_ids() rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2]) assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test') timeseries = np.random.randn(4, 60000) timeseries[1] = 10 * timeseries[1] rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000) rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2))) se.MdaRecordingExtractor.write_recording(rec_np, 'test') rec = se.MdaRecordingExtractor('test') rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10) assert 1 not in rec_rm.get_channel_ids() check_dumping(rec_rm) shutil.rmtree('test')
def test_cache_extractor(self): cache_rec = se.CacheRecordingExtractor(self.RX) check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) cache_rec.move_to('cache_rec') assert cache_rec.filename == 'cache_rec.dat' check_dumping(cache_rec, test_relative=True) cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2') check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) assert cache_rec.filename == 'cache_rec2.dat' check_dumping(cache_rec, test_relative=True) # test saving to file del cache_rec assert Path('cache_rec2.dat').is_file() # test tmp cache_rec = se.CacheRecordingExtractor(self.RX) tmp_file = cache_rec.filename del cache_rec assert not Path(tmp_file).is_file() cache_sort = se.CacheSortingExtractor(self.SX) check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) cache_sort.move_to('cache_sort') assert cache_sort.filename == 'cache_sort.npz' check_dumping(cache_sort, test_relative=True) # test saving to file del cache_sort assert Path('cache_sort.npz').is_file() cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2') check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) assert cache_sort.filename == 'cache_sort2.npz' check_dumping(cache_sort, test_relative=True) # test saving to file del cache_sort assert Path('cache_sort2.npz').is_file() # test tmp cache_sort = se.CacheSortingExtractor(self.SX) tmp_file = cache_sort.filename del cache_sort assert not Path(tmp_file).is_file() # cleanup os.remove('cache_rec.dat') os.remove('cache_rec2.dat') os.remove('cache_sort.npz') os.remove('cache_sort2.npz')
def test_exdir_extractors(self): path1 = self.test_dir + '/raw.exdir' se.ExdirRecordingExtractor.write_recording(self.RX, path1) RX_exdir = se.ExdirRecordingExtractor(path1) check_recording_return_types(RX_exdir) check_recordings_equal(self.RX, RX_exdir) check_dumping(RX_exdir) path2 = self.test_dir + '/firings.exdir' se.ExdirSortingExtractor.write_sorting(self.SX, path2, self.RX) SX_exdir = se.ExdirSortingExtractor(path2) check_sorting_return_types(SX_exdir) check_sortings_equal(self.SX, SX_exdir) check_dumping(SX_exdir)
def test_highpass_filter(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0) rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft') assert check_signal_power_signal1_below_signal2( rec_fft.get_traces(), rec.get_traces(), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) rec_sci = bandpass_filter(rec, freq_min=3000, freq_max=6000, filter_type='butter', order=3) assert check_signal_power_signal1_below_signal2( rec_sci.get_traces(), rec.get_traces(), freq_range=[1000, 3000], fs=rec.get_sampling_frequency()) traces = rec.get_traces().astype('uint16') rec_u = se.NumpyRecordingExtractor( traces, sampling_frequency=rec.get_sampling_frequency()) rec_fu = bandpass_filter(rec_u, freq_min=5000, freq_max=10000, filter_type='fft') assert check_signal_power_signal1_below_signal2( rec_fu.get_traces(), rec_u.get_traces(), freq_range=[1000, 5000], fs=rec.get_sampling_frequency()) assert check_signal_power_signal1_below_signal2( rec_fu.get_traces(), rec_u.get_traces(), freq_range=[10000, 15000], fs=rec.get_sampling_frequency()) assert not str(rec_fu.get_dtype()).startswith('u') check_dumping(rec_fft) shutil.rmtree('test')
def test_thresh_num_spikes(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 25 sort_ns = threshold_num_spikes(sort, s_threshold, 'less') new_ns = compute_num_spikes(sort_ns, sort.get_sampling_frequency()) assert np.all(new_ns >= s_threshold) check_dumping(sort_ns) shutil.rmtree('test')
def test_thresh_frs(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) fr_thresh = 2 sort_fr = threshold_firing_rates(sort, fr_thresh, 'less', rec.get_num_frames()) new_fr = compute_firing_rates(sort_fr, rec.get_num_frames()) assert np.all(new_fr >= fr_thresh) check_dumping(sort_fr) shutil.rmtree('test')
def test_mearec_extractors(self): path1 = self.test_dir + '/raw.h5' se.MEArecRecordingExtractor.write_recording(self.RX, path1) RX_mearec = se.MEArecRecordingExtractor(path1) tr = RX_mearec.get_traces(channel_ids=[0, 1], end_frame=1000) check_recording_return_types(RX_mearec) check_recordings_equal(self.RX, RX_mearec) check_dumping(RX_mearec) path2 = self.test_dir + '/firings_true.h5' se.MEArecSortingExtractor.write_sorting( self.SX, path2, self.RX.get_sampling_frequency()) SX_mearec = se.MEArecSortingExtractor(path2) check_sorting_return_types(SX_mearec) check_sortings_equal(self.SX, SX_mearec) check_dumping(SX_mearec)
def test_thresh_presence_ratios(): rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, seed=0) s_threshold = 0.18 sort_pr = threshold_presence_ratios(sort, s_threshold, 'less', rec.get_num_frames()) new_pr = compute_presence_ratios(sort_pr, rec.get_num_frames(), sort.get_sampling_frequency()) assert np.all(new_pr >= s_threshold) check_dumping(sort_pr) shutil.rmtree('test')