예제 #1
0
def test_spike_extractor():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())

    sp_waves = components.SpikeExtractor().spikes
    mean_wave = sp_waves['data'][:, :, 0].mean(1)
    time = sp_waves['time']
    true_spike = spike_amp * ((time >= 0) & (time < spike_dur))
    ok_(np.sum(np.abs(mean_wave - true_spike)) < 0.01 * spike_amp)
예제 #2
0
def test_truncated_spikes_from_end():
    signal_src = DummySignalSource()
    signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5]
    base.features.Provide("SignalSource", signal_src)
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    sp_waves = components.SpikeExtractor().spikes

    correct_mask = np.ones(n_spikes - 2).astype(np.bool)
    correct_mask[-1] = False
    ok_((sp_waves['is_valid'] == correct_mask).all())
예제 #3
0
def test_feature_extractor():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    features = feat_comp.features

    ok_((features['data'] == spike_amp).all())
예제 #4
0
def test_feature_extractor_hide_features_not_found():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    feat_comp.update()

    assert_raises(ValueError, feat_comp.hide_features, "NonExistingFeature")
예제 #5
0
def test_truncated_spikes_from_begin():
    signal_src = DummySignalSource()
    detector = DummySpikeDetector()
    spt = detector._spt_data['data']
    detector._spt_data['data'] = np.insert(spt, 0, 0)
    base.features.Provide("SignalSource", signal_src)
    base.features.Provide("SpikeMarkerSource", detector)
    sp_waves = components.SpikeExtractor().spikes

    correct_mask = np.ones(n_spikes - 1).astype(np.bool)
    correct_mask[0] = False
    ok_((sp_waves['is_valid'] == correct_mask).all())
예제 #6
0
def test_pipeline_update():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource",
                          components.SpikeDetector(thresh=spike_amp / 2.))
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 2))

    base.features['FeatureSource'].add_feature("P2P")

    cl1 = base.features["ClusterAnalyzer"].labels
    base.features["SignalSource"].update()
    cl2 = base.features["ClusterAnalyzer"].labels
    ok_(~(len(cl1) == len(cl2)))
예제 #7
0
def test_feature_extractor_hide_features():
    base.features.Provide("SignalSource", DummySignalSource())
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())

    feat_comp = components.FeatureExtractor(normalize=False)
    feat_comp.add_feature("P2P")
    feat_comp.add_feature("SpIdx")
    feat_comp.hide_features("Sp*")
    feat_comp.update()
    features = feat_comp.features

    test1 = features['names'] == ["P2P:Ch0:P2P"]  # it's P2P
    test2 = features['data'].shape[1] == 1
    test3 = (features['data'] == spike_amp).all()  # with corresponding data

    ok_(test1 and test2 and test3)
예제 #8
0
def test_null_labels_returned_for_truncated_spikes():
    signal_src = DummySignalSource()
    signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5]

    base.features.Provide("SignalSource", signal_src)
    base.features.Provide("SpikeMarkerSource", DummySpikeDetector())
    base.features.Provide("SpikeSource", components.SpikeExtractor())
    base.features.Provide("FeatureSource",
                          components.FeatureExtractor(normalize=False))
    base.features.Provide("ClusterAnalyzer",
                          components.ClusterAnalyzer("k_means", 1))
    base.features['FeatureSource'].add_feature("P2P")
    cl = base.features["ClusterAnalyzer"].labels

    true_labels = np.ones(n_spikes - 2)
    true_labels[-1] = 0

    ok_((cl == true_labels).all())
예제 #9
0
import os

io = ABFSource(file_name, electrodes=[1, 2])
io_filter = components.FilterStack()
base.features.Provide("RawSource", io)
base.features.Provide("EventsOutput", io)
base.features.Provide("SignalSource", io_filter)
base.features.Provide(
    "SpikeMarkerSource",
    MultiChannelSpikeDetector(contact=contact,
                              thresh=thresh,
                              type=type,
                              sp_win=[-0.6, 0.8],
                              resample=10,
                              align=True))
base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
base.features.Provide("FeatureSource", components.FeatureExtractor())
base.features.Provide("LabelSource",
                      components.ClusterAnalyzer("gmm", n_clusters))

src = base.features['SignalSource']
features = base.features['FeatureSource']
clusters = base.features['LabelSource']
events = base.features['SpikeMarkerSource']
labels = base.features['LabelSource']

browser = components.SpikeBrowserWithLabels()
plot1 = components.PlotFeaturesTimeline()
plot2 = components.PlotSpikes()
legend = components.Legend()
export = components.ExportCells()
예제 #10
0
from spike_sort.io import neo_filters

####################################
# Adjust these fields for your needs

sp_win = [-0.6, 0.8]

url = 'https://portal.g-node.org/neo/axon/File_axon_1.abf'
path = 'file_axon.abf'

import urllib

urllib.urlretrieve(url, path)

io = neo_filters.NeoSource(path)

base.register("SignalSource", io)
base.register(
    "SpikeMarkerSource",
    components.SpikeDetector(contact=0,
                             thresh='auto',
                             type='max',
                             sp_win=sp_win,
                             resample=1,
                             align=True))
base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win))

browser = components.SpikeBrowser()

browser.show()