Beispiel #1
0
 def test_biocam_extractor(self):
     path1 = self.test_dir + '/raw.brw'
     se.BiocamRecordingExtractor.write_recording(self.RX, path1)
     RX_biocam = se.BiocamRecordingExtractor(path1)
     check_recording_return_types(RX_biocam)
     check_recordings_equal(self.RX, RX_biocam)
     check_dumping(RX_biocam)
Beispiel #2
0
def test_remove_artifacts():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    triggers = [15000, 30000]
    ms = 10
    ms_frames = int(ms * rec.get_sampling_frequency() / 1000)

    traces_all_0_clean = rec.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames)
    traces_all_1_clean = rec.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames)

    rec_rmart = remove_artifacts(rec, triggers, ms_before=10, ms_after=10)
    traces_all_0 = rec_rmart.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames)
    traces_short_0 = rec_rmart.get_traces(start_frame=triggers[0] - 10, end_frame=triggers[0] + 10)
    traces_all_1 = rec_rmart.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames)
    traces_short_1 = rec_rmart.get_traces(start_frame=triggers[1] - 10, end_frame=triggers[1] + 10)

    assert not np.any(traces_all_0)
    assert not np.any(traces_all_1)
    assert not np.any(traces_short_0)
    assert not np.any(traces_short_1)

    rec_rmart_lin = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="linear")
    traces_all_0 = rec_rmart_lin.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_lin.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames)
    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)

    rec_rmart_cub = remove_artifacts(rec, triggers, ms_before=10, ms_after=10, mode="cubic")
    traces_all_0 = rec_rmart_cub.get_traces(start_frame=triggers[0] - ms_frames, end_frame=triggers[0] + ms_frames)
    traces_all_1 = rec_rmart_cub.get_traces(start_frame=triggers[1] - ms_frames, end_frame=triggers[1] + ms_frames)

    assert not np.allclose(traces_all_0, traces_all_0_clean)
    assert not np.allclose(traces_all_1, traces_all_1_clean)

    check_dumping(rec_rmart)
    shutil.rmtree('test')
def test_transform():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)

    scalar = 3
    offset = 50

    rec_t = transform(rec, scalar=scalar, offset=offset)
    assert np.allclose(rec_t.get_traces(),
                       scalar * rec.get_traces() + offset,
                       atol=0.001)

    scalars = np.random.randn(4)
    offsets = np.random.randn(4)
    rec_t_arr = transform(rec, scalar=scalars, offset=offsets)
    for (tt, to, s, o) in zip(rec_t_arr.get_traces(), rec.get_traces(),
                              scalars, offsets):
        assert np.allclose(tt, s * to + o, atol=0.001)

    check_dumping(rec_t)
    check_dumping(rec_t_arr)
    shutil.rmtree('test')
Beispiel #4
0
def test_thresh_l_ratios():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    l_ratios_thresh = 0

    l_ratios = compute_l_ratios(sort, rec, apply_filter=False, seed=0)
    sort_l_ratios = threshold_l_ratios(sort,
                                       rec,
                                       l_ratios_thresh,
                                       "less",
                                       apply_filter=False,
                                       seed=0)

    original_ids = sort.get_unit_ids()
    new_l_ratios = []
    for unit in sort_l_ratios.get_unit_ids():
        new_l_ratios.append(l_ratios[original_ids.index(unit)])
    new_l_ratios = np.array(new_l_ratios)
    assert np.all(new_l_ratios >= l_ratios_thresh)
    check_dumping(sort_l_ratios)
    shutil.rmtree('test')
