def test_cluster_component():
    base.features.Provide("FeatureSource", DummyFeatureExtractor())

    cluster_comp = components.ClusterAnalyzer("k_means", 2)
    labels = cluster_comp.read_labels()

    ok = (((labels[:n_spikes] == 1).all() & (labels[n_spikes:] == 2).all()) |
          ((labels[:n_spikes] == 2).all() & (labels[n_spikes:] == 1).all()))
    ok_(ok)
def test_cluster_component_relabel():
    base.features.Provide("FeatureSource", RandomFeatures())

    cluster_comp = components.ClusterAnalyzer("k_means", 5)
    labs = cluster_comp.labels
    cluster_comp.delete_cells(1, 2, 3, 4)
    cluster_comp.relabel()

    labels = np.unique(cluster_comp.labels)
    labels.sort()

    ok_((labels == np.array([0, 1])).all())
def test_pipeline_update():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource",
                          components.SpikeDetector(thresh=spike_amp / 2.))
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 2))

    base.features['FeatureSource'].add_feature("P2P")

    cl1 = base.features["ClusterAnalyzer"].labels
    base.features["SignalSource"].update()
    cl2 = base.features["ClusterAnalyzer"].labels
    ok_(~(len(cl1) == len(cl2)))
def test_null_labels_returned_for_truncated_spikes():
    signal_src = DummySignalSource()
    signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5]

    base.features.Provide("SignalSource", signal_src)
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 1))
    base.features['FeatureSource'].add_feature("P2P")
    cl = base.features["ClusterAnalyzer"].labels

    true_labels = np.ones(n_spikes - 2)
    true_labels[-1] = 0

    ok_((cl == true_labels).all())
Exemple #5
0
def test_cluster_component_methods_before_labels_requested():
    # issue #75
    # The followig methods should not fail when called before lebels are
    # externally requested (cluster_labels = None)

    base.features.Provide("FeatureSource", RandomFeatures())
    cluster_comp = components.ClusterAnalyzer("k_means", 2)

    cluster_comp.delete_cells(1)

    cluster_comp.cluster_labels = None
    cluster_comp.delete_spikes([0])

    cluster_comp.cluster_labels = None
    cluster_comp.merge_cells(1, 2)

    cluster_comp.cluster_labels = None
    cluster_comp.relabel()

    ok_(cluster_comp.cluster_labels is not None)
Exemple #6
0
def test_cluster_component_smart_update():
    feature_comp = base.register("FeatureSource", DummyFeatureExtractor())
    cluster_comp = components.ClusterAnalyzer("k_means", 2)
    labs = cluster_comp.labels  # this is a workaround to call _cluster()

    # cluster_comp should NOT recluster when updated, if the number of
    # spikes didn't change
    cluster_comp.delete_cells(1)  # modify lablels
    labels_orig = cluster_comp.labels.copy()

    feature_comp.add_feature('new_feature')
    feature_comp.update()
    labels_new_feat = cluster_comp.labels.copy()

    # cluster_comp should recluster when updated, if the number of
    # spikes DID change
    feature_comp.add_spikes(10)
    feature_comp.update()
    labels_new_spikes = cluster_comp.labels.copy()

    test1 = (labels_orig == labels_new_feat).all()
    test2 = len(labels_new_spikes) != len(labels_new_feat)

    ok_(test1 and test2)
Exemple #7
0
io_filter = components.FilterStack()
base.features.Provide("RawSource", io)
base.features.Provide("EventsOutput", io)
base.features.Provide("SignalSource", io_filter)
base.features.Provide(
    "SpikeMarkerSource",
    MultiChannelSpikeDetector(contact=contact,
                              thresh=thresh,
                              type=type,
                              sp_win=[-0.6, 0.8],
                              resample=10,
                              align=True))
base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.features.Provide("FeatureSource", components.FeatureExtractor())
base.features.Provide("LabelSource",
                      components.ClusterAnalyzer("gmm", n_clusters))

src = base.features['SignalSource']
features = base.features['FeatureSource']
clusters = base.features['LabelSource']
events = base.features['SpikeMarkerSource']
labels = base.features['LabelSource']

browser = components.SpikeBrowserWithLabels()
plot1 = components.PlotFeaturesTimeline()
plot2 = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()

#############################################################
# Add filters here
hd5file = os.environ['DATAPATH'] + hd5file
io_filter = components.PyTablesSource(hd5file, dataset, f_filter=filter_freq)
base.features.Provide("SignalSource", io_filter)
base.features.Provide("EventsOutput", io_filter)
base.features.Provide(
    "SpikeMarkerSource",
    components.SpikeDetector(contact=contact,
                             thresh=thresh,
                             type=type,
                             sp_win=sp_win,
                             resample=10,
                             align=True))
base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.features.Provide("FeatureSource", components.FeatureExtractor())
base.features.Provide("LabelSource", components.ClusterAnalyzer("gmm", 4))

browser = components.SpikeBrowserWithLabels()
plot1 = components.PlotFeaturesTimeline()
plot2 = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()

#############################################################
# Add the features here:

base.features["FeatureSource"].add_feature("P2P")
base.features["FeatureSource"].add_feature("PCs", ncomps=2)

#############################################################
# Run the analysis (this can take a while)
Exemple #9
0
io_filter = components.FilterStack()

base.register("RawSource", io)
base.register("EventsOutput", io)
base.register("SignalSource", io_filter)
base.register(
    "SpikeMarkerSource",
    components.SpikeDetector(contact=contact,
                             thresh=thresh,
                             type=detection_type,
                             sp_win=sp_win,
                             resample=1,
                             align=True))
base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.register("FeatureSource", components.FeatureExtractor())
base.register("LabelSource", components.ClusterAnalyzer("gmm", 4))

browser = components.SpikeBrowser()
feature_plot = components.PlotFeaturesTimeline()
wave_plot = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()

#############################################################
# Add filters here:
base.features["SignalSource"].add_filter("LinearIIR", *filter_freq)

# Add the features here:
base.features["FeatureSource"].add_feature("P2P")
base.features["FeatureSource"].add_feature("PCA", ncomps=2)