def postprocess(inputs, outputs, is_training, apply_nms, nms_score_threshold,
                nms_iou_threshold, nms_max_num_predicted_boxes,
                use_furthest_voxel_sampling, num_furthest_voxel_samples,
                sampler_score_vs_distance_coef):
    """Post-processor function."""
    if not is_training:

        # Squeeze voxel properties.
        for key in standard_fields.get_output_voxel_fields():
            if key in outputs and outputs[key] is not None:
                outputs[key] = tf.squeeze(outputs[key], axis=0)
        for key in standard_fields.get_output_point_fields():
            if key in outputs and outputs[key] is not None:
                outputs[key] = tf.squeeze(outputs[key], axis=0)
        for key in standard_fields.get_output_object_fields():
            if key in outputs and outputs[key] is not None:
                outputs[key] = tf.squeeze(outputs[key], axis=0)

        # Mask the valid voxels
        mask_valid_voxels(inputs=inputs, outputs=outputs)

        # NMS
        postprocessor.postprocess(
            outputs=outputs,
            score_thresh=nms_score_threshold,
            iou_thresh=nms_iou_threshold,
            max_output_size=nms_max_num_predicted_boxes,
            use_furthest_voxel_sampling=use_furthest_voxel_sampling,
            num_furthest_voxel_samples=num_furthest_voxel_samples,
            sampler_score_vs_distance_coef=sampler_score_vs_distance_coef,
            apply_nms=apply_nms)
    def test_postprocess(self):
        num_classes = 10
        n = 1000

        outputs = {
            standard_fields.DetectionResultFields.objects_score:
            tf.random.uniform((n, num_classes),
                              minval=-2.0,
                              maxval=2.0,
                              dtype=tf.float32),
            standard_fields.DetectionResultFields.objects_rotation_matrix:
            tf.random.uniform((n, 3, 3),
                              minval=-1.0,
                              maxval=1.0,
                              dtype=tf.float32),
            standard_fields.DetectionResultFields.objects_center:
            tf.random.uniform((n, 3),
                              minval=10.0,
                              maxval=20.0,
                              dtype=tf.float32),
            standard_fields.DetectionResultFields.objects_length:
            tf.random.uniform((n, 1), minval=0.1, maxval=3.0,
                              dtype=tf.float32),
            standard_fields.DetectionResultFields.objects_height:
            tf.random.uniform((n, 1), minval=0.1, maxval=3.0,
                              dtype=tf.float32),
            standard_fields.DetectionResultFields.objects_width:
            tf.random.uniform((n, 1), minval=0.1, maxval=3.0,
                              dtype=tf.float32),
        }

        postprocessor.postprocess(outputs=outputs,
                                  score_thresh=0.1,
                                  iou_thresh=0.5,
                                  max_output_size=10)

        for key in [
                standard_fields.DetectionResultFields.objects_length,
                standard_fields.DetectionResultFields.objects_height,
                standard_fields.DetectionResultFields.objects_width,
                standard_fields.DetectionResultFields.objects_center,
                standard_fields.DetectionResultFields.objects_class,
                standard_fields.DetectionResultFields.objects_score
        ]:
            self.assertEqual(len(outputs[key].shape), 2)
        self.assertEqual(
            len(outputs[standard_fields.DetectionResultFields.
                        objects_rotation_matrix].shape), 3)