Beispiel #5
0
def test_resample():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)

    resample_rate_low = 0.1 * rec.get_sampling_frequency()
    resample_rate_high = 2 * rec.get_sampling_frequency()

    rec_rsl = resample(rec, resample_rate_low)
    rec_rsh = resample(rec, resample_rate_high)

    assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1)
    assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2)

    # with times

    times = rec.frame_to_time(np.arange(rec.get_num_frames())) - 10
    times[1000:] += 0.5
    rec.set_times(times)

    rec_rsl = resample(rec, resample_rate_low)
    rec_rsh = resample(rec, resample_rate_high)

    assert rec_rsl.get_num_frames() == int(rec.get_num_frames() * 0.1)
    assert rec_rsh.get_num_frames() == int(rec.get_num_frames() * 2)

    check_dumping(rec_rsl)
    check_dumping(rec_rsh)
    shutil.rmtree('test')
Beispiel #6
0
def test_thresh_silhouettes():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    silhouette_thresh = .5

    silhouette = compute_silhouette_scores(sort,
                                           rec,
                                           apply_filter=False,
                                           seed=0)
    sort_silhouette = threshold_silhouette_scores(sort,
                                                  rec,
                                                  silhouette_thresh,
                                                  "less",
                                                  apply_filter=False,
                                                  seed=0)

    original_ids = sort.get_unit_ids()
    new_silhouette = []
    for unit in sort_silhouette.get_unit_ids():
        new_silhouette.append(silhouette[original_ids.index(unit)])
    new_silhouette = np.array(new_silhouette)
    assert np.all(new_silhouette >= silhouette_thresh)
    check_dumping(sort_silhouette)
    shutil.rmtree('test')
Beispiel #7
0
def test_thresh_isolation_distances():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    s_threshold = 200

    iso = compute_isolation_distances(sort, rec, apply_filter=False, seed=0)
    sort_iso = threshold_isolation_distances(sort,
                                             rec,
                                             s_threshold,
                                             'less',
                                             apply_filter=False,
                                             seed=0)

    original_ids = sort.get_unit_ids()
    new_iso = []
    for unit in sort_iso.get_unit_ids():
        new_iso.append(iso[original_ids.index(unit)])
    new_iso = np.array(new_iso)
    assert np.all(new_iso >= s_threshold)
    check_dumping(sort_iso)
    shutil.rmtree('test')
Beispiel #8
0
def test_thresh_noise_overlaps():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)

    noise_thresh = 0.3

    noise_overlaps = compute_noise_overlaps(sort,
                                            rec,
                                            apply_filter=False,
                                            seed=0)
    sort_noise = threshold_noise_overlaps(sort,
                                          rec,
                                          noise_thresh,
                                          'less',
                                          apply_filter=False,
                                          seed=0)

    original_ids = sort.get_unit_ids()
    new_noise = []
    for unit in sort_noise.get_unit_ids():
        new_noise.append(noise_overlaps[original_ids.index(unit)])
    new_noise = np.array(new_noise)
    assert np.all(new_noise >= noise_thresh)
    check_dumping(sort_noise)
    shutil.rmtree('test')
Beispiel #9
0
 def test_spykingcircus_extractor(self):
     path1 = self.test_dir + '/sc'
     se.SpykingCircusSortingExtractor.write_sorting(self.SX, path1)
     SX_spy = se.SpykingCircusSortingExtractor(path1)
     check_sorting_return_types(SX_spy)
     check_sortings_equal(self.SX, SX_spy)
     check_dumping(SX_spy)
Beispiel #10
0
 def test_hs2_extractor(self):
     path1 = self.test_dir + '/firings_true.hdf5'
     se.HS2SortingExtractor.write_sorting(self.SX, path1)
     SX_hs2 = se.HS2SortingExtractor(path1)
     check_sorting_return_types(SX_hs2)
     check_sortings_equal(self.SX, SX_hs2)
     self.assertEqual(SX_hs2.get_sampling_frequency(),
                      self.SX.get_sampling_frequency())
     check_dumping(SX_hs2)
Beispiel #11
0
def test_rectify():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)

    rec_rect = rectify(rec)

    assert np.allclose(rec_rect.get_traces(), np.abs(rec.get_traces()))

    check_dumping(rec_rect)
    shutil.rmtree('test')
