示例#1
0
  def test_gather_with_invalid_field_specified(self):
    indices = np.array([2, 0, 1], dtype=int)
    boxlist = self.boxlist

    with self.assertRaises(ValueError):
      np_box_list_ops.gather(boxlist, indices, 'labels')

    with self.assertRaises(ValueError):
      np_box_list_ops.gather(boxlist, indices, ['objectness'])
示例#2
0
  def test_gather_with_fields_specified(self):
    indices = np.array([2, 0, 1], dtype=int)
    boxlist = self.boxlist
    subboxlist = np_box_list_ops.gather(boxlist, indices, ['labels'])

    self.assertFalse(subboxlist.has_field('scores'))

    expected_boxes = np.array([[0.0, 0.0, 20.0, 20.0], [3.0, 4.0, 6.0, 8.0],
                               [14.0, 14.0, 15.0, 15.0]],
                              dtype=float)
    self.assertAllClose(expected_boxes, subboxlist.get())

    expected_labels = np.array([[0, 0, 0, 0, 1], [0, 0, 0, 1, 0],
                                [0, 1, 0, 0, 0]],
                               dtype=int)
    self.assertAllClose(expected_labels, subboxlist.get_field('labels'))
示例#3
0
 def test_gather_with_invalid_multidimensional_indices(self):
   indices = np.array([[0, 1], [1, 2]], dtype=int)
   boxlist = self.boxlist
   with self.assertRaises(ValueError):
     np_box_list_ops.gather(boxlist, indices)
示例#4
0
 def test_gather_with_out_of_range_indices(self):
   indices = np.array([3, 1], dtype=int)
   boxlist = self.boxlist
   with self.assertRaises(ValueError):
     np_box_list_ops.gather(boxlist, indices)