예제 #1
0
def test_feature_extractor():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    features = feat_comp.features

    ok_((features['data'] == spike_amp).all())
예제 #2
0
def test_feature_extractor_hide_features_not_found():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    feat_comp.update()

    assert_raises(ValueError, feat_comp.hide_features, "NonExistingFeature")
예제 #3
0
def test_propagate_truncate_to_features():

    spike_src = DummySpikeSource()
    sp_waves = spike_src._sp_waves
    is_valid = np.ones(sp_waves['data'].shape[1]).astype(bool)
    spike_src._sp_waves['is_valid'] = is_valid

    base.features.Provide("SpikeSource", spike_src)

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")

    features = feat_comp.features

    ok_((features['is_valid'] == is_valid).all())
예제 #4
0
def test_pipeline_update():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource",
                          components.SpikeDetector(thresh=spike_amp / 2.))
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 2))

    base.features['FeatureSource'].add_feature("P2P")

    cl1 = base.features["ClusterAnalyzer"].labels
    base.features["SignalSource"].update()
    cl2 = base.features["ClusterAnalyzer"].labels
    ok_(~(len(cl1) == len(cl2)))
예제 #5
0
def test_feature_extractor_hide_features():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    feat_comp.add_feature("SpIdx")
    feat_comp.hide_features("Sp*")
    feat_comp.update()
    features = feat_comp.features

    test1 = features['names'] == ["P2P:Ch0:P2P"]  # it's P2P
    test2 = features['data'].shape[1] == 1
    test3 = (features['data'] == spike_amp).all()  # with corresponding data

    ok_(test1 and test2 and test3)
예제 #6
0
def test_null_labels_returned_for_truncated_spikes():
    signal_src = DummySignalSource()
    signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5]

    base.features.Provide("SignalSource", signal_src)
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 1))
    base.features['FeatureSource'].add_feature("P2P")
    cl = base.features["ClusterAnalyzer"].labels

    true_labels = np.ones(n_spikes - 2)
    true_labels[-1] = 0

    ok_((cl == true_labels).all())
예제 #7
0
io = ABFSource(file_name, electrodes=[1, 2])
io_filter = components.FilterStack()
base.features.Provide("RawSource", io)
base.features.Provide("EventsOutput", io)
base.features.Provide("SignalSource", io_filter)
base.features.Provide(
    "SpikeMarkerSource",
    MultiChannelSpikeDetector(contact=contact,
                              thresh=thresh,
                              type=type,
                              sp_win=[-0.6, 0.8],
                              resample=10,
                              align=True))
base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.features.Provide("FeatureSource", components.FeatureExtractor())
base.features.Provide("LabelSource",
                      components.ClusterAnalyzer("gmm", n_clusters))

src = base.features['SignalSource']
features = base.features['FeatureSource']
clusters = base.features['LabelSource']
events = base.features['SpikeMarkerSource']
labels = base.features['LabelSource']

browser = components.SpikeBrowserWithLabels()
plot1 = components.PlotFeaturesTimeline()
plot2 = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()
예제 #8
0
io = components.PyTablesSource(hdf5file, dataset)
io_filter = components.FilterStack()

base.register("RawSource", io)
base.register("EventsOutput", io)
base.register("SignalSource", io_filter)
base.register(
    "SpikeMarkerSource",
    components.SpikeDetector(contact=contact,
                             thresh=thresh,
                             type=detection_type,
                             sp_win=sp_win,
                             resample=1,
                             align=True))
base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.register("FeatureSource", components.FeatureExtractor())
base.register("LabelSource", components.ClusterAnalyzer("gmm", 4))

browser = components.SpikeBrowser()
feature_plot = components.PlotFeaturesTimeline()
wave_plot = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()

#############################################################
# Add filters here:
base.features["SignalSource"].add_filter("LinearIIR", *filter_freq)

# Add the features here:
base.features["FeatureSource"].add_feature("P2P")
base.features["FeatureSource"].add_feature("PCA", ncomps=2)