def test_preprocess_output_keys(self): height, width = (240, 320) input_dict = self._get_input_dict(height, width) output_dict = preprocessor.preprocess( inputs=input_dict, images_points_correspondence_fn=self._image_correspondence_fn, output_keys=[standard_fields.InputDataFields.camera_image], image_preprocess_fn_dic=None) self.assertIn(standard_fields.InputDataFields.camera_image, output_dict) self.assertEqual(len(output_dict.keys()), 1)
def test_preprocess_output_shapes(self): height, width = (240, 320) input_dict = self._get_input_dict(height, width) object_keys = preprocessor._OBJECT_KEYS output_keys = [ standard_fields.InputDataFields.camera_intrinsics, standard_fields.InputDataFields.camera_rotation_matrix, standard_fields.InputDataFields.camera_translation, standard_fields.InputDataFields.point_positions, standard_fields.InputDataFields.num_valid_points, standard_fields.InputDataFields.object_class_points, standard_fields.InputDataFields.object_center_points, standard_fields.InputDataFields.object_height_points, standard_fields.InputDataFields.object_width_points, standard_fields.InputDataFields.object_rotation_matrix_points, standard_fields.InputDataFields.object_length_points, standard_fields.InputDataFields.object_instance_id_points, ] output_dict = preprocessor.preprocess( inputs=input_dict, images_points_correspondence_fn=self._image_correspondence_fn, image_preprocess_fn_dic=None) for key in output_keys: self.assertIn(key, output_dict) self.assertEqual( output_dict[ standard_fields.InputDataFields.camera_intrinsics].shape, (3, 3)) self.assertEqual( output_dict[ standard_fields.InputDataFields.camera_rotation_matrix].shape, (3, 3)) self.assertEqual( output_dict[ standard_fields.InputDataFields.camera_translation].shape, (3, )) self.assertEqual( output_dict[standard_fields.InputDataFields.point_positions].shape, (100, 3)) self.assertEqual( output_dict[ standard_fields.InputDataFields.num_valid_points].numpy(), 100) self.assertEqual( output_dict[ standard_fields.InputDataFields.object_class_points].shape, (100, )) self.assertEqual( output_dict[ standard_fields.InputDataFields.object_center_points].shape, (100, 3)) self.assertEqual( output_dict[ standard_fields.InputDataFields.object_height_points].shape, (100, 1)) self.assertEqual( output_dict[ standard_fields.InputDataFields.object_width_points].shape, (100, 1)) self.assertEqual( output_dict[ standard_fields.InputDataFields.object_length_points].shape, (100, 1)) self.assertEqual( output_dict[standard_fields.InputDataFields. object_rotation_matrix_points].shape, (100, 3, 3)) self.assertEqual( output_dict[standard_fields.InputDataFields. object_instance_id_points].shape, (100, )) for key in object_keys: self.assertEqual(output_dict[key].shape[0], 2)
def test_preprocess_missing_input_raises(self): with self.assertRaises(ValueError): empty_input = {} preprocessor.preprocess(inputs=empty_input)