Beispiel #12
0
 def test_cell_explorer_extractor(self):
     sorter_id = "cell_explorer_sorter"
     cell_explorer_dir = Path(self.test_dir) / sorter_id
     spikes_matfile_path = cell_explorer_dir / f"{sorter_id}.spikes.cellinfo.mat"
     se.CellExplorerSortingExtractor.write_sorting(
         sorting=self.SX, save_path=spikes_matfile_path)
     SX_cell_explorer = se.CellExplorerSortingExtractor(
         spikes_matfile_path=spikes_matfile_path)
     check_sorting_return_types(SX_cell_explorer)
     check_sortings_equal(self.SX, SX_cell_explorer)
     check_dumping(SX_cell_explorer)
Beispiel #13
0
 def test_hdsort_extractor(self):
     path = self.test_dir + '/results_test_hdsort_extractor.mat'
     locations = np.ones((10, 2))
     se.HDSortSortingExtractor.write_sorting(self.SX,
                                             path,
                                             locations=locations,
                                             noise_std_by_channel=None)
     SX_hd = se.HDSortSortingExtractor(path)
     check_sorting_return_types(SX_hd)
     check_sortings_equal(self.SX, SX_hd)
     check_dumping(SX_hd)
Beispiel #14
0
def test_blank_saturation():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    threshold = 2
    rec_bs = blank_saturation(rec, threshold=threshold)

    index_below_threshold = np.where(rec.get_traces() < threshold)

    assert np.all(rec_bs.get_traces()[index_below_threshold] < threshold)

    check_dumping(rec_bs)
    shutil.rmtree('test')
