コード例 #1
0
def load():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    # c = Controller(l)
    return l
コード例 #2
0
def load():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    c = Processor(l)
    return (l, c)
コード例 #3
0
def test_klusters_save():
    """WARNING: this test should occur at the end of the module since it
    changes the mock data sets."""
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    
    clusters = l.get_clusters()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()
    
    # Set clusters.
    indices = get_indices(clusters)
    l.set_cluster(indices[::2], 2)
    l.set_cluster(indices[1::2], 3)
    
    # Set cluster info.
    cluster_indices = l.get_clusters_unique()
    l.set_cluster_colors(cluster_indices[::2], 10)
    l.set_cluster_colors(cluster_indices[1::2], 20)
    l.set_cluster_groups(cluster_indices[::2], 1)
    l.set_cluster_groups(cluster_indices[1::2], 0)
    
    # Save.
    l.remove_empty_clusters()
    l.save()
    
    clusters = read_clusters(l.filename_aclu)
    cluster_info = read_cluster_info(l.filename_acluinfo)
    
    assert np.all(clusters[::2] == 2)
    assert np.all(clusters[1::2] == 3)
    
    assert np.array_equal(cluster_info.index, cluster_indices)
    assert np.all(cluster_info.values[::2, 0] == 10)
    assert np.all(cluster_info.values[1::2, 0] == 20)
    assert np.all(cluster_info.values[::2, 1] == 1)
    assert np.all(cluster_info.values[1::2, 1] == 0)

    l.close()
    
    
    
コード例 #4
0
def test_klusters_loader_control():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    
    # Take all spikes in cluster 3.
    spikes = get_indices(l.get_clusters(clusters=3))
    
    # Put them in cluster 4.
    l.set_cluster(spikes, 4)
    spikes_new = get_indices(l.get_clusters(clusters=4))
    
    # Ensure all spikes in old cluster 3 are now in cluster 4.
    assert np.all(np.in1d(spikes, spikes_new))
    
    # Change cluster groups.
    clusters = [2, 3, 4]
    group = 0
    l.set_cluster_groups(clusters, group)
    groups = l.get_cluster_groups(clusters)
    assert np.all(groups == group)
    
    # Change cluster colors.
    clusters = [2, 3, 4]
    color = 12
    l.set_cluster_colors(clusters, color)
    colors = l.get_cluster_colors(clusters)
    assert np.all(colors == color)
    
    # Change group name.
    group = 0
    name = l.get_group_names(group)
    name_new = 'Noise new'
    assert name == 'Noise'
    l.set_group_names(group, name_new)
    assert l.get_group_names(group) == name_new
    
    # Change group color.
    groups = [1, 2]
    colors = l.get_group_colors(groups)
    color_new = 10
    l.set_group_colors(groups, color_new)
    assert np.all(l.get_group_colors(groups) == color_new)
    
    # Add cluster and group.
    spikes = get_indices(l.get_clusters(clusters=3))[:10]
    # Create new group 100.
    l.add_group(100, 'New group', 10)
    # Create new cluster 10000 and put it in group 100.
    l.add_cluster(10000, 100, 10)
    # Put some spikes in the new cluster.
    l.set_cluster(spikes, 10000)
    clusters = l.get_clusters(spikes=spikes)
    assert np.all(clusters == 10000)
    groups = l.get_cluster_groups(10000)
    assert groups == 100
    l.set_cluster(spikes, 2)
    
    # Remove the new cluster and group.
    l.remove_cluster(10000)
    l.remove_group(100)
    assert np.all(~np.in1d(10000, l.get_clusters()))
    assert np.all(~np.in1d(100, l.get_cluster_groups()))
    
    l.close()
