def test_preprocess_output_keys(self):
     height, width = (240, 320)
     input_dict = self._get_input_dict(height, width)
     output_dict = preprocessor.preprocess(
         inputs=input_dict,
         images_points_correspondence_fn=self._image_correspondence_fn,
         output_keys=[standard_fields.InputDataFields.camera_image],
         image_preprocess_fn_dic=None)
     self.assertIn(standard_fields.InputDataFields.camera_image,
                   output_dict)
     self.assertEqual(len(output_dict.keys()), 1)
 def test_preprocess_output_shapes(self):
     height, width = (240, 320)
     input_dict = self._get_input_dict(height, width)
     object_keys = preprocessor._OBJECT_KEYS
     output_keys = [
         standard_fields.InputDataFields.camera_intrinsics,
         standard_fields.InputDataFields.camera_rotation_matrix,
         standard_fields.InputDataFields.camera_translation,
         standard_fields.InputDataFields.point_positions,
         standard_fields.InputDataFields.num_valid_points,
         standard_fields.InputDataFields.object_class_points,
         standard_fields.InputDataFields.object_center_points,
         standard_fields.InputDataFields.object_height_points,
         standard_fields.InputDataFields.object_width_points,
         standard_fields.InputDataFields.object_rotation_matrix_points,
         standard_fields.InputDataFields.object_length_points,
         standard_fields.InputDataFields.object_instance_id_points,
     ]
     output_dict = preprocessor.preprocess(
         inputs=input_dict,
         images_points_correspondence_fn=self._image_correspondence_fn,
         image_preprocess_fn_dic=None)
     for key in output_keys:
         self.assertIn(key, output_dict)
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.camera_intrinsics].shape,
         (3, 3))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.camera_rotation_matrix].shape,
         (3, 3))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.camera_translation].shape,
         (3, ))
     self.assertEqual(
         output_dict[standard_fields.InputDataFields.point_positions].shape,
         (100, 3))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.num_valid_points].numpy(), 100)
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.object_class_points].shape,
         (100, ))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.object_center_points].shape,
         (100, 3))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.object_height_points].shape,
         (100, 1))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.object_width_points].shape,
         (100, 1))
     self.assertEqual(
         output_dict[
             standard_fields.InputDataFields.object_length_points].shape,
         (100, 1))
     self.assertEqual(
         output_dict[standard_fields.InputDataFields.
                     object_rotation_matrix_points].shape, (100, 3, 3))
     self.assertEqual(
         output_dict[standard_fields.InputDataFields.
                     object_instance_id_points].shape, (100, ))
     for key in object_keys:
         self.assertEqual(output_dict[key].shape[0], 2)
 def test_preprocess_missing_input_raises(self):
     with self.assertRaises(ValueError):
         empty_input = {}
         preprocessor.preprocess(inputs=empty_input)