Exemple #1
0
 def graph_fn():
   corners = tf.constant([[0, 0, 1, 1],
                          [0, 0.1, 1, 1.1],
                          [0, -0.1, 1, 0.9],
                          [0, 10, 1, 11],
                          [0, 10.1, 1, 11.1],
                          [0, 100, 1, 101]], tf.float32)
   boxes = box_list.BoxList(corners)
   boxes.add_field('classes', tf.constant([1, 2, 1, 2, 2, 1]))
   filtered_boxes1 = box_list_ops.filter_field_value_equals(
       boxes, 'classes', 1)
   filtered_boxes2 = box_list_ops.filter_field_value_equals(
       boxes, 'classes', 2)
   return filtered_boxes1.get(), filtered_boxes2.get()
    def test_filter_field_value_equals(self):
        corners = tf.constant(
            [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], [0, 10, 1, 11],
             [0, 10.1, 1, 11.1], [0, 100, 1, 101]], tf.float32)
        boxes = box_list.BoxList(corners)
        boxes.add_field('classes', tf.constant([1, 2, 1, 2, 2, 1]))
        exp_output1 = [[0, 0, 1, 1], [0, -0.1, 1, 0.9], [0, 100, 1, 101]]
        exp_output2 = [[0, 0.1, 1, 1.1], [0, 10, 1, 11], [0, 10.1, 1, 11.1]]

        filtered_boxes1 = box_list_ops.filter_field_value_equals(
            boxes, 'classes', 1)
        filtered_boxes2 = box_list_ops.filter_field_value_equals(
            boxes, 'classes', 2)
        with self.test_session() as sess:
            filtered_output1, filtered_output2 = sess.run(
                [filtered_boxes1.get(),
                 filtered_boxes2.get()])
            self.assertAllClose(filtered_output1, exp_output1)
            self.assertAllClose(filtered_output2, exp_output2)
  def test_filter_field_value_equals(self):
    corners = tf.constant([[0, 0, 1, 1],
                           [0, 0.1, 1, 1.1],
                           [0, -0.1, 1, 0.9],
                           [0, 10, 1, 11],
                           [0, 10.1, 1, 11.1],
                           [0, 100, 1, 101]], tf.float32)
    boxes = box_list.BoxList(corners)
    boxes.add_field('classes', tf.constant([1, 2, 1, 2, 2, 1]))
    exp_output1 = [[0, 0, 1, 1], [0, -0.1, 1, 0.9], [0, 100, 1, 101]]
    exp_output2 = [[0, 0.1, 1, 1.1], [0, 10, 1, 11], [0, 10.1, 1, 11.1]]

    filtered_boxes1 = box_list_ops.filter_field_value_equals(
        boxes, 'classes', 1)
    filtered_boxes2 = box_list_ops.filter_field_value_equals(
        boxes, 'classes', 2)
    with self.test_session() as sess:
      filtered_output1, filtered_output2 = sess.run([filtered_boxes1.get(),
                                                     filtered_boxes2.get()])
      self.assertAllClose(filtered_output1, exp_output1)
      self.assertAllClose(filtered_output2, exp_output2)