def testRemoveAll(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [100, 200, 300, 400, 500],) self.assertAllEqual([1, 2, 3, 4, 5], shortlist_ids.eval()) self.assertAllEqual([0], new_length.eval())
def testRemoveAll(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [100, 200, 300, 400, 500], ) self.assertAllEqual([1, 2, 3, 4, 5], shortlist_ids.eval()) self.assertAllEqual([0], new_length.eval())
def remove(self, ids): """Remove the ids (and their associated scores) from the TopN.""" with tf.control_dependencies(self.last_ops): scatter_op = tf.scatter_update(self.id_to_score, ids, tf.ones_like(ids, dtype=tf.float32) * tf.float32.min) # We assume that removed ids are almost always in the shortlist, # so it makes no sense to hide the Op behind a tf.cond shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove(self.sl_ids, ids) u1 = tf.scatter_update( self.sl_ids, tf.concat_v2([[0], shortlist_ids_to_remove], 0), tf.concat_v2([new_length, tf.ones_like(shortlist_ids_to_remove) * -1], 0), ) u2 = tf.scatter_update( self.sl_scores, shortlist_ids_to_remove, tf.float32.min * tf.ones_like(shortlist_ids_to_remove, dtype=tf.float32), ) self.last_ops = [scatter_op, u1, u2]
def remove(self, ids): """Remove the ids (and their associated scores) from the TopN.""" with tf.control_dependencies(self.last_ops): scatter_op = tf.scatter_update( self.id_to_score, ids, tf.ones_like(ids, dtype=tf.float32) * tf.float32.min) # We assume that removed ids are almost always in the shortlist, # so it makes no sense to hide the Op behind a tf.cond shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove( self.sl_ids, ids) u1 = tf.scatter_update( self.sl_ids, tf.concat_v2([[0], shortlist_ids_to_remove], 0), tf.concat_v2( [new_length, tf.ones_like(shortlist_ids_to_remove) * -1], 0)) u2 = tf.scatter_update( self.sl_scores, shortlist_ids_to_remove, tf.float32.min * tf.ones_like(shortlist_ids_to_remove, dtype=tf.float32)) self.last_ops = [scatter_op, u1, u2]
def testRemoveAllMissing(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [1200, 1400, 600]) self.assertAllEqual([], shortlist_ids.eval()) self.assertAllEqual([5], new_length.eval())
def testRemoveSimple(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [200, 400, 600]) self.assertAllEqual([2, 4], shortlist_ids.eval()) self.assertAllEqual([3], new_length.eval())
def testRemoveAllMissing(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [1200, 1400, 600]) self.assertAllEqual([], shortlist_ids.eval()) self.assertAllEqual([5], new_length.eval())
def testRemoveSimple(self): with self.test_session(): shortlist_ids, new_length = tensor_forest_ops.top_n_remove( [5, 100, 200, 300, 400, 500], [200, 400, 600]) self.assertAllEqual([2, 4], shortlist_ids.eval()) self.assertAllEqual([3], new_length.eval())