示例#1
0
def test_NumpySorting():
    sampling_frequency = 30000

    # empty
    unit_ids = []
    sorting = NumpySorting(sampling_frequency, unit_ids)
    # print(sorting)

    # 2 columns
    times = np.arange(0, 1000, 10)
    labels = np.zeros(times.size, dtype='int64')
    labels[0::3] = 0
    labels[1::3] = 1
    labels[2::3] = 2
    sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency)
    print(sorting)
    assert sorting.get_num_segments() == 1

    sorting = NumpySorting.from_times_labels([times] * 3, [labels] * 3,
                                             sampling_frequency)
    # print(sorting)
    assert sorting.get_num_segments() == 3

    # from other extracrtor
    num_seg = 2
    file_path = 'test_NpzSortingExtractor.npz'
    create_sorting_npz(num_seg, file_path)
    other_sorting = NpzSortingExtractor(file_path)

    sorting = NumpySorting.from_extractor(other_sorting)
示例#2
0
def test_unitsselectionsorting():
    num_seg = 2
    file_path = 'test_BaseSorting.npz'

    create_sorting_npz(num_seg, file_path)

    sorting = NpzSortingExtractor(file_path)
    print(sorting)
    print(sorting.unit_ids)

    sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
    print(sorting2.unit_ids)
    assert np.array_equal(sorting2.unit_ids, [0, 2])

    sorting3 = UnitsSelectionSorting(sorting,
                                     unit_ids=[0, 2],
                                     renamed_unit_ids=['a', 'b'])
    print(sorting3.unit_ids)
    assert np.array_equal(sorting3.unit_ids, ['a', 'b'])

    assert np.array_equal(sorting.get_unit_spike_train(0, segment_index=0),
                          sorting2.get_unit_spike_train(0, segment_index=0))
    assert np.array_equal(sorting.get_unit_spike_train(0, segment_index=0),
                          sorting3.get_unit_spike_train('a', segment_index=0))

    assert np.array_equal(sorting.get_unit_spike_train(2, segment_index=0),
                          sorting2.get_unit_spike_train(2, segment_index=0))
    assert np.array_equal(sorting.get_unit_spike_train(2, segment_index=0),
                          sorting3.get_unit_spike_train('b', segment_index=0))
示例#3
0
def test_unitsaggregationsorting():
    num_seg = 2
    file_path = 'test_BaseSorting.npz'

    create_sorting_npz(num_seg, file_path)

    sorting1 = NpzSortingExtractor(file_path)
    sorting2 = sorting1.clone()
    sorting3 = sorting1.clone()
    print(sorting1)
    num_units = len(sorting1.get_unit_ids())

    # test num units
    sorting_agg = aggregate_units([sorting1, sorting2, sorting3])
    print(sorting_agg)
    assert len(sorting_agg.get_unit_ids()) == 3 * num_units

    # test spike trains
    unit_ids = sorting1.get_unit_ids()

    for seg in range(num_seg):
        spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg)
        spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg)
        spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg)
        assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg))
        assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0],
                                                                            segment_index=seg))
        assert np.allclose(spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2],
                                                                            segment_index=seg))

    # test rename units
    renamed_unit_ids = [f"#Unit {i}" for i in range(3 * num_units)]
    sorting_agg_renamed = aggregate_units([sorting1, sorting2, sorting3], renamed_unit_ids=renamed_unit_ids)
    assert all(unit in renamed_unit_ids for unit in sorting_agg_renamed.get_unit_ids())

    # test properties

    # complete property
    sorting1.set_property("brain_area", ["CA1"]*num_units)
    sorting2.set_property("brain_area", ["CA2"]*num_units)
    sorting3.set_property("brain_area", ["CA3"]*num_units)

    # skip for inconsistency
    sorting1.set_property("template", np.zeros((num_units, 4, 30)))
    sorting1.set_property("template", np.zeros((num_units, 20, 50)))
    sorting1.set_property("template", np.zeros((num_units, 2, 10)))

    # incomplete property
    sorting1.set_property("quality", ["good"]*num_units)
    sorting2.set_property("quality", ["bad"]*num_units)

    sorting_agg_prop = aggregate_units([sorting1, sorting2, sorting3])
    assert "brain_area" in sorting_agg_prop.get_property_keys()
    assert "quality" not in sorting_agg_prop.get_property_keys()
    print(sorting_agg_prop.get_property("brain_area"))
def test_NpzSortingExtractor():
    num_seg = 2
    file_path = 'test_NpzSortingExtractor.npz'

    create_sorting_npz(num_seg, file_path)

    sorting = NpzSortingExtractor(file_path)

    for segment_index in range(num_seg):
        for unit_id in (0, 1, 2):
            st = sorting.get_unit_spike_train(unit_id,
                                              segment_index=segment_index)

    file_path_copy = 'test_NpzSortingExtractor_copy.npz'
    NpzSortingExtractor.write_sorting(sorting, file_path_copy)
    sorting_copy = NpzSortingExtractor(file_path_copy)
示例#5
0
def test_BaseSorting():
    num_seg = 2
    file_path = 'test_BaseSorting.npz'

    create_sorting_npz(num_seg, file_path)

    sorting = NpzSortingExtractor(file_path)
    print(sorting)

    assert sorting.get_num_segments() == 2
    assert sorting.get_num_units() == 3

    # annotations / properties
    sorting.annotate(yep='yop')
    assert sorting.get_annotation('yep') == 'yop'

    sorting.set_property('amplitude', [-20, -40., -55.5])
    values = sorting.get_property('amplitude')
    assert np.all(values == [-20, -40., -55.5])

    # dump/load dict
    d = sorting.to_dict()
    sorting2 = BaseExtractor.from_dict(d)
    sorting3 = load_extractor(d)

    # dump/load json
    sorting.dump_to_json('test_BaseSorting.json')
    sorting2 = BaseExtractor.load('test_BaseSorting.json')
    sorting3 = load_extractor('test_BaseSorting.json')

    # dump/load pickle
    sorting.dump_to_pickle('test_BaseSorting.pkl')
    sorting2 = BaseExtractor.load('test_BaseSorting.pkl')
    sorting3 = load_extractor('test_BaseSorting.pkl')

    # cache
    folder = Path('./my_cache_folder') / 'simple_sorting'
    sorting.save(folder=folder)
    sorting2 = BaseExtractor.load_from_folder(folder)
    # but also possible
    sorting3 = BaseExtractor.load(folder)

    spikes = sorting.get_all_spike_trains()
    # print(spikes)

    spikes = sorting.to_spike_vector()