def test_postprocess_per_class(self): """Test postprocess with per class nms.""" tf.random.set_seed(1111) cls_outputs = { 1: tf.random.normal([2, 4, 4, 2]), 2: tf.random.normal([2, 2, 2, 2]) } box_outputs = { 1: tf.random.normal([2, 4, 4, 4]), 2: tf.random.normal([2, 2, 2, 4]) } cls_outputs_list = [cls_outputs[1], cls_outputs[2]] box_outputs_list = [box_outputs[1], box_outputs[2]] scales = [1.0, 2.0] ids = [0, 1] self.params['max_detection_points'] = 10 outputs = postprocess.generate_detections(self.params, cls_outputs_list, box_outputs_list, scales, ids) self.params['disable_pyfun'] = False score_thresh = 0.5 max_output_size = self.params['nms_configs']['max_output_size'] self.params['batch_size'] = len(scales) legacy_outputs = inference.det_post_process(self.params, cls_outputs, box_outputs, scales, score_thresh, max_output_size) self.assertAllClose(outputs, legacy_outputs)
def test_postprocess_global(self): """Test the postprocess with global nms.""" tf.random.set_seed(1111) cls_outputs = { 1: tf.random.normal([2, 4, 4, 2]), 2: tf.random.normal([2, 2, 2, 2]) } box_outputs = { 1: tf.random.normal([2, 4, 4, 4]), 2: tf.random.normal([2, 2, 2, 4]) } cls_outputs_list = [cls_outputs[1], cls_outputs[2]] box_outputs_list = [box_outputs[1], box_outputs[2]] scales = [1.0, 2.0] self.params['max_detection_points'] = 10 boxes, scores, classes, valid_len = postprocess.postprocess_global( self.params, cls_outputs_list, box_outputs_list, scales) self.assertAllClose(valid_len, [2, 2]) self.params['disable_pyfun'] = True score_thresh = 0.5 self.params['batch_size'] = len(scales) max_output_size = self.params['nms_configs']['max_output_size'] legacy_detections = inference.det_post_process(self.params, cls_outputs, box_outputs, scales, score_thresh, max_output_size) legacy_boxes = legacy_detections[:, :, 1:5] legacy_scores = legacy_detections[:, :, 5] legacy_classes = legacy_detections[:, :, 6] self.assertAllClose(boxes, legacy_boxes) self.assertAllClose(scores, legacy_scores) self.assertAllClose(classes, legacy_classes)