Beispiel #15
0
def test_curation_sorting_extractor():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=3,
                                                seed=0)

    # Dummy features for testing merging and splitting of features
    sort.set_unit_spike_features(
        1, 'f_int', range(0 + 1,
                          len(sort.get_unit_spike_train(1)) + 1))
    sort.set_unit_spike_features(2, 'f_int',
                                 range(0, len(sort.get_unit_spike_train(2))))
    sort.set_unit_spike_features(
        2, 'bad_features', np.repeat(1, len(sort.get_unit_spike_train(2))))
    sort.set_unit_spike_features(3, 'f_int',
                                 range(0, len(sort.get_unit_spike_train(3))))

    CSX = st.curation.CurationSortingExtractor(parent_sorting=sort)
    merged_unit_id = CSX.merge_units(unit_ids=[1, 2])
    assert np.allclose(merged_unit_id, 4)
    original_spike_train = np.concatenate(
        (sort.get_unit_spike_train(1), sort.get_unit_spike_train(2)))
    indices_sort = np.argsort(original_spike_train)
    original_spike_train = original_spike_train[indices_sort]
    original_features = np.concatenate(
        (sort.get_unit_spike_features(1, 'f_int'),
         sort.get_unit_spike_features(2, 'f_int')))
    original_features = original_features[indices_sort]
    assert np.allclose(CSX.get_unit_spike_train(4), original_spike_train)
    assert np.allclose(CSX.get_unit_spike_features(4, 'f_int'),
                       original_features)
    assert CSX.get_unit_spike_feature_names(4) == ['f_int']
    assert np.allclose(CSX.get_sampling_frequency(),
                       sort.get_sampling_frequency())

    unit_ids_split = CSX.split_unit(unit_id=3,
                                    indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    assert np.allclose(unit_ids_split[0], 5)
    assert np.allclose(unit_ids_split[1], 6)
    original_spike_train = sort.get_unit_spike_train(3)
    original_features = sort.get_unit_spike_features(3, 'f_int')
    split_spike_train_1 = CSX.get_unit_spike_train(5)
    split_spike_train_2 = CSX.get_unit_spike_train(6)
    split_features_1 = CSX.get_unit_spike_features(5, 'f_int')
    split_features_2 = CSX.get_unit_spike_features(6, 'f_int')
    assert np.allclose(original_spike_train[:10], split_spike_train_1)
    assert np.allclose(original_spike_train[10:], split_spike_train_2)
    assert np.allclose(original_features[:10], split_features_1)
    assert np.allclose(original_features[10:], split_features_2)

    check_dumping(CSX)
Beispiel #16
0
def test_center():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)

    rec_c = center(rec, mode='mean')
    assert np.allclose(np.mean(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    rec_c = center(rec, mode='median')
    assert np.allclose(np.median(rec_c.get_traces(), axis=1), 0, atol=0.001)
    check_dumping(rec_c)

    shutil.rmtree('test')
Beispiel #17
0
 def test_mda_extractor(self):
     path1 = self.test_dir + '/mda'
     path2 = path1 + '/firings_true.mda'
     se.MdaRecordingExtractor.write_recording(self.RX, path1)
     se.MdaSortingExtractor.write_sorting(self.SX, path2)
     RX_mda = se.MdaRecordingExtractor(path1)
     SX_mda = se.MdaSortingExtractor(path2)
     check_recording_return_types(RX_mda)
     check_recordings_equal(self.RX, RX_mda)
     check_sorting_return_types(SX_mda)
     check_sortings_equal(self.SX, SX_mda)
     check_dumping(RX_mda)
     check_dumping(SX_mda)
Beispiel #18
0
def test_highpass_filter():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)

    rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft')

    assert check_signal_power_signal1_below_signal2(rec_fft.get_traces(), rec.get_traces(), freq_range=[1000, 5000],
                                                    fs=rec.get_sampling_frequency())
    assert check_signal_power_signal1_below_signal2(rec_fft.get_traces(end_frame=30000),
                                                    rec.get_traces(end_frame=30000), freq_range=[1000, 5000],
                                                    fs=rec.get_sampling_frequency())

    check_dumping(rec_fft)
    shutil.rmtree('test')
Beispiel #19
0
def test_notch_filter():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)

    rec_n = notch_filter(rec, 3000, q=10)

    assert check_signal_power_signal1_below_signal2(rec_n.get_traces(), rec.get_traces(), freq_range=[2900, 3100],
                                                    fs=rec.get_sampling_frequency())
    assert check_signal_power_signal1_below_signal2(rec_n.get_traces(end_frame=30000),
                                                    rec.get_traces(end_frame=30000), freq_range=[2900, 3100],
                                                    fs=rec.get_sampling_frequency())

    check_dumping(rec_n)
    shutil.rmtree('test')
Beispiel #20
0
def test_clip():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    threshold = 5
    rec_clip = clip(rec, a_min=-threshold, a_max=threshold)

    index_below_threshold = np.where(rec.get_traces() < -threshold)
    index_above_threshold = np.where(rec.get_traces() > threshold)

    assert np.all(rec_clip.get_traces()[index_below_threshold] == -threshold)
    assert np.all(rec_clip.get_traces()[index_above_threshold] == threshold)

    check_dumping(rec_clip)
    shutil.rmtree('test')
Beispiel #21
0
    def test_npz_extractor(self):
        path = self.test_dir + '/sorting.npz'
        se.NpzSortingExtractor.write_sorting(self.SX, path)
        SX_npz = se.NpzSortingExtractor(path)

        # empty write
        sorting_empty = se.NumpySortingExtractor()
        path_empty = self.test_dir + '/sorting_empty.npz'
        se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty)

        check_sorting_return_types(SX_npz)
        check_sortings_equal(self.SX, SX_npz)
        check_dumping(SX_npz)
Beispiel #22
0
def test_whiten():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, seed=0)

    rec_w = whiten(rec)
    cov_w = np.cov(rec_w.get_traces())
    assert np.allclose(cov_w, np.eye(4), atol=0.3)

    # should size should not affect
    rec_w2 = whiten(rec, chunk_size=30000)

    assert np.array_equal(rec_w.get_traces(), rec_w2.get_traces())

    check_dumping(rec_w)
    shutil.rmtree('test')
