Example #1
0
def test_templates():
    n_wf_samples = 100
    folder = 'test'
    if os.path.isdir(folder):
        shutil.rmtree(folder)
    rec, sort, waveforms, templates, max_chans, amps = create_signal_with_known_waveforms(
        n_waveforms=2, n_channels=4, n_wf_samples=n_wf_samples)
    rec, sort = create_dumpable_extractors_from_existing(folder, rec, sort)
    # get num samples in ms
    ms_cut = n_wf_samples // 2 / rec.get_sampling_frequency() * 1000

    # no group
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=ms_cut,
                              ms_after=ms_cut,
                              save_property_or_features=False,
                              save_wf_as_features=False,
                              recompute_info=True)

    for (t, t_gt) in zip(temp, templates):
        assert np.allclose(t, t_gt, atol=1)
    assert 'template' not in sort.get_shared_unit_property_names()
    assert 'waveforms' not in sort.get_shared_unit_spike_feature_names()

    # change cut ms
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=2,
                              ms_after=2,
                              save_property_or_features=True,
                              recompute_waveforms=True,
                              recompute_info=True)

    for (t, t_gt) in zip(temp, templates):
        _, samples = t.shape
        assert np.allclose(t[:, samples // 2 - n_wf_samples // 2:samples // 2 +
                             n_wf_samples // 2],
                           t_gt,
                           atol=1)
    assert 'template' in sort.get_shared_unit_property_names()
    assert 'waveforms' in sort.get_shared_unit_spike_feature_names()

    # by group
    rec.set_channel_groups([0, 0, 1, 1])
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=ms_cut,
                              ms_after=ms_cut,
                              grouping_property='group',
                              recompute_info=True)

    for (t, t_gt) in zip(temp, templates):
        assert np.allclose(t, t_gt[:2], atol=1) or np.allclose(
            t, t_gt[2:], atol=1)
    shutil.rmtree('test')
Example #2
0
def test_templates():
    n_wf_samples = 100
    rec, sort, waveforms, templates, max_chans, amps = create_signal_with_known_waveforms(
        n_waveforms=2, n_channels=4, n_wf_samples=n_wf_samples)
    # get num samples in ms
    ms_cut = n_wf_samples // 2 / rec.get_sampling_frequency() * 1000

    # no group
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=ms_cut,
                              ms_after=ms_cut,
                              save_as_property=False,
                              save_wf_as_features=False)

    for (t, t_gt) in zip(temp, templates):
        assert np.allclose(t, t_gt, atol=1)
    assert 'template' not in sort.get_shared_unit_property_names()
    assert 'waveforms' not in sort.get_shared_unit_spike_feature_names()

    # change cut ms
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=2,
                              ms_after=2,
                              save_as_property=True,
                              recompute_waveforms=True)

    for (t, t_gt) in zip(temp, templates):
        _, samples = t.shape
        assert np.allclose(t[:, samples // 2 - n_wf_samples // 2:samples // 2 +
                             n_wf_samples // 2],
                           t_gt,
                           atol=1)
    assert 'template' in sort.get_shared_unit_property_names()
    assert 'waveforms' in sort.get_shared_unit_spike_feature_names()

    # by group
    rec.set_channel_groups(rec.get_channel_ids(), [0, 0, 1, 1])
    temp = get_unit_templates(rec,
                              sort,
                              ms_before=ms_cut,
                              ms_after=ms_cut,
                              grouping_property='group',
                              recompute_waveforms=True)

    for (t, t_gt) in zip(temp, templates):
        assert np.allclose(t, t_gt[:2], atol=1) or np.allclose(
            t, t_gt[2:], atol=1)