Esempio n. 1
0
 def testSimpler(self):
     t = topn.TopN(1000, shortlist_size=10)
     t.insert([1], [33.0])
     ids, vals = t.get_best(1)
     with session.Session() as sess:
         sess.run(tf.initialize_all_variables())
         ids_v, vals_v = sess.run([ids, vals])
         self.assertListEqual([1], list(ids_v))
         self.assertListEqual([33.0], list(vals_v))
Esempio n. 2
0
 def testLotsOfInsertsDescending(self):
     t = topn.TopN(1000, shortlist_size=10)
     for i in range(99, 1, -1):
         t.insert([i], [float(i)])
     ids, vals = t.get_best(5)
     with session.Session() as sess:
         sess.run(tf.initialize_all_variables())
         ids_v, vals_v = sess.run([ids, vals])
         self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v))
         self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
Esempio n. 3
0
 def testSimple(self):
     t = topn.TopN(1000, shortlist_size=10)
     t.insert([1, 2, 3, 4, 5], [1.0, 2.0, 3.0, 4.0, 5.0])
     t.remove([4, 5])
     ids, vals = t.get_best(2)
     with session.Session() as sess:
         sess.run(tf.initialize_all_variables())
         ids_v, vals_v = sess.run([ids, vals])
         self.assertItemsEqual([2, 3], list(ids_v))
         self.assertItemsEqual([2.0, 3.0], list(vals_v))
Esempio n. 4
0
 def testRemoveNotInShortlist(self):
     t = topn.TopN(1000, shortlist_size=10)
     for i in range(20):
         t.insert([i], [float(i)])
     t.remove([4, 5])
     ids, vals = t.get_best(2)
     with session.Session() as sess:
         sess.run(tf.initialize_all_variables())
         ids_v, vals_v = sess.run([ids, vals])
         self.assertItemsEqual([18.0, 19.0], list(vals_v))
         self.assertItemsEqual([18, 19], list(ids_v))
Esempio n. 5
0
 def testNeedToRefreshShortlistInGetBest(self):
     t = topn.TopN(1000, shortlist_size=10)
     for i in range(20):
         t.insert([i], [float(i)])
     # Shortlist now has 10 .. 19
     t.remove([11, 12, 13, 14, 15, 16, 17, 18, 19])
     ids, vals = t.get_best(2)
     with session.Session() as sess:
         sess.run(tf.initialize_all_variables())
         ids_v, vals_v = sess.run([ids, vals])
         self.assertItemsEqual([9, 10], list(ids_v))
         self.assertItemsEqual([9.0, 10.0], list(vals_v))