Beispiel #23
0
def test_remove_bad_channels():
    rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=2, num_channels=4, seed=0)
    rec_rm = remove_bad_channels(rec, bad_channel_ids=[0])
    assert 0 not in rec_rm.get_channel_ids()

    rec_rm = remove_bad_channels(rec, bad_channel_ids=[1, 2])
    assert 1 not in rec_rm.get_channel_ids() and 2 not in rec_rm.get_channel_ids()

    check_dumping(rec_rm)
    shutil.rmtree('test')

    timeseries = np.random.randn(4, 60000)
    timeseries[1] = 10 * timeseries[1]

    rec_np = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=30000)
    rec_np.set_channel_locations(np.ones((rec_np.get_num_channels(), 2)))
    se.MdaRecordingExtractor.write_recording(rec_np, 'test')
    rec = se.MdaRecordingExtractor('test')
    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=0.1)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)

    rec_rm = remove_bad_channels(rec, bad_channel_ids=None, bad_threshold=2, seconds=10)
    assert 1 not in rec_rm.get_channel_ids()
    check_dumping(rec_rm)
    shutil.rmtree('test')
Beispiel #24
0
    def test_cache_extractor(self):
        cache_rec = se.CacheRecordingExtractor(self.RX)
        check_recording_return_types(cache_rec)
        check_recordings_equal(self.RX, cache_rec)
        cache_rec.move_to('cache_rec')

        assert cache_rec.filename == 'cache_rec.dat'
        check_dumping(cache_rec, test_relative=True)

        cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2')
        check_recording_return_types(cache_rec)
        check_recordings_equal(self.RX, cache_rec)

        assert cache_rec.filename == 'cache_rec2.dat'
        check_dumping(cache_rec, test_relative=True)

        # test saving to file
        del cache_rec
        assert Path('cache_rec2.dat').is_file()

        # test tmp
        cache_rec = se.CacheRecordingExtractor(self.RX)
        tmp_file = cache_rec.filename
        del cache_rec
        assert not Path(tmp_file).is_file()

        cache_sort = se.CacheSortingExtractor(self.SX)
        check_sorting_return_types(cache_sort)
        check_sortings_equal(self.SX, cache_sort)
        cache_sort.move_to('cache_sort')

        assert cache_sort.filename == 'cache_sort.npz'
        check_dumping(cache_sort, test_relative=True)

        # test saving to file
        del cache_sort
        assert Path('cache_sort.npz').is_file()

        cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2')
        check_sorting_return_types(cache_sort)
        check_sortings_equal(self.SX, cache_sort)

        assert cache_sort.filename == 'cache_sort2.npz'
        check_dumping(cache_sort, test_relative=True)

        # test saving to file
        del cache_sort
        assert Path('cache_sort2.npz').is_file()

        # test tmp
        cache_sort = se.CacheSortingExtractor(self.SX)
        tmp_file = cache_sort.filename
        del cache_sort
        assert not Path(tmp_file).is_file()

        # cleanup
        os.remove('cache_rec.dat')
        os.remove('cache_rec2.dat')
        os.remove('cache_sort.npz')
        os.remove('cache_sort2.npz')
Beispiel #25
0
    def test_exdir_extractors(self):
        path1 = self.test_dir + '/raw.exdir'
        se.ExdirRecordingExtractor.write_recording(self.RX, path1)
        RX_exdir = se.ExdirRecordingExtractor(path1)
        check_recording_return_types(RX_exdir)
        check_recordings_equal(self.RX, RX_exdir)
        check_dumping(RX_exdir)

        path2 = self.test_dir + '/firings.exdir'
        se.ExdirSortingExtractor.write_sorting(self.SX, path2, self.RX)
        SX_exdir = se.ExdirSortingExtractor(path2)
        check_sorting_return_types(SX_exdir)
        check_sortings_equal(self.SX, SX_exdir)
        check_dumping(SX_exdir)
