Ejemplo n.º 1
0
 def graph_fn():
   pool = box_list.BoxList(
       tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
                    [0.6, 0.6, 0.8, 0.8], [0.2, 0.2, 0.3, 0.3]], tf.float32))
   pool.add_field('classes', tf.constant([0, 0, 1, 1]))
   pool.add_field('scores', tf.constant([0.75, 0.25, 0.3, 0.2]))
   averaged_boxes = box_list_ops.refine_boxes_multi_class(pool, 3, 0.5, 10)
   return (averaged_boxes.get(), averaged_boxes.get_field('scores'),
           averaged_boxes.get_field('classes'))
Ejemplo n.º 2
0
  def test_refine_boxes_multi_class(self):
    pool = box_list.BoxList(
        tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
                     [0.6, 0.6, 0.8, 0.8], [0.2, 0.2, 0.3, 0.3]], tf.float32))
    pool.add_field('classes', tf.constant([0, 0, 1, 1]))
    pool.add_field('scores', tf.constant([0.75, 0.25, 0.3, 0.2]))
    refined_boxes = box_list_ops.refine_boxes_multi_class(pool, 3, 0.5, 10)

    expected_boxes = [[0.1, 0.1, 0.425, 0.425], [0.6, 0.6, 0.8, 0.8],
                      [0.2, 0.2, 0.3, 0.3]]
    expected_scores = [0.5, 0.3, 0.2]
    with self.test_session() as sess:
      boxes_out, scores_out, extra_field_out = sess.run(
          [refined_boxes.get(), refined_boxes.get_field('scores'),
           refined_boxes.get_field('classes')])

      self.assertAllClose(expected_boxes, boxes_out)
      self.assertAllClose(expected_scores, scores_out)
      self.assertAllEqual(extra_field_out, [0, 1, 1])
Ejemplo n.º 3
0
  def test_refine_boxes_multi_class(self):
    pool = box_list.BoxList(
        tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
                     [0.6, 0.6, 0.8, 0.8], [0.2, 0.2, 0.3, 0.3]], tf.float32))
    pool.add_field('classes', tf.constant([0, 0, 1, 1]))
    pool.add_field('scores', tf.constant([0.75, 0.25, 0.3, 0.2]))
    refined_boxes = box_list_ops.refine_boxes_multi_class(pool, 3, 0.5, 10)

    expected_boxes = [[0.1, 0.1, 0.425, 0.425], [0.6, 0.6, 0.8, 0.8],
                      [0.2, 0.2, 0.3, 0.3]]
    expected_scores = [0.5, 0.3, 0.2]
    with self.test_session() as sess:
      boxes_out, scores_out, extra_field_out = sess.run(
          [refined_boxes.get(), refined_boxes.get_field('scores'),
           refined_boxes.get_field('classes')])

      self.assertAllClose(expected_boxes, boxes_out)
      self.assertAllClose(expected_scores, scores_out)
      self.assertAllEqual(extra_field_out, [0, 1, 1])