示例#1
0
    def test_prune_completely_outside_window_with_empty_boxlist(self):
        window = tf.constant([0, 0, 9, 14], tf.float32)
        corners = tf.zeros(shape=[0, 4], dtype=tf.float32)
        boxes = box_list.BoxList(corners)
        boxes.add_field('extra_data', tf.zeros(shape=[0], dtype=tf.int32))
        pruned, keep_indices = box_list_ops.prune_completely_outside_window(
            boxes, window)
        pruned_boxes = pruned.get()
        extra = pruned.get_field('extra_data')

        exp_pruned_boxes = np.zeros(shape=[0, 4], dtype=np.float32)
        exp_extra = np.zeros(shape=[0], dtype=np.int32)
        with self.test_session() as sess:
            pruned_boxes_out, keep_indices_out, extra_out = sess.run(
                [pruned_boxes, keep_indices, extra])
            self.assertAllClose(exp_pruned_boxes, pruned_boxes_out)
            self.assertAllEqual([], keep_indices_out)
            self.assertAllEqual(exp_extra, extra_out)
示例#2
0
 def test_prune_completely_outside_window(self):
     window = tf.constant([0, 0, 9, 14], tf.float32)
     corners = tf.constant([[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0],
                            [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0],
                            [-10.0, -10.0, -9.0, -9.0],
                            [-100.0, -100.0, 300.0, 600.0]])
     boxes = box_list.BoxList(corners)
     boxes.add_field('extra_data',
                     tf.constant([[1], [2], [3], [4], [5], [6]]))
     exp_output = [[5.0, 5.0, 6.0, 6.0], [-1.0, -2.0, 4.0, 5.0],
                   [2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0],
                   [-100.0, -100.0, 300.0, 600.0]]
     pruned, keep_indices = box_list_ops.prune_completely_outside_window(
         boxes, window)
     with self.test_session() as sess:
         pruned_output = sess.run(pruned.get())
         self.assertAllClose(pruned_output, exp_output)
         keep_indices_out = sess.run(keep_indices)
         self.assertAllEqual(keep_indices_out, [0, 1, 2, 3, 5])
         extra_data_out = sess.run(pruned.get_field('extra_data'))
         self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [6]])