Beispiel #1
0
def test_probe_merge_1(tempdir):
    out_dir = tempdir / 'merged'

    # Create two identical datasets.
    probe_names = ('probe_left', 'probe_right')
    for name in probe_names:
        (tempdir / name).mkdir(exist_ok=True, parents=True)
        _make_dataset(tempdir / name, param='dense', has_spike_attributes=False)

    subdirs = [tempdir / name for name in probe_names]

    # Merge them.
    m = Merger(subdirs, out_dir)
    single = load_model(tempdir / probe_names[0] / 'params.py')

    # Test the merged dataset.
    merged = m.merge()
    for name in ('n_spikes', 'n_channels', 'n_templates'):
        assert getattr(merged, name) == getattr(single, name) * 2
    assert merged.sample_rate == single.sample_rate
Beispiel #2
0
def test_probe_merge_2(tempdir):
    out_dir = tempdir / 'merged'

    # Create two identical datasets.
    probe_names = ('probe_left', 'probe_right')
    for name in probe_names:
        (tempdir / name).mkdir(exist_ok=True, parents=True)
        _make_dataset(tempdir / name, param='dense', has_spike_attributes=False)
    subdirs = [tempdir / name for name in probe_names]

    # Add small shift in the spike times of the second probe.
    single = load_model(tempdir / probe_names[0] / 'params.py')
    st_path = tempdir / 'probe_right/spike_times.npy'
    np.save(st_path, single.spike_samples + 1)
    # make amplitudes unique and growing so they can serve as key and sorting indices
    single.amplitudes = np.linspace(5, 15, single.n_spikes)
    # single.spike_clusters[single.spike_clusters == 0] = 12
    for m, subdir in enumerate(subdirs):
        np.save(subdir / 'amplitudes.npy', single.amplitudes + 20 * m)
        np.save(subdir / 'spike_clusters.npy', single.spike_clusters)

    # Merge them.
    m = Merger(subdirs, out_dir)
    merged = m.merge()

    # Test the merged dataset.
    for name in ('n_spikes', 'n_channels', 'n_templates'):
        assert getattr(merged, name) == getattr(single, name) * 2
    assert merged.sample_rate == single.sample_rate

    # Check the spikes.
    single = load_model(tempdir / probe_names[0] / 'params.py')

    def test_merged_single(merged, merged_original_amps=None):
        if merged_original_amps is None:
            merged_original_amps = merged.amplitudes
        _, im1, i1 = np.intersect1d(merged_original_amps, single.amplitudes, return_indices=True)
        _, im2, i2 = np.intersect1d(merged_original_amps, single.amplitudes + 20,
                                    return_indices=True)
        # intersection spans the full vector
        assert i1.size + i2.size == merged.amplitudes.size
        # test spikes
        assert np.allclose(merged.spike_times[im1], single.spike_times[i1])
        assert np.allclose(merged.spike_times[im2], single.spike_times[i2] + 4e-5)
        # test clusters
        assert np.allclose(merged.spike_clusters[im2], single.spike_clusters[i2] + 64)
        assert np.allclose(merged.spike_clusters[im1], single.spike_clusters[i1])
        # test templates
        assert np.all(merged.spike_templates[im1] - single.spike_templates[i1] == 0)
        assert np.all(merged.spike_templates[im2] - single.spike_templates[i2] == 64)
        # test probes
        assert np.all(merged.channel_probes == np.r_[single.channel_probes,
                                                     single.channel_probes + 1])
        assert np.all(merged.templates_channels[merged.templates_probes == 0] < single.n_channels)
        assert np.all(merged.templates_channels[merged.templates_probes == 1] >= single.n_channels)
        spike_probes = merged.templates_probes[merged.spike_templates]

        assert np.all(merged_original_amps[spike_probes == 0] <= 15)
        assert np.all(merged_original_amps[spike_probes == 1] >= 20)

        np.all(merged.sparse_templates.data[:64, :, 0:32] == single.sparse_templates.data)

    # Convert into ALF and load.
    alf = EphysAlfCreator(merged).convert(tempdir / 'alf')
    test_merged_single(merged)
    test_merged_single(alf, merged_original_amps=merged.amplitudes)

    # specific test channel ids only for ALF merge dataset: the raw indices are still individual
    # file indices, the merged channel mapping is in `channels._phy_ids.npy`
    chid = np.load(tempdir.joinpath('alf', 'channels.rawInd.npy'))
    assert np.all(chid == np.r_[single.channel_mapping, single.channel_mapping])

    out_files = list(tempdir.joinpath('alf').glob('*.*'))
    cl_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('clusters.') and
                f.name.endswith('.npy')]
    sp_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('spikes.')]
    ch_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('channels.')]
    assert len(set(cl_shape)) == 1
    assert len(set(sp_shape)) == 1
    assert len(set(ch_shape)) == 1
Beispiel #3
0
def test_template_describe(qtbot, tempdir):
    model = load_model(
        _make_dataset(tempdir, param='dense', has_spike_attributes=False))
    with captured_output() as (stdout, stderr):
        template_describe(model.dir_path / 'params.py')
    assert '314' in stdout.getvalue()
Beispiel #4
0
 def _create_dataset(cls, tempdir):
     return _make_dataset(tempdir,
                          param='dense',
                          has_spike_attributes=False)
Beispiel #5
0
 def _create_dataset(cls, tempdir):
     return _make_dataset(tempdir, param='misc')