def test_cluster_component(): base.features.Provide("FeatureSource", DummyFeatureExtractor()) cluster_comp = components.ClusterAnalyzer("k_means", 2) labels = cluster_comp.read_labels() ok = (((labels[:n_spikes] == 1).all() & (labels[n_spikes:] == 2).all()) | ((labels[:n_spikes] == 2).all() & (labels[n_spikes:] == 1).all())) ok_(ok)
def test_cluster_component_relabel(): base.features.Provide("FeatureSource", RandomFeatures()) cluster_comp = components.ClusterAnalyzer("k_means", 5) labs = cluster_comp.labels cluster_comp.delete_cells(1, 2, 3, 4) cluster_comp.relabel() labels = np.unique(cluster_comp.labels) labels.sort() ok_((labels == np.array([0, 1])).all())
def test_pipeline_update(): base.features.Provide("SignalSource", DummySignalSource()) base.features.Provide("SpikeMarkerSource", components.SpikeDetector(thresh=spike_amp / 2.)) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 2)) base.features['FeatureSource'].add_feature("P2P") cl1 = base.features["ClusterAnalyzer"].labels base.features["SignalSource"].update() cl2 = base.features["ClusterAnalyzer"].labels ok_(~(len(cl1) == len(cl2)))
def test_null_labels_returned_for_truncated_spikes(): signal_src = DummySignalSource() signal_src._spikes = signal_src._spikes[:, :-period / 1000. * FS * 2.5] base.features.Provide("SignalSource", signal_src) base.features.Provide("SpikeMarkerSource", DummySpikeDetector()) base.features.Provide("SpikeSource", components.SpikeExtractor()) base.features.Provide("FeatureSource", components.FeatureExtractor(normalize=False)) base.features.Provide("ClusterAnalyzer", components.ClusterAnalyzer("k_means", 1)) base.features['FeatureSource'].add_feature("P2P") cl = base.features["ClusterAnalyzer"].labels true_labels = np.ones(n_spikes - 2) true_labels[-1] = 0 ok_((cl == true_labels).all())
def test_cluster_component_methods_before_labels_requested(): # issue #75 # The followig methods should not fail when called before lebels are # externally requested (cluster_labels = None) base.features.Provide("FeatureSource", RandomFeatures()) cluster_comp = components.ClusterAnalyzer("k_means", 2) cluster_comp.delete_cells(1) cluster_comp.cluster_labels = None cluster_comp.delete_spikes([0]) cluster_comp.cluster_labels = None cluster_comp.merge_cells(1, 2) cluster_comp.cluster_labels = None cluster_comp.relabel() ok_(cluster_comp.cluster_labels is not None)
def test_cluster_component_smart_update(): feature_comp = base.register("FeatureSource", DummyFeatureExtractor()) cluster_comp = components.ClusterAnalyzer("k_means", 2) labs = cluster_comp.labels # this is a workaround to call _cluster() # cluster_comp should NOT recluster when updated, if the number of # spikes didn't change cluster_comp.delete_cells(1) # modify lablels labels_orig = cluster_comp.labels.copy() feature_comp.add_feature('new_feature') feature_comp.update() labels_new_feat = cluster_comp.labels.copy() # cluster_comp should recluster when updated, if the number of # spikes DID change feature_comp.add_spikes(10) feature_comp.update() labels_new_spikes = cluster_comp.labels.copy() test1 = (labels_orig == labels_new_feat).all() test2 = len(labels_new_spikes) != len(labels_new_feat) ok_(test1 and test2)
io_filter = components.FilterStack() base.features.Provide("RawSource", io) base.features.Provide("EventsOutput", io) base.features.Provide("SignalSource", io_filter) base.features.Provide( "SpikeMarkerSource", MultiChannelSpikeDetector(contact=contact, thresh=thresh, type=type, sp_win=[-0.6, 0.8], resample=10, align=True)) base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.features.Provide("FeatureSource", components.FeatureExtractor()) base.features.Provide("LabelSource", components.ClusterAnalyzer("gmm", n_clusters)) src = base.features['SignalSource'] features = base.features['FeatureSource'] clusters = base.features['LabelSource'] events = base.features['SpikeMarkerSource'] labels = base.features['LabelSource'] browser = components.SpikeBrowserWithLabels() plot1 = components.PlotFeaturesTimeline() plot2 = components.PlotSpikes() legend = components.Legend() export = components.ExportCells() ############################################################# # Add filters here
hd5file = os.environ['DATAPATH'] + hd5file io_filter = components.PyTablesSource(hd5file, dataset, f_filter=filter_freq) base.features.Provide("SignalSource", io_filter) base.features.Provide("EventsOutput", io_filter) base.features.Provide( "SpikeMarkerSource", components.SpikeDetector(contact=contact, thresh=thresh, type=type, sp_win=sp_win, resample=10, align=True)) base.features.Provide("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.features.Provide("FeatureSource", components.FeatureExtractor()) base.features.Provide("LabelSource", components.ClusterAnalyzer("gmm", 4)) browser = components.SpikeBrowserWithLabels() plot1 = components.PlotFeaturesTimeline() plot2 = components.PlotSpikes() legend = components.Legend() export = components.ExportCells() ############################################################# # Add the features here: base.features["FeatureSource"].add_feature("P2P") base.features["FeatureSource"].add_feature("PCs", ncomps=2) ############################################################# # Run the analysis (this can take a while)
io_filter = components.FilterStack() base.register("RawSource", io) base.register("EventsOutput", io) base.register("SignalSource", io_filter) base.register( "SpikeMarkerSource", components.SpikeDetector(contact=contact, thresh=thresh, type=detection_type, sp_win=sp_win, resample=1, align=True)) base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) base.register("FeatureSource", components.FeatureExtractor()) base.register("LabelSource", components.ClusterAnalyzer("gmm", 4)) browser = components.SpikeBrowser() feature_plot = components.PlotFeaturesTimeline() wave_plot = components.PlotSpikes() legend = components.Legend() export = components.ExportCells() ############################################################# # Add filters here: base.features["SignalSource"].add_filter("LinearIIR", *filter_freq) # Add the features here: base.features["FeatureSource"].add_feature("P2P") base.features["FeatureSource"].add_feature("PCA", ncomps=2)