def test_gather_with_invalid_field(self): corners = tf.constant([4 * [0.0], 4 * [1.0]]) indices = tf.constant([0, 1], tf.int32) weights = tf.constant([[.1], [.3]], tf.float32) boxes = box_list.BoxList(corners) boxes.add_field('weights', weights) with self.assertRaises(ValueError): box_list_ops.gather(boxes, indices, ['foo', 'bar'])
def test_gather_with_invalid_inputs(self): corners = tf.constant( [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]) indices_float32 = tf.constant([0, 2, 4], tf.float32) boxes = box_list.BoxList(corners) with self.assertRaises(ValueError): _ = box_list_ops.gather(boxes, indices_float32) indices_2d = tf.constant([[0, 2, 4]], tf.int32) boxes = box_list.BoxList(corners) with self.assertRaises(ValueError): _ = box_list_ops.gather(boxes, indices_2d)
def test_gather(self): corners = tf.constant( [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]) indices = tf.constant([0, 2, 4], tf.int32) expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]] boxes = box_list.BoxList(corners) subset = box_list_ops.gather(boxes, indices) with self.test_session() as sess: subset_output = sess.run(subset.get()) self.assertAllClose(subset_output, expected_subset)
def test_gather_with_field(self): corners = tf.constant( [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]) indices = tf.constant([0, 2, 4], tf.int32) weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32) expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]] expected_weights = [[.1], [.5], [.9]] boxes = box_list.BoxList(corners) boxes.add_field('weights', weights) subset = box_list_ops.gather(boxes, indices, ['weights']) with self.test_session() as sess: subset_output, weights_output = sess.run( [subset.get(), subset.get_field('weights')]) self.assertAllClose(subset_output, expected_subset) self.assertAllClose(weights_output, expected_weights)
def test_gather_with_dynamic_indexing(self): corners = tf.constant( [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]) weights = tf.constant([.5, .3, .7, .1, .9], tf.float32) indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1]) expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]] expected_weights = [.5, .7, .9] boxes = box_list.BoxList(corners) boxes.add_field('weights', weights) subset = box_list_ops.gather(boxes, indices, ['weights']) with self.test_session() as sess: subset_output, weights_output = sess.run( [subset.get(), subset.get_field('weights')]) self.assertAllClose(subset_output, expected_subset) self.assertAllClose(weights_output, expected_weights)