コード例 #5
0
def test_klusters_loader_2():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    
    # Get full data sets.
    features = l.get_features()
    masks = l.get_masks()
    waveforms = l.get_waveforms()
    clusters = l.get_clusters()
    spiketimes = l.get_spiketimes()
    nclusters = len(Counter(clusters))
    
    probe = l.get_probe()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()
    cluster_sizes = l.get_cluster_sizes()
    
    
    # Check selection.
    # ----------------
    index = nspikes / 2
    waveform = select(waveforms, index)
    cluster = clusters[index]
    spikes_in_cluster = np.nonzero(clusters == cluster)[0]
    nspikes_in_cluster = len(spikes_in_cluster)
    l.select(clusters=[cluster])
    
    
    # Check the size of the selected data.
    # ------------------------------------
    assert check_shape(l.get_features(), (nspikes_in_cluster, 
                                          nchannels * fetdim + 1))
    assert check_shape(l.get_masks(full=True), (nspikes_in_cluster, 
                                                nchannels * fetdim + 1))
    assert check_shape(l.get_waveforms(), 
                       (nspikes_in_cluster, nsamples, nchannels))
    assert check_shape(l.get_clusters(), (nspikes_in_cluster,))
    assert check_shape(l.get_spiketimes(), (nspikes_in_cluster,))
    
    
    # Check waveform sub selection.
    # -----------------------------
    waveforms_selected = l.get_waveforms()
    assert np.array_equal(get_array(select(waveforms_selected, index)), 
        get_array(waveform))
        
    l.close()
コード例 #6
0
def test_klusters_loader_1():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)
    
    # Get full data sets.
    features = l.get_features()
    # features_some = l.get_some_features()
    masks = l.get_masks()
    waveforms = l.get_waveforms()
    clusters = l.get_clusters()
    spiketimes = l.get_spiketimes()
    nclusters = len(Counter(clusters))
    
    probe = l.get_probe()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()
    cluster_sizes = l.get_cluster_sizes()
    
    # Check the shape of the data sets.
    # ---------------------------------
    assert check_shape(features, (nspikes, nchannels * fetdim + 1))
    # assert features_some.shape[1] == nchannels * fetdim + 1
    assert check_shape(masks, (nspikes, nchannels))
    assert check_shape(waveforms, (nspikes, nsamples, nchannels))
    assert check_shape(clusters, (nspikes,))
    assert check_shape(spiketimes, (nspikes,))
    
    assert check_shape(probe, (nchannels, 2))
    assert check_shape(cluster_colors, (nclusters,))
    assert check_shape(cluster_groups, (nclusters,))
    assert check_shape(group_colors, (4,))
    assert check_shape(group_names, (4,))
    assert check_shape(cluster_sizes, (nclusters,))
    
    
    # Check the data type of the data sets.
    # -------------------------------------
    assert check_dtype(features, np.float32)
    assert check_dtype(masks, np.float32)
    # HACK: Panel has no dtype(s) attribute
    # assert check_dtype(waveforms, np.float32)
    assert check_dtype(clusters, np.int32)
    assert check_dtype(spiketimes, np.float32)
    
    assert check_dtype(probe, np.float32)
    assert check_dtype(cluster_colors, np.int32)
    assert check_dtype(cluster_groups, np.int32)
    assert check_dtype(group_colors, np.int32)
    assert check_dtype(group_names, object)
    assert check_dtype(cluster_sizes, np.int32)
    
    l.close()
コード例 #7
0
def test_hdf5_loader1():
    # Open the mock data.
    dir = TEST_FOLDER
    filename = os.path.join(dir, 'test.xml')
    
    global nspikes
    nspikes_total = nspikes
    
    # Convert in HDF5.
    klusters_to_hdf5(filename)
        
    # Open the file.
    filename_h5 = os.path.join(dir, 'test.kwik')
    
    l = HDF5Loader(filename=filename_h5)
    lk = KlustersLoader(filename=filename)
    
    # Open probe.
    probe = l.get_probe()
    assert np.array_equal(probe[1]['channels'], np.arange(nchannels))
    
    # Select cluster.
    cluster = 3
    l.select(clusters=[cluster])
    lk.select(clusters=[cluster])
    
    # Get clusters.
    clusters = l.get_clusters('all')
    clusters_k = lk.get_clusters('all')
    nspikes = np.sum(clusters == cluster)
    
    # Check the clusters are correct.
    assert np.array_equal(get_array(clusters), get_array(clusters_k))
    
    # Get the spike times.
    spiketimes = l.get_spiketimes()
    spiketimes_k = lk.get_spiketimes()
    assert np.all(spiketimes <= 60)
    
    # Check the spiketimes are correct.
    assert np.allclose(get_array(spiketimes), get_array(spiketimes_k))
    
    # Get features.
    features = l.get_features()
    spikes = l.get_spikes()
    # Assert the indices in the features Pandas object correspond to the
    # spikes in the selected cluster.
    assert np.array_equal(features.index, spikes)
    # Assert the array has the right number of spikes.
    assert features.shape[0] == nspikes
    assert l.fetdim == fetdim
    assert l.nextrafet == 1
    
    # Get all features.
    features = l.get_features('all')
    features_k = lk.get_features('all')
    assert type(features) == pd.DataFrame
    assert features.shape[0] == nspikes_total
    
    # Check the features are correct.
    f = get_array(features)[:,:-1]
    f_k = get_array(features_k)[:,:-1]
    normalize_inplace(f)
    normalize_inplace(f_k)
    assert np.allclose(f, f_k, atol=1e-5)
    
    
    # Get masks.
    masks = l.get_masks('all')
    masks_k = lk.get_masks('all')
    assert masks.shape[0] == nspikes_total
    
    # Check the masks.
    assert np.allclose(masks.values, masks_k.values, atol=1e-2)
    
    # Get waveforms.
    waveforms = l.get_waveforms().values
    waveforms_k = lk.get_waveforms().values
    assert np.array_equal(waveforms.shape, (nspikes, nsamples, nchannels))
    
    # Check waveforms
    normalize_inplace(waveforms)
    normalize_inplace(waveforms_k)
    assert np.allclose(waveforms, waveforms_k, atol=1e-4)
    
    l.close()
