def test_spike_extractor(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) sp_waves = components.SpikeExtractor().spikes mean_wave = sp_waves['data'][:, :, 0].mean(1) time = sp_waves['time'] true_spike = spike_amp * ((time >= 0) & (time < spike_dur)) ok_(np.sum(np.abs(mean_wave - true_spike)) < 0.01 * spike_amp)
def test_truncated_spikes_from_end(): signal_src = DummySignalSource() signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5] base.features.Provide("SignalSource", signal_src) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) sp_waves = components.SpikeExtractor().spikes correct_mask = np.ones(n_spikes - 2).astype(np.bool) correct_mask[-1] = False ok_((sp_waves['is_valid'] == correct_mask).all())
def test_feature_extractor(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") features = feat_comp.features ok_((features['data'] == spike_amp).all())
def test_feature_extractor_hide_features_not_found(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") feat_comp.update() assert_raises(ValueError, feat_comp.hide_features, "NonExistingFeature")
def test_truncated_spikes_from_begin(): signal_src = DummySignalSource() detector = DummySpikeDetector() spt = detector._spt_data['data'] detector._spt_data['data'] = np.insert(spt, 0, 0) base.features.Provide("SignalSource", signal_src) base.features.Provide("SpikeMarkerSource", detector) sp_waves = components.SpikeExtractor().spikes correct_mask = np.ones(n_spikes - 1).astype(np.bool) correct_mask[0] = False ok_((sp_waves['is_valid'] == correct_mask).all())
def test_pipeline_update(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", components.SpikeDetector(thresh=spike_amp / 2.)) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 2)) base.features['FeatureSource'].add_feature("P2P") cl1 = base.features["ClusterAnalyzer"].labels base.features["SignalSource"].update() cl2 = base.features["ClusterAnalyzer"].labels ok_(~(len(cl1) == len(cl2)))
def test_feature_extractor_hide_features(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") feat_comp.add_feature("SpIdx") feat_comp.hide_features("Sp*") feat_comp.update() features = feat_comp.features test1 = features['names'] == ["P2P:Ch0:P2P"] # it's P2P test2 = features['data'].shape[1] == 1 test3 = (features['data'] == spike_amp).all() # with corresponding data ok_(test1 and test2 and test3)
def test_null_labels_returned_for_truncated_spikes(): signal_src = DummySignalSource() signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5] base.features.Provide("SignalSource", signal_src) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 1)) base.features['FeatureSource'].add_feature("P2P") cl = base.features["ClusterAnalyzer"].labels true_labels = np.ones(n_spikes - 2) true_labels[-1] = 0 ok_((cl == true_labels).all())
import os io = ABFSource(file_name, electrodes=[1, 2]) io_filter = components.FilterStack() base.features.Provide("RawSource", io) base.features.Provide("EventsOutput", io) base.features.Provide("SignalSource", io_filter) base.features.Provide( "SpikeMarkerSource", MultiChannelSpikeDetector(contact=contact, thresh=thresh, type=type, sp_win=[-0.6, 0.8], resample=10, align=True)) base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.features.Provide("FeatureSource", components.FeatureExtractor()) base.features.Provide("LabelSource", components.ClusterAnalyzer("gmm", n_clusters)) src = base.features['SignalSource'] features = base.features['FeatureSource'] clusters = base.features['LabelSource'] events = base.features['SpikeMarkerSource'] labels = base.features['LabelSource'] browser = components.SpikeBrowserWithLabels() plot1 = components.PlotFeaturesTimeline() plot2 = components.PlotSpikes() legend = components.Legend() export = components.ExportCells()
from spike_sort.io import neo_filters #################################### # Adjust these fields for your needs sp_win = [-0.6, 0.8] url = 'https://portal.g-node.org/neo/axon/File_axon_1.abf' path = 'file_axon.abf' import urllib urllib.urlretrieve(url, path) io = neo_filters.NeoSource(path) base.register("SignalSource", io) base.register( "SpikeMarkerSource", components.SpikeDetector(contact=0, thresh='auto', type='max', sp_win=sp_win, resample=1, align=True)) base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) browser = components.SpikeBrowser() browser.show()