def test_spikecache_1(): nspikes = 100000 nclusters = 100 nchannels = 8 spike_clusters = np.random.randint(size=nspikes, low=0, high=nclusters) sc = SpikeCache(spike_clusters=spike_clusters, cache_fraction=.1, features_masks=np.zeros((nspikes, 3*nchannels, 2)), waveforms_raw=np.zeros((nspikes, 20, nchannels)), waveforms_filtered=np.zeros((nspikes, 20, nchannels))) ind, fm = sc.load_features_masks(.1) assert len(ind) == nspikes // 100 assert fm.shape[0] == nspikes // 100 ind, fm = sc.load_features_masks(clusters=[10, 20]) assert len(ind) == fm.shape[0] assert np.allclose(ind, np.nonzero(np.in1d(spike_clusters, (10, 20)))[0]) ind, fm = sc.load_features_masks(clusters=[1000]) assert len(ind) == 0 assert len(fm) == 0
def init_cache(self): """Initialize the cache for the features & masks.""" self._spikecache = SpikeCache( # TODO: handle multiple clusterings in the spike cache here spike_clusters=self.clusters.main, features_masks=self.features_masks, waveforms_raw=self.waveforms_raw, waveforms_filtered=self.waveforms_filtered, # TODO: put this value in the parameters cache_fraction=1.,)
def test_spikecache_2(): nspikes = 100000 nclusters = 100 nchannels = 8 spike_clusters = np.random.randint(size=nspikes, low=0, high=nclusters) sc = SpikeCache(spike_clusters=spike_clusters, cache_fraction=.1, features_masks=np.zeros((nspikes, 3*nchannels, 2)), waveforms_raw=np.zeros((nspikes, 20, nchannels)), waveforms_filtered=np.zeros((nspikes, 20, nchannels))) ind, waveforms = sc.load_waveforms(clusters=[10], count=10) assert len(ind) == waveforms.shape[0] assert len(ind) >= 10 ind, waveforms = sc.load_waveforms(clusters=[10, 20], count=10) assert len(ind) == waveforms.shape[0] assert len(ind) >= 20 ind, waveforms = sc.load_waveforms(clusters=[1000], count=10) assert len(ind) == 0