def test_invalid_initialization_arguments(self): with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase() with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase( graph_config=calculator_pb2.CalculatorGraphConfig(), binary_graph_path='/tmp/no_such.binarypb')
def test_invalid_calculator_options(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { name: 'SignalGate' calculator: 'GateCalculator' input_stream: 'transformed_image_in' input_side_packet: 'ALLOW:allow_signal' output_stream: 'image_out_to_transform' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_out_to_transform' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex( ValueError, 'Modifying the calculator options of SignalGate is not supported.' ): solution_base.SolutionBase( graph_config=config_proto, calculator_params={'SignalGate.invalid_field': 'I am invalid'})
def test_valid_input_data_type_proto(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTION_LIST:input_detections' output_stream: 'DETECTION_LIST:output_detections' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: input_detections = detection_pb2.DetectionList() detection_1 = input_detections.detection.add() text_format.Parse('score: 0.5', detection_1) detection_2 = input_detections.detection.add() text_format.Parse('score: 0.8', detection_2) results = solution.process({'input_detections': input_detections}) self.assertTrue(hasattr(results, 'output_detections')) self.assertLen(results.output_detections.detection, 2) expected_detection_1 = detection_pb2.Detection() text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1) expected_detection_2 = detection_pb2.Detection() text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2) self.assertEqual(results.output_detections.detection[0], expected_detection_1) self.assertEqual(results.output_detections.detection[1], expected_detection_2)
def test_solution_reset(self, text_config, side_inputs): config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs) as solution: for _ in range(20): outputs = solution.process(input_image) self.assertTrue(np.array_equal(input_image, outputs.image_out)) solution.reset()
def test_calculator_has_both_options_and_node_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(ValueError, 'has both options and node_options fields.'): solution_base.SolutionBase( graph_config=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 })
def _process_and_verify(self, config_proto, side_inputs=None, calculator_params=None): input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs, calculator_params=calculator_params) as solution: outputs = solution.process(input_image) outputs2 = solution.process({'image_in': input_image}) self.assertTrue(np.array_equal(input_image, outputs.image_out)) self.assertTrue(np.array_equal(input_image, outputs2.image_out))
def test_invalid_input_data_type(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTIONS:input_detections' output_stream: 'DETECTIONS:output_detections' } """ config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: detection = detection_pb2.Detection() text_format.Parse('score: 0.5', detection) with self.assertRaisesRegex( NotImplementedError, 'SolutionBase can only process image data. PROTO_LIST type is not supported.' ): solution.process({'input_detections': detection})
def test_invalid_input_image_data(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:transformed_image_in' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
def test_invalid_config(self, text_config, error_type, error_message): config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(error_type, error_message): solution_base.SolutionBase(graph_config=config_proto)