Exemplo n.º 1
0
    def graph_fn():
      corners = tf.constant([[0, 0, 1, 1],
                             [0, 0.1, 1, 1.1]], tf.float32)
      boxes = box_list.BoxList(corners)

      boxes.add_field('tensor1', tf.constant(tensor1))
      boxes.add_field('tensor2', tf.constant(tensor2))
      new_boxes = box_list.BoxList(tf.constant([[0, 0, 10, 10],
                                                [1, 3, 5, 5]], tf.float32))
      new_boxes = box_list_ops._copy_extra_fields(new_boxes, boxes)
      return new_boxes.get_field('tensor1'), new_boxes.get_field('tensor2')
Exemplo n.º 2
0
 def test_copy_extra_fields(self):
   corners = tf.constant([[0, 0, 1, 1],
                          [0, 0.1, 1, 1.1]], tf.float32)
   boxes = box_list.BoxList(corners)
   tensor1 = np.array([[1], [4]])
   tensor2 = np.array([[1, 1], [2, 2]])
   boxes.add_field('tensor1', tf.constant(tensor1))
   boxes.add_field('tensor2', tf.constant(tensor2))
   new_boxes = box_list.BoxList(tf.constant([[0, 0, 10, 10],
                                             [1, 3, 5, 5]], tf.float32))
   new_boxes = box_list_ops._copy_extra_fields(new_boxes, boxes)
   with self.test_session() as sess:
     self.assertAllClose(tensor1, sess.run(new_boxes.get_field('tensor1')))
     self.assertAllClose(tensor2, sess.run(new_boxes.get_field('tensor2')))
Exemplo n.º 3
0
 def test_copy_extra_fields(self):
   corners = tf.constant([[0, 0, 1, 1],
                          [0, 0.1, 1, 1.1]], tf.float32)
   boxes = box_list.BoxList(corners)
   tensor1 = np.array([[1], [4]])
   tensor2 = np.array([[1, 1], [2, 2]])
   boxes.add_field('tensor1', tf.constant(tensor1))
   boxes.add_field('tensor2', tf.constant(tensor2))
   new_boxes = box_list.BoxList(tf.constant([[0, 0, 10, 10],
                                             [1, 3, 5, 5]], tf.float32))
   new_boxes = box_list_ops._copy_extra_fields(new_boxes, boxes)
   with self.test_session() as sess:
     self.assertAllClose(tensor1, sess.run(new_boxes.get_field('tensor1')))
     self.assertAllClose(tensor2, sess.run(new_boxes.get_field('tensor2')))