def testSimpler(self): t = topn.TopN(1000, shortlist_size=10) t.insert([1], [33.0]) ids, vals = t.get_best(1) with session.Session() as sess: sess.run(tf.initialize_all_variables()) ids_v, vals_v = sess.run([ids, vals]) self.assertListEqual([1], list(ids_v)) self.assertListEqual([33.0], list(vals_v))
def testLotsOfInsertsDescending(self): t = topn.TopN(1000, shortlist_size=10) for i in range(99, 1, -1): t.insert([i], [float(i)]) ids, vals = t.get_best(5) with session.Session() as sess: sess.run(tf.initialize_all_variables()) ids_v, vals_v = sess.run([ids, vals]) self.assertItemsEqual([95, 96, 97, 98, 99], list(ids_v)) self.assertItemsEqual([95.0, 96.0, 97.0, 98.0, 99.0], list(vals_v))
def testSimple(self): t = topn.TopN(1000, shortlist_size=10) t.insert([1, 2, 3, 4, 5], [1.0, 2.0, 3.0, 4.0, 5.0]) t.remove([4, 5]) ids, vals = t.get_best(2) with session.Session() as sess: sess.run(tf.initialize_all_variables()) ids_v, vals_v = sess.run([ids, vals]) self.assertItemsEqual([2, 3], list(ids_v)) self.assertItemsEqual([2.0, 3.0], list(vals_v))
def testRemoveNotInShortlist(self): t = topn.TopN(1000, shortlist_size=10) for i in range(20): t.insert([i], [float(i)]) t.remove([4, 5]) ids, vals = t.get_best(2) with session.Session() as sess: sess.run(tf.initialize_all_variables()) ids_v, vals_v = sess.run([ids, vals]) self.assertItemsEqual([18.0, 19.0], list(vals_v)) self.assertItemsEqual([18, 19], list(ids_v))
def testNeedToRefreshShortlistInGetBest(self): t = topn.TopN(1000, shortlist_size=10) for i in range(20): t.insert([i], [float(i)]) # Shortlist now has 10 .. 19 t.remove([11, 12, 13, 14, 15, 16, 17, 18, 19]) ids, vals = t.get_best(2) with session.Session() as sess: sess.run(tf.initialize_all_variables()) ids_v, vals_v = sess.run([ids, vals]) self.assertItemsEqual([9, 10], list(ids_v)) self.assertItemsEqual([9.0, 10.0], list(vals_v))