def test_call_train(self): num_classes = 5 loss_fn_box_corner_distance_on_voxel_tensors = functools.partial( box_prediction_losses.box_corner_distance_loss_on_voxel_tensors, is_intermediate=False, loss_type='absolute_difference', is_balanced=True) loss_fn_box_classification_using_center_distance = functools.partial( classification_losses.box_classification_using_center_distance_loss, is_intermediate=False, is_balanced=True, max_positive_normalized_distance=0.3) loss_fn_hard_negative_classification = functools.partial( classification_losses.hard_negative_classification_loss, is_intermediate=False, gamma=1.0) loss_names_to_functions = { 'box_corner_distance_loss_on_voxel_tensors': loss_fn_box_corner_distance_on_voxel_tensors, 'box_classification_using_center_distance_loss': loss_fn_box_classification_using_center_distance, 'hard_negative_classification_loss': loss_fn_hard_negative_classification, } loss_names_to_weights = { 'box_corner_distance_loss_on_voxel_tensors': 5.0, 'box_classification_using_center_distance_loss': 1.0, 'hard_negative_classification_loss': 1.0, } object_detection_model = model.ObjectDetectionModel( loss_names_to_functions=loss_names_to_functions, loss_names_to_weights=loss_names_to_weights, num_classes=num_classes, predict_rotation_x=True, predict_rotation_y=True, predict_rotation_z=True) num_voxels = 100 inputs = self.get_inputs(num_voxels=num_voxels, num_classes=num_classes) outputs = object_detection_model(inputs, training=True) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_semantic_voxels] .get_shape(), (1, num_voxels, num_classes)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_center_voxels] .get_shape(), (1, num_voxels, 3)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_length_voxels] .get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_height_voxels] .get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_width_voxels] .get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_x_cos_voxels].get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_x_sin_voxels].get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_y_cos_voxels].get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_y_sin_voxels].get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_z_cos_voxels].get_shape(), (1, num_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_z_sin_voxels].get_shape(), (1, num_voxels, 1))
def test_call_eval(self): num_classes = 5 object_detection_model = model.ObjectDetectionModel( num_classes=num_classes, predict_rotation_x=True, predict_rotation_y=True, predict_rotation_z=True) num_voxels = 100 inputs = self.get_inputs(num_voxels=num_voxels, num_classes=num_classes) outputs = object_detection_model(inputs, training=False) num_valid_voxels = inputs[ standard_fields.InputDataFields.num_valid_voxels].numpy()[0] self.assertEqual( outputs[standard_fields.DetectionResultFields.object_semantic_voxels] .get_shape(), (num_valid_voxels, num_classes)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_center_voxels] .get_shape(), (num_valid_voxels, 3)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_length_voxels] .get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_height_voxels] .get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields.object_width_voxels] .get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_x_cos_voxels].get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_x_sin_voxels].get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_y_cos_voxels].get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_y_sin_voxels].get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_z_cos_voxels].get_shape(), (num_valid_voxels, 1)) self.assertEqual( outputs[standard_fields.DetectionResultFields .object_rotation_z_sin_voxels].get_shape(), (num_valid_voxels, 1)) self.assertIn(standard_fields.DetectionResultFields.objects_center, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_length, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_height, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_width, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_matrix, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_x_cos, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_x_sin, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_y_cos, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_y_sin, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_z_cos, outputs) self.assertIn(standard_fields.DetectionResultFields.objects_rotation_z_sin, outputs)