示例#1
0
 def shortlist_insert():
   larger_ids = array_ops.boolean_mask(
       math_ops.to_int64(ids), larger_scores)
   larger_score_values = array_ops.boolean_mask(scores, larger_scores)
   shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
       self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
   u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
   u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
   return control_flow_ops.group(u1, u2)
示例#2
0
 def shortlist_insert():
     larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
     larger_score_values = tf.boolean_mask(scores, larger_scores)
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         self.sl_ids, self.sl_scores, larger_ids, larger_score_values
     )
     u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
     u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
     return tf.group(u1, u2)
示例#3
0
 def shortlist_insert():
     larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
     larger_score_values = tf.boolean_mask(scores, larger_scores)
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         self.sl_ids, self.sl_scores, larger_ids,
         larger_score_values)
     u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
     u2 = tf.scatter_update(self.sl_scores, shortlist_ids,
                            new_scores)
     return tf.group(u1, u2)
示例#4
0
 def testInsertOpIntoEmptyShortlist(self):
   with self.test_session():
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         [0, -1, -1, -1, -1, -1],  # sl_ids
         [-999, -999, -999, -999, -999, -999],  # sl_scores
         [5],
         [33.0]  # new id and score
     )
     self.assertAllEqual([1, 0], shortlist_ids.eval())
     self.assertAllEqual([5, 1], new_ids.eval())
     self.assertAllEqual([33.0, -999], new_scores.eval())
示例#5
0
 def testInsertOpIntoEmptyShortlist(self):
     with self.test_session():
         shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
             [0, -1, -1, -1, -1, -1],  # sl_ids
             [-999, -999, -999, -999, -999, -999],  # sl_scores
             [5],
             [33.0]  # new id and score
         )
         self.assertAllEqual([1, 0], shortlist_ids.eval())
         self.assertAllEqual([5, 1], new_ids.eval())
         self.assertAllEqual([33.0, -999], new_scores.eval())
示例#6
0
 def testInsertOpIntoAlmostFullShortlist(self):
   with self.test_session():
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         [4, 13, -1, 27, 99, 15],  # sl_ids
         [60.0, 87.0, -999, 65.0, 1000.0, 256.0],  # sl_scores
         [5],
         [93.0]  # new id and score
     )
     self.assertAllEqual([2, 0], shortlist_ids.eval())
     self.assertAllEqual([5, 5], new_ids.eval())
     # Shortlist still contains all known scores > 60.0
     self.assertAllEqual([93.0, 60.0], new_scores.eval())
示例#7
0
 def testInsertOpIntoAlmostFullShortlist(self):
     with self.test_session():
         shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
             [4, 13, -1, 27, 99, 15],  # sl_ids
             [60.0, 87.0, -999, 65.0, 1000.0, 256.0],  # sl_scores
             [5],
             [93.0]  # new id and score
         )
         self.assertAllEqual([2, 0], shortlist_ids.eval())
         self.assertAllEqual([5, 5], new_ids.eval())
         # Shortlist still contains all known scores > 60.0
         self.assertAllEqual([93.0, 60.0], new_scores.eval())
示例#8
0
 def shortlist_insert():
     larger_ids = array_ops.boolean_mask(math_ops.to_int64(ids),
                                         larger_scores)
     larger_score_values = array_ops.boolean_mask(
         scores, larger_scores)
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         self.sl_ids, self.sl_scores, larger_ids,
         larger_score_values)
     u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids,
                                   new_ids)
     u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids,
                                   new_scores)
     return control_flow_ops.group(u1, u2)
示例#9
0
 def testInsertOpHard(self):
   with self.test_session():
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         [4, 13, -1, 27, 99, 15],  # sl_ids
         [60.0, 87.0, -999, 65.0, 1000.0, 256.0],  # sl_scores
         [5, 6, 7, 8, 9],
         [61.0, 66.0, 90.0, 100.0, 2000.0]  # new id and score
     )
     # Top 5 scores are: 2000.0, 1000.0, 256.0, 100.0, 90.0
     self.assertAllEqual([2, 3, 1, 0], shortlist_ids.eval())
     self.assertAllEqual([9, 8, 7, 5], new_ids.eval())
     # 87.0 is the highest score we overwrote or didn't insert.
     self.assertAllEqual([2000.0, 100.0, 90.0, 87.0], new_scores.eval())
示例#10
0
 def testInsertOpIntoFullShortlist(self):
   with self.test_session():
     shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
         [5, 13, 44, 27, 99, 15],  # sl_ids
         [60.0, 87.0, 111.0, 65.0, 1000.0, 256.0],  # sl_scores
         [5],
         [93.0]  # new id and score
     )
     self.assertAllEqual([3, 0], shortlist_ids.eval())
     self.assertAllEqual([5, 5], new_ids.eval())
     # We removed a 65.0 from the list, so now we can only claim that
     # it holds all scores > 65.0.
     self.assertAllEqual([93.0, 65.0], new_scores.eval())
示例#11
0
 def testInsertOpHard(self):
     with self.test_session():
         shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
             [4, 13, -1, 27, 99, 15],  # sl_ids
             [60.0, 87.0, -999, 65.0, 1000.0, 256.0],  # sl_scores
             [5, 6, 7, 8, 9],
             [61.0, 66.0, 90.0, 100.0, 2000.0]  # new id and score
         )
         # Top 5 scores are: 2000.0, 1000.0, 256.0, 100.0, 90.0
         self.assertAllEqual([2, 3, 1, 0], shortlist_ids.eval())
         self.assertAllEqual([9, 8, 7, 5], new_ids.eval())
         # 87.0 is the highest score we overwrote or didn't insert.
         self.assertAllEqual([2000.0, 100.0, 90.0, 87.0], new_scores.eval())
示例#12
0
 def testInsertOpIntoFullShortlist(self):
     with self.test_session():
         shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
             [5, 13, 44, 27, 99, 15],  # sl_ids
             [60.0, 87.0, 111.0, 65.0, 1000.0, 256.0],  # sl_scores
             [5],
             [93.0]  # new id and score
         )
         self.assertAllEqual([3, 0], shortlist_ids.eval())
         self.assertAllEqual([5, 5], new_ids.eval())
         # We removed a 65.0 from the list, so now we can only claim that
         # it holds all scores > 65.0.
         self.assertAllEqual([93.0, 65.0], new_scores.eval())