예제 #1
0
 def test_valid_input_data_type_proto(self):
   text_config = """
     input_stream: 'input_detections'
     output_stream: 'output_detections'
     node {
       calculator: 'DetectionUniqueIdCalculator'
       input_stream: 'DETECTION_LIST:input_detections'
       output_stream: 'DETECTION_LIST:output_detections'
     }
   """
   config_proto = text_format.Parse(text_config,
                                    calculator_pb2.CalculatorGraphConfig())
   with solution_base.SolutionBase(graph_config=config_proto) as solution:
     input_detections = detection_pb2.DetectionList()
     detection_1 = input_detections.detection.add()
     text_format.Parse('score: 0.5', detection_1)
     detection_2 = input_detections.detection.add()
     text_format.Parse('score: 0.8', detection_2)
     results = solution.process({'input_detections': input_detections})
     self.assertTrue(hasattr(results, 'output_detections'))
     self.assertLen(results.output_detections.detection, 2)
     expected_detection_1 = detection_pb2.Detection()
     text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
     expected_detection_2 = detection_pb2.Detection()
     text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
     self.assertEqual(results.output_detections.detection[0],
                      expected_detection_1)
     self.assertEqual(results.output_detections.detection[1],
                      expected_detection_2)
예제 #2
0
 def test_unqualified_detection(self):
     detection = text_format.Parse('location_data {format: GLOBAL}',
                                   detection_pb2.Detection())
     image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
     with self.assertRaisesRegex(ValueError,
                                 'LocationData must be relative'):
         drawing_utils.draw_detection(image, detection)
예제 #3
0
 def test_invalid_input_image(self):
     image = np.arange(18, dtype=np.uint8).reshape(3, 3, 2)
     with self.assertRaisesRegex(
             ValueError,
             'Input image must contain three channel rgb data.'):
         drawing_utils.draw_landmarks(image,
                                      landmark_pb2.NormalizedLandmarkList())
     with self.assertRaisesRegex(
             ValueError,
             'Input image must contain three channel rgb data.'):
         drawing_utils.draw_detection(image, detection_pb2.Detection())
 def test_draw_bboxs_only(self):
   detection = text_format.Parse(
       'location_data {'
       '  format: RELATIVE_BOUNDING_BOX'
       '  relative_bounding_box {xmin: 0 ymin: 0 width: 1 height: 1}}',
       detection_pb2.Detection())
   image = np.zeros((100, 100, 3), np.uint8)
   expected_result = np.copy(image)
   cv2.rectangle(expected_result, (0, 0), (99, 99),
                 DEFAULT_BBOX_DRAWING_SPEC.color,
                 DEFAULT_BBOX_DRAWING_SPEC.thickness)
   drawing_utils.draw_detection(image, detection)
   np.testing.assert_array_equal(image, expected_result)
 def test_draw_keypoints_only(self):
   detection = text_format.Parse(
       'location_data {'
       '  format: RELATIVE_BOUNDING_BOX'
       '  relative_keypoints {x: 0 y: 1}'
       '  relative_keypoints {x: 1 y: 0}}', detection_pb2.Detection())
   image = np.zeros((100, 100, 3), np.uint8)
   expected_result = np.copy(image)
   cv2.circle(expected_result, (0, 99),
              DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius,
              DEFAULT_CIRCLE_DRAWING_SPEC.color,
              DEFAULT_CIRCLE_DRAWING_SPEC.thickness)
   cv2.circle(expected_result, (99, 0),
              DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius,
              DEFAULT_CIRCLE_DRAWING_SPEC.color,
              DEFAULT_CIRCLE_DRAWING_SPEC.thickness)
   drawing_utils.draw_detection(image, detection)
   np.testing.assert_array_equal(image, expected_result)
 def test_invalid_input_data_type(self):
     text_config = """
   input_stream: 'input_detections'
   output_stream: 'output_detections'
   node {
     calculator: 'DetectionUniqueIdCalculator'
     input_stream: 'DETECTIONS:input_detections'
     output_stream: 'DETECTIONS:output_detections'
   }
 """
     config_proto = text_format.Parse(
         text_config, calculator_pb2.CalculatorGraphConfig())
     with solution_base.SolutionBase(graph_config=config_proto) as solution:
         detection = detection_pb2.Detection()
         text_format.Parse('score: 0.5', detection)
         with self.assertRaisesRegex(
                 NotImplementedError,
                 'SolutionBase can only process image data. PROTO_LIST type is not supported.'
         ):
             solution.process({'input_detections': detection})
예제 #7
0
 def testDetectionProtoPacket(self):
   detection = detection_pb2.Detection()
   text_format.Parse('score: 0.5', detection)
   p = mp.packet_creator.create_proto(detection).at(100)