コード例 #8
0
def test_klusters_save():
    """WARNING: this test should occur at the end of the module since it
    changes the mock data sets."""
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)

    clusters = l.get_clusters()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()

    # Set clusters.
    indices = get_indices(clusters)
    l.set_cluster(indices[::2], 2)
    l.set_cluster(indices[1::2], 3)

    # Set cluster info.
    cluster_indices = l.get_clusters_unique()
    l.set_cluster_colors(cluster_indices[::2], 10)
    l.set_cluster_colors(cluster_indices[1::2], 20)
    l.set_cluster_groups(cluster_indices[::2], 1)
    l.set_cluster_groups(cluster_indices[1::2], 0)

    # Save.
    l.remove_empty_clusters()
    l.save()

    clusters = read_clusters(l.filename_aclu)
    cluster_info = read_cluster_info(l.filename_acluinfo)

    assert np.all(clusters[::2] == 2)
    assert np.all(clusters[1::2] == 3)

    assert np.array_equal(cluster_info.index, cluster_indices)
    assert np.all(cluster_info.values[::2, 0] == 10)
    assert np.all(cluster_info.values[1::2, 0] == 20)
    assert np.all(cluster_info.values[::2, 1] == 1)
    assert np.all(cluster_info.values[1::2, 1] == 0)

    l.close()
コード例 #9
0
def test_klusters_loader_control():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)

    # Take all spikes in cluster 3.
    spikes = get_indices(l.get_clusters(clusters=3))

    # Put them in cluster 4.
    l.set_cluster(spikes, 4)
    spikes_new = get_indices(l.get_clusters(clusters=4))

    # Ensure all spikes in old cluster 3 are now in cluster 4.
    assert np.all(np.in1d(spikes, spikes_new))

    # Change cluster groups.
    clusters = [2, 3, 4]
    group = 0
    l.set_cluster_groups(clusters, group)
    groups = l.get_cluster_groups(clusters)
    assert np.all(groups == group)

    # Change cluster colors.
    clusters = [2, 3, 4]
    color = 12
    l.set_cluster_colors(clusters, color)
    colors = l.get_cluster_colors(clusters)
    assert np.all(colors == color)

    # Change group name.
    group = 0
    name = l.get_group_names(group)
    name_new = 'Noise new'
    assert name == 'Noise'
    l.set_group_names(group, name_new)
    assert l.get_group_names(group) == name_new

    # Change group color.
    groups = [1, 2]
    colors = l.get_group_colors(groups)
    color_new = 10
    l.set_group_colors(groups, color_new)
    assert np.all(l.get_group_colors(groups) == color_new)

    # Add cluster and group.
    spikes = get_indices(l.get_clusters(clusters=3))[:10]
    # Create new group 100.
    l.add_group(100, 'New group', 10)
    # Create new cluster 10000 and put it in group 100.
    l.add_cluster(10000, 100, 10)
    # Put some spikes in the new cluster.
    l.set_cluster(spikes, 10000)
    clusters = l.get_clusters(spikes=spikes)
    assert np.all(clusters == 10000)
    groups = l.get_cluster_groups(10000)
    assert groups == 100
    l.set_cluster(spikes, 2)

    # Remove the new cluster and group.
    l.remove_cluster(10000)
    l.remove_group(100)
    assert np.all(~np.in1d(10000, l.get_clusters()))
    assert np.all(~np.in1d(100, l.get_cluster_groups()))

    l.close()