def test_highpass_filter():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=2,
                                                num_channels=4,
                                                seed=0)

    rec_fft = highpass_filter(rec, freq_min=5000, filter_type='fft')

    assert check_signal_power_signal1_below_signal2(
        rec_fft.get_traces(),
        rec.get_traces(),
        freq_range=[1000, 5000],
        fs=rec.get_sampling_frequency())

    rec_sci = bandpass_filter(rec,
                              freq_min=3000,
                              freq_max=6000,
                              filter_type='butter',
                              order=3)

    assert check_signal_power_signal1_below_signal2(
        rec_sci.get_traces(),
        rec.get_traces(),
        freq_range=[1000, 3000],
        fs=rec.get_sampling_frequency())

    traces = rec.get_traces().astype('uint16')
    rec_u = se.NumpyRecordingExtractor(
        traces, sampling_frequency=rec.get_sampling_frequency())
    rec_fu = bandpass_filter(rec_u,
                             freq_min=5000,
                             freq_max=10000,
                             filter_type='fft')

    assert check_signal_power_signal1_below_signal2(
        rec_fu.get_traces(),
        rec_u.get_traces(),
        freq_range=[1000, 5000],
        fs=rec.get_sampling_frequency())
    assert check_signal_power_signal1_below_signal2(
        rec_fu.get_traces(),
        rec_u.get_traces(),
        freq_range=[10000, 15000],
        fs=rec.get_sampling_frequency())
    assert not str(rec_fu.get_dtype()).startswith('u')

    check_dumping(rec_fft)
    shutil.rmtree('test')
Beispiel #27
0
def test_thresh_num_spikes():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    s_threshold = 25

    sort_ns = threshold_num_spikes(sort, s_threshold, 'less')
    new_ns = compute_num_spikes(sort_ns, sort.get_sampling_frequency())

    assert np.all(new_ns >= s_threshold)
    check_dumping(sort_ns)
    shutil.rmtree('test')
Beispiel #28
0
def test_thresh_frs():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    fr_thresh = 2

    sort_fr = threshold_firing_rates(sort, fr_thresh, 'less',
                                     rec.get_num_frames())
    new_fr = compute_firing_rates(sort_fr, rec.get_num_frames())

    assert np.all(new_fr >= fr_thresh)
    check_dumping(sort_fr)
    shutil.rmtree('test')
Beispiel #29
0
    def test_mearec_extractors(self):
        path1 = self.test_dir + '/raw.h5'
        se.MEArecRecordingExtractor.write_recording(self.RX, path1)
        RX_mearec = se.MEArecRecordingExtractor(path1)
        tr = RX_mearec.get_traces(channel_ids=[0, 1], end_frame=1000)
        check_recording_return_types(RX_mearec)
        check_recordings_equal(self.RX, RX_mearec)
        check_dumping(RX_mearec)

        path2 = self.test_dir + '/firings_true.h5'
        se.MEArecSortingExtractor.write_sorting(
            self.SX, path2, self.RX.get_sampling_frequency())
        SX_mearec = se.MEArecSortingExtractor(path2)
        check_sorting_return_types(SX_mearec)
        check_sortings_equal(self.SX, SX_mearec)
        check_dumping(SX_mearec)
Beispiel #30
0
def test_thresh_presence_ratios():
    rec, sort = se.example_datasets.toy_example(dump_folder='test',
                                                dumpable=True,
                                                duration=10,
                                                num_channels=4,
                                                K=10,
                                                seed=0)
    s_threshold = 0.18

    sort_pr = threshold_presence_ratios(sort, s_threshold, 'less',
                                        rec.get_num_frames())
    new_pr = compute_presence_ratios(sort_pr, rec.get_num_frames(),
                                     sort.get_sampling_frequency())

    assert np.all(new_pr >= s_threshold)
    check_dumping(sort_pr)
    shutil.rmtree('test')