def test_feature_extractor(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") features = feat_comp.features ok_((features['data'] == spike_amp).all())
def test_feature_extractor_hide_features_not_found(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") feat_comp.update() assert_raises(ValueError, feat_comp.hide_features, "NonExistingFeature")
def test_propagate_truncate_to_features(): spike_src = DummySpikeSource() sp_waves = spike_src._sp_waves is_valid = np.ones(sp_waves['data'].shape[1]).astype(bool) spike_src._sp_waves['is_valid'] = is_valid base.features.Provide("SpikeSource", spike_src) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") features = feat_comp.features ok_((features['is_valid'] == is_valid).all())
def test_pipeline_update(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", components.SpikeDetector(thresh=spike_amp / 2.)) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 2)) base.features['FeatureSource'].add_feature("P2P") cl1 = base.features["ClusterAnalyzer"].labels base.features["SignalSource"].update() cl2 = base.features["ClusterAnalyzer"].labels ok_(~(len(cl1) == len(cl2)))
def test_feature_extractor_hide_features(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) feat_comp = components.FeatureExtractor(normalize=False) feat_comp.add_feature("P2P") feat_comp.add_feature("SpIdx") feat_comp.hide_features("Sp*") feat_comp.update() features = feat_comp.features test1 = features['names'] == ["P2P:Ch0:P2P"] # it's P2P test2 = features['data'].shape[1] == 1 test3 = (features['data'] == spike_amp).all() # with corresponding data ok_(test1 and test2 and test3)
def test_null_labels_returned_for_truncated_spikes(): signal_src = DummySignalSource() signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5] base.features.Provide("SignalSource", signal_src) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 1)) base.features['FeatureSource'].add_feature("P2P") cl = base.features["ClusterAnalyzer"].labels true_labels = np.ones(n_spikes - 2) true_labels[-1] = 0 ok_((cl == true_labels).all())
io = ABFSource(file_name, electrodes=[1, 2]) io_filter = components.FilterStack() base.features.Provide("RawSource", io) base.features.Provide("EventsOutput", io) base.features.Provide("SignalSource", io_filter) base.features.Provide( "SpikeMarkerSource", MultiChannelSpikeDetector(contact=contact, thresh=thresh, type=type, sp_win=[-0.6, 0.8], resample=10, align=True)) base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.features.Provide("FeatureSource", components.FeatureExtractor()) base.features.Provide("LabelSource", components.ClusterAnalyzer("gmm", n_clusters)) src = base.features['SignalSource'] features = base.features['FeatureSource'] clusters = base.features['LabelSource'] events = base.features['SpikeMarkerSource'] labels = base.features['LabelSource'] browser = components.SpikeBrowserWithLabels() plot1 = components.PlotFeaturesTimeline() plot2 = components.PlotSpikes() legend = components.Legend() export = components.ExportCells()
io = components.PyTablesSource(hdf5file, dataset) io_filter = components.FilterStack() base.register("RawSource", io) base.register("EventsOutput", io) base.register("SignalSource", io_filter) base.register( "SpikeMarkerSource", components.SpikeDetector(contact=contact, thresh=thresh, type=detection_type, sp_win=sp_win, resample=1, align=True)) base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.register("FeatureSource", components.FeatureExtractor()) base.register("LabelSource", components.ClusterAnalyzer("gmm", 4)) browser = components.SpikeBrowser() feature_plot = components.PlotFeaturesTimeline() wave_plot = components.PlotSpikes() legend = components.Legend() export = components.ExportCells() ############################################################# # Add filters here: base.features["SignalSource"].add_filter("LinearIIR", *filter_freq) # Add the features here: base.features["FeatureSource"].add_feature("P2P") base.features["FeatureSource"].add_feature("PCA", ncomps=2)