コード例 #10
0
def test_klusters_loader_2():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)

    # Get full data sets.
    features = l.get_features()
    masks = l.get_masks()
    waveforms = l.get_waveforms()
    clusters = l.get_clusters()
    spiketimes = l.get_spiketimes()
    nclusters = len(Counter(clusters))

    probe = l.get_probe()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()
    cluster_sizes = l.get_cluster_sizes()

    # Check selection.
    # ----------------
    index = nspikes / 2
    waveform = select(waveforms, index)
    cluster = clusters[index]
    spikes_in_cluster = np.nonzero(clusters == cluster)[0]
    nspikes_in_cluster = len(spikes_in_cluster)
    l.select(clusters=[cluster])

    # Check the size of the selected data.
    # ------------------------------------
    assert check_shape(l.get_features(),
                       (nspikes_in_cluster, nchannels * fetdim + 1))
    assert check_shape(l.get_masks(full=True),
                       (nspikes_in_cluster, nchannels * fetdim + 1))
    assert check_shape(l.get_waveforms(),
                       (nspikes_in_cluster, nsamples, nchannels))
    assert check_shape(l.get_clusters(), (nspikes_in_cluster, ))
    assert check_shape(l.get_spiketimes(), (nspikes_in_cluster, ))

    # Check waveform sub selection.
    # -----------------------------
    waveforms_selected = l.get_waveforms()
    assert np.array_equal(get_array(select(waveforms_selected, index)),
                          get_array(waveform))

    l.close()
コード例 #11
0
def test_klusters_loader_1():
    # Open the mock data.
    dir = TEST_FOLDER
    xmlfile = os.path.join(dir, 'test.xml')
    l = KlustersLoader(filename=xmlfile)

    # Get full data sets.
    features = l.get_features()
    # features_some = l.get_some_features()
    masks = l.get_masks()
    waveforms = l.get_waveforms()
    clusters = l.get_clusters()
    spiketimes = l.get_spiketimes()
    nclusters = len(Counter(clusters))

    probe = l.get_probe()
    cluster_colors = l.get_cluster_colors()
    cluster_groups = l.get_cluster_groups()
    group_colors = l.get_group_colors()
    group_names = l.get_group_names()
    cluster_sizes = l.get_cluster_sizes()

    # Check the shape of the data sets.
    # ---------------------------------
    assert check_shape(features, (nspikes, nchannels * fetdim + 1))
    # assert features_some.shape[1] == nchannels * fetdim + 1
    assert check_shape(masks, (nspikes, nchannels))
    assert check_shape(waveforms, (nspikes, nsamples, nchannels))
    assert check_shape(clusters, (nspikes, ))
    assert check_shape(spiketimes, (nspikes, ))

    assert check_shape(probe, (nchannels, 2))
    assert check_shape(cluster_colors, (nclusters, ))
    assert check_shape(cluster_groups, (nclusters, ))
    assert check_shape(group_colors, (4, ))
    assert check_shape(group_names, (4, ))
    assert check_shape(cluster_sizes, (nclusters, ))

    # Check the data type of the data sets.
    # -------------------------------------
    assert check_dtype(features, np.float32)
    assert check_dtype(masks, np.float32)
    # HACK: Panel has no dtype(s) attribute
    # assert check_dtype(waveforms, np.float32)
    assert check_dtype(clusters, np.int32)
    assert check_dtype(spiketimes, np.float32)

    assert check_dtype(probe, np.float32)
    assert check_dtype(cluster_colors, np.int32)
    assert check_dtype(cluster_groups, np.int32)
    assert check_dtype(group_colors, np.int32)
    assert check_dtype(group_names, object)
    assert check_dtype(cluster_sizes, np.int32)

    l.close()