示例#1
0
    def test_gather_with_invalid_field(self):
        corners = tf.constant([4 * [0.0], 4 * [1.0]])
        indices = tf.constant([0, 1], tf.int32)
        weights = tf.constant([[.1], [.3]], tf.float32)

        boxes = box_list.BoxList(corners)
        boxes.add_field('weights', weights)
        with self.assertRaises(ValueError):
            box_list_ops.gather(boxes, indices, ['foo', 'bar'])
示例#2
0
 def test_gather_with_invalid_inputs(self):
     corners = tf.constant(
         [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
     indices_float32 = tf.constant([0, 2, 4], tf.float32)
     boxes = box_list.BoxList(corners)
     with self.assertRaises(ValueError):
         _ = box_list_ops.gather(boxes, indices_float32)
     indices_2d = tf.constant([[0, 2, 4]], tf.int32)
     boxes = box_list.BoxList(corners)
     with self.assertRaises(ValueError):
         _ = box_list_ops.gather(boxes, indices_2d)
示例#3
0
 def test_gather(self):
     corners = tf.constant(
         [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
     indices = tf.constant([0, 2, 4], tf.int32)
     expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
     boxes = box_list.BoxList(corners)
     subset = box_list_ops.gather(boxes, indices)
     with self.test_session() as sess:
         subset_output = sess.run(subset.get())
         self.assertAllClose(subset_output, expected_subset)
示例#4
0
    def test_gather_with_field(self):
        corners = tf.constant(
            [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
        indices = tf.constant([0, 2, 4], tf.int32)
        weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32)
        expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
        expected_weights = [[.1], [.5], [.9]]

        boxes = box_list.BoxList(corners)
        boxes.add_field('weights', weights)
        subset = box_list_ops.gather(boxes, indices, ['weights'])
        with self.test_session() as sess:
            subset_output, weights_output = sess.run(
                [subset.get(), subset.get_field('weights')])
            self.assertAllClose(subset_output, expected_subset)
            self.assertAllClose(weights_output, expected_weights)
示例#5
0
    def test_gather_with_dynamic_indexing(self):
        corners = tf.constant(
            [4 * [0.0],
             4 * [1.0],
             4 * [2.0],
             4 * [3.0],
             4 * [4.0]])
        weights = tf.constant([.5, .3, .7, .1, .9], tf.float32)
        indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1])
        expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
        expected_weights = [.5, .7, .9]

        boxes = box_list.BoxList(corners)
        boxes.add_field('weights', weights)
        subset = box_list_ops.gather(boxes, indices, ['weights'])
        with self.test_session() as sess:
            subset_output, weights_output = sess.run(
                [subset.get(), subset.get_field('weights')])
            self.assertAllClose(subset_output, expected_subset)
            self.assertAllClose(weights_output, expected_weights)