def test_gather_with_invalid_field_specified(self):
        indices = np.array([2, 0, 1], dtype=int)
        boxlist = self.boxlist

        with self.assertRaises(ValueError):
            np_box_list_ops.gather(boxlist, indices, 'labels')

        with self.assertRaises(ValueError):
            np_box_list_ops.gather(boxlist, indices, ['objectness'])
def gather(box_mask_list, indices, fields=None):
    """Gather boxes from np_box_mask_list.BoxMaskList according to indices.

  By default, gather returns boxes corresponding to the input index list, as
  well as all additional fields stored in the box_mask_list (indexing into the
  first dimension).  However one can optionally only gather from a
  subset of fields.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes
    indices: a 1-d numpy array of type int_
    fields: (optional) list of fields to also gather from.  If None (default),
        all fields are gathered from.  Pass an empty fields list to only gather
        the box coordinates.

  Returns:
    subbox_mask_list: a np_box_mask_list.BoxMaskList corresponding to the subset
        of the input box_mask_list specified by indices

  Raises:
    ValueError: if specified field is not contained in box_mask_list or if the
        indices are not of type int_
  """
    if fields is not None:
        if 'masks' not in fields:
            fields.append('masks')
    return box_list_to_box_mask_list(
        np_box_list_ops.gather(boxlist=box_mask_list,
                               indices=indices,
                               fields=fields))
    def test_gather_with_fields_specified(self):
        indices = np.array([2, 0, 1], dtype=int)
        boxlist = self.boxlist
        subboxlist = np_box_list_ops.gather(boxlist, indices, ['labels'])

        self.assertFalse(subboxlist.has_field('scores'))

        expected_boxes = np.array(
            [[0.0, 0.0, 20.0, 20.0], [3.0, 4.0, 6.0, 8.0],
             [14.0, 14.0, 15.0, 15.0]],
            dtype=float)
        self.assertAllClose(expected_boxes, subboxlist.get())

        expected_labels = np.array(
            [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0, 0, 0]], dtype=int)
        self.assertAllClose(expected_labels, subboxlist.get_field('labels'))
 def test_gather_with_invalid_multidimensional_indices(self):
     indices = np.array([[0, 1], [1, 2]], dtype=int)
     boxlist = self.boxlist
     with self.assertRaises(ValueError):
         np_box_list_ops.gather(boxlist, indices)
 def test_gather_with_out_of_range_indices(self):
     indices = np.array([3, 1], dtype=int)
     boxlist = self.boxlist
     with self.assertRaises(ValueError):
         np_box_list_ops.gather(boxlist, indices)