def test_get_sign(self): dim = 100 act = 10 gen = Generator(dim, act) signs = [str(i) for i in range(10)] sign_index = TrieSignIndex(gen, vocabulary=signs) for s in signs: self.assertTrue(sign_index.contains(s)) id = sign_index.get_id(s) self.assertTrue(sign_index.contains_id(id)) s2 = sign_index.get_sign(id) self.assertEqual(s,s2)\ #get sign for an id that doesn't exist id = 86 s = sign_index.get_sign(id) self.assertEqual(s,None) self.assertFalse(sign_index.contains_id(id)) self.assertEqual(len(sign_index.sign_trie),len(signs)) self.assertTrue(sign_index.contains_id(len(signs)-1)) self.assertFalse(sign_index.contains_id(len(signs)))
def test_load(self): """ The ids should be the same when the index is loaded back up """ dim = 100 act = 10 gen = Generator(dim, act) signs1 = [str(i) for i in range(1000)] index1 = TrieSignIndex(gen, vocabulary=signs1) filename = "index.hdf5" directory = os.path.dirname(os.path.abspath(__file__)) index_file = directory + "/" + filename self.assertFalse(os.path.exists(index_file)) try: index1.save(index_file) self.assertTrue(os.path.exists(index_file)) index2 = TrieSignIndex.load(index_file) self.assertEqual(len(index2),len(index1)) for sign in signs1: self.assertTrue(index1.contains(sign)) self.assertTrue(index2.contains(sign)) id1 = index1.get_id(sign) id2 = index2.get_id(sign) self.assertEqual(id1,id2) ri1 = index1.get_ri(sign).to_vector() ri2 = index2.get_ri(sign).to_vector() np.testing.assert_array_equal(ri1,ri2) except: raise finally: if os.path.exists(index_file): os.remove(index_file) self.assertFalse(os.path.exists(index_file))
top10f = list(frequencies[0:10]) top10ids = [trie.get(top10w[i]) for i in range(10)] top10w_trie = [trie.restore_key(i) for i in top10ids] print(top10w) print(top10f) print(top10w_trie) ri_gen = Generator(dim=1000, num_active=10) t0 = time.time() sign_index = TrieSignIndex(ri_gen, list(vocabulary[:])) t1 = time.time() print(t1 - t0) print(top10ids) top10w_index = [sign_index.get_sign(i) for i in top10ids] print(top10w_index) #test load top ten print("=============================================") index = TrieSignIndex(generator=ri_gen, vocabulary=top10w) print(top10w) top10ids = [index.get_id(w) for w in top10w] print(top10ids) freq = TrieSignIndex.map_frequencies(top10w, top10f, index) top10freq = [freq[i] for i in top10ids] print(top10freq) h5v.close()