Beispiel #1
0
    def test_postprocess_per_class(self):
        """Test postprocess with per class nms."""
        tf.random.set_seed(1111)
        cls_outputs = {
            1: tf.random.normal([2, 4, 4, 2]),
            2: tf.random.normal([2, 2, 2, 2])
        }
        box_outputs = {
            1: tf.random.normal([2, 4, 4, 4]),
            2: tf.random.normal([2, 2, 2, 4])
        }
        cls_outputs_list = [cls_outputs[1], cls_outputs[2]]
        box_outputs_list = [box_outputs[1], box_outputs[2]]
        scales = [1.0, 2.0]
        ids = [0, 1]

        self.params['max_detection_points'] = 10
        outputs = postprocess.generate_detections(self.params,
                                                  cls_outputs_list,
                                                  box_outputs_list, scales,
                                                  ids)

        self.params['disable_pyfun'] = False
        score_thresh = 0.5
        max_output_size = self.params['nms_configs']['max_output_size']
        self.params['batch_size'] = len(scales)
        legacy_outputs = inference.det_post_process(self.params, cls_outputs,
                                                    box_outputs, scales,
                                                    score_thresh,
                                                    max_output_size)
        self.assertAllClose(outputs, legacy_outputs)
Beispiel #2
0
    def test_postprocess_global(self):
        """Test the postprocess with global nms."""
        tf.random.set_seed(1111)
        cls_outputs = {
            1: tf.random.normal([2, 4, 4, 2]),
            2: tf.random.normal([2, 2, 2, 2])
        }
        box_outputs = {
            1: tf.random.normal([2, 4, 4, 4]),
            2: tf.random.normal([2, 2, 2, 4])
        }
        cls_outputs_list = [cls_outputs[1], cls_outputs[2]]
        box_outputs_list = [box_outputs[1], box_outputs[2]]
        scales = [1.0, 2.0]

        self.params['max_detection_points'] = 10
        boxes, scores, classes, valid_len = postprocess.postprocess_global(
            self.params, cls_outputs_list, box_outputs_list, scales)
        self.assertAllClose(valid_len, [2, 2])

        self.params['disable_pyfun'] = True
        score_thresh = 0.5
        self.params['batch_size'] = len(scales)
        max_output_size = self.params['nms_configs']['max_output_size']
        legacy_detections = inference.det_post_process(self.params,
                                                       cls_outputs,
                                                       box_outputs, scales,
                                                       score_thresh,
                                                       max_output_size)
        legacy_boxes = legacy_detections[:, :, 1:5]
        legacy_scores = legacy_detections[:, :, 5]
        legacy_classes = legacy_detections[:, :, 6]
        self.assertAllClose(boxes, legacy_boxes)
        self.assertAllClose(scores, legacy_scores)
        self.assertAllClose(classes, legacy_classes)