def testGraphInitializedWithProtoConfig(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) graph = mp.CalculatorGraph(graph_config=config_proto) hello_world_packet = mp.packet_creator.create_string('hello world') out = [] graph = mp.CalculatorGraph(graph_config=config_proto) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet, timestamp=0) graph.add_packet_to_input_stream( stream='in', packet=hello_world_packet.at(1)) graph.close() self.assertEqual(graph.graph_input_stream_add_mode, mp.GraphInputStreamAddMode.WAIT_TILL_NOT_FULL) self.assertEqual(graph.max_queue_size, 1) self.assertFalse(graph.has_error()) self.assertLen(out, 2) self.assertEqual(out[0].timestamp, 0) self.assertEqual(out[1].timestamp, 1) self.assertEqual(mp.packet_getter.get_str(out[0]), 'hello world') self.assertEqual(mp.packet_getter.get_str(out[1]), 'hello world')
def testInsertPacketsWithSameTimestamp(self): text_config = """ max_queue_size: 1 input_stream: 'in' output_stream: 'out' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) hello_world_packet = mp.packet_creator.create_string('hello world') out = [] graph = mp.CalculatorGraph(graph_config=config_proto) graph.observe_output_stream('out', lambda _, packet: out.append(packet)) graph.start_run() graph.add_packet_to_input_stream(stream='in', packet=hello_world_packet.at(0)) graph.wait_until_idle() graph.add_packet_to_input_stream(stream='in', packet=hello_world_packet.at(0)) with self.assertRaisesRegex( ValueError, 'Current minimum expected timestamp is 1 but received 0.'): graph.wait_until_idle()
def test_valid_input_data_type_proto(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTION_LIST:input_detections' output_stream: 'DETECTION_LIST:output_detections' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: input_detections = detection_pb2.DetectionList() detection_1 = input_detections.detection.add() text_format.Parse('score: 0.5', detection_1) detection_2 = input_detections.detection.add() text_format.Parse('score: 0.8', detection_2) results = solution.process({'input_detections': input_detections}) self.assertTrue(hasattr(results, 'output_detections')) self.assertLen(results.output_detections.detection, 2) expected_detection_1 = detection_pb2.Detection() text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1) expected_detection_2 = detection_pb2.Detection() text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2) self.assertEqual(results.output_detections.detection[0], expected_detection_1) self.assertEqual(results.output_detections.detection[1], expected_detection_2)
def test_invalid_calculator_options(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { name: 'SignalGate' calculator: 'GateCalculator' input_stream: 'transformed_image_in' input_side_packet: 'ALLOW:allow_signal' output_stream: 'image_out_to_transform' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_out_to_transform' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex( ValueError, 'Modifying the calculator options of SignalGate is not supported.' ): solution_base.SolutionBase( graph_config=config_proto, calculator_params={'SignalGate.invalid_field': 'I am invalid'})
def test_solution_reset(self, text_config, side_inputs): config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs) as solution: for _ in range(20): outputs = solution.process(input_image) self.assertTrue(np.array_equal(input_image, outputs.image_out)) solution.reset()
def test_calculator_has_both_options_and_node_options(self): config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(ValueError, 'has both options and node_options fields.'): solution_base.SolutionBase( graph_config=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 })
def _initialize_graph_interface( self, validated_graph: validated_graph_config.ValidatedGraphConfig, side_inputs: Optional[Mapping[str, Any]] = None, outputs: Optional[List[str]] = None): """Gets graph interface type information and returns the canonical graph config proto.""" canonical_graph_config_proto = calculator_pb2.CalculatorGraphConfig() canonical_graph_config_proto.ParseFromString( validated_graph.binary_config) # Gets name from a 'TAG:index:name' str. def get_name(tag_index_name): return tag_index_name.split(':')[-1] # Gets the packet type information of the input streams and output streams # from the validated calculator graph. The mappings from the stream names to # the packet data types is for deciding which packet creator and getter # methods to call in the process() method. def get_stream_packet_type(packet_tag_index_name): return _PacketDataType.from_registered_name( validated_graph.registered_stream_type_name( get_name(packet_tag_index_name))) self._input_stream_type_info = { get_name(tag_index_name): get_stream_packet_type(tag_index_name) for tag_index_name in canonical_graph_config_proto.input_stream } if not outputs: output_streams = canonical_graph_config_proto.output_stream else: output_streams = outputs self._output_stream_type_info = { get_name(tag_index_name): get_stream_packet_type(tag_index_name) for tag_index_name in output_streams } # Gets the packet type information of the input side packets from the # validated calculator graph. The mappings from the side packet names to the # packet data types is for making the input_side_packets dict for graph # start_run(). def get_side_packet_type(packet_tag_index_name): return _PacketDataType.from_registered_name( validated_graph.registered_side_packet_type_name( get_name(packet_tag_index_name))) self._side_input_type_info = { get_name(tag_index_name): get_side_packet_type(tag_index_name) for tag_index_name, _ in (side_inputs or {}).items() } return canonical_graph_config_proto
def test_modifying_calculator_proto3_node_options(self): config_proto = text_format.Parse( CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG, calculator_pb2.CalculatorGraphConfig()) # To test proto3 node options only, remove the proto2 options field from the # graph config. self.assertEqual('ImageTransformation', config_proto.node[0].name) config_proto.node[0].ClearField('options') self._process_and_verify(config_proto=config_proto, calculator_params={ 'ImageTransformation.output_width': 0, 'ImageTransformation.output_height': 0 })
def main(): args = get_args() logzero.loglevel(eval('logging.{}'.format(args.log_level))) graph_configs = [] for graph in args.graphs: graph_config = calculator_pb2.CalculatorGraphConfig() text_format.Merge(open(graph).read(), graph_config) logger.info(graph_config) graph_configs.append(graph_config) graph = CalculatorGraph() graph.initialize(graph_configs[0], graph_configs[1:]) graph.start_run(args.input_side_packet, None, True)
def testInvalidCalculatorType(self): text_config = """ node { calculator: 'SomeUnknownCalculator' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) with self.assertRaisesRegex( RuntimeError, 'Unable to find Calculator \"SomeUnknownCalculator\"'): mp.CalculatorGraph(graph_config=config_proto)
def test_invalid_initialization_arguments(self): with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase() with self.assertRaisesRegex( ValueError, 'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.' ): solution_base.SolutionBase( graph_config=calculator_pb2.CalculatorGraphConfig(), binary_graph_path='/tmp/no_such.binarypb')
def testInvalidNodeConfig(self): text_config = """ node { calculator: 'PassThroughCalculator' input_stream: 'in' input_stream: 'in' output_stream: 'out' } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) with self.assertRaisesRegex( ValueError, 'Input and output streams to PassThroughCalculator must use matching tags and indexes.' ): mp.CalculatorGraph(graph_config=config_proto)
def testSidePacketGraph(self): text_config = """ node { calculator: 'StringToUint64Calculator' input_side_packet: "string" output_side_packet: "number" } """ config_proto = calculator_pb2.CalculatorGraphConfig() text_format.Parse(text_config, config_proto) graph = mp.CalculatorGraph(graph_config=config_proto) graph.start_run( input_side_packets={'string': mp.packet_creator.create_string('42')}) graph.wait_until_done() self.assertFalse(graph.has_error()) self.assertEqual( mp.packet_getter.get_uint(graph.get_output_side_packet('number')), 42)
def test_invalid_input_data_type(self): text_config = """ input_stream: 'input_detections' output_stream: 'output_detections' node { calculator: 'DetectionUniqueIdCalculator' input_stream: 'DETECTIONS:input_detections' output_stream: 'DETECTIONS:output_detections' } """ config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: detection = detection_pb2.Detection() text_format.Parse('score: 0.5', detection) with self.assertRaisesRegex( NotImplementedError, 'SolutionBase can only process image data. PROTO_LIST type is not supported.' ): solution.process({'input_detections': detection})
def test_invalid_input_image_data(self): text_config = """ input_stream: 'image_in' output_stream: 'image_out' node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:image_in' output_stream: 'IMAGE:transformed_image_in' } node { calculator: 'ImageTransformationCalculator' input_stream: 'IMAGE:transformed_image_in' output_stream: 'IMAGE:image_out' } """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) with solution_base.SolutionBase(graph_config=config_proto) as solution: with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
} } # Subgraph that detections objects (see object_detection_cpu.pbtxt). node { calculator: "ObjectDetectionSubgraphCpu" input_stream: "IMAGE:throttled_input_video" output_stream: "DETECTIONS:output_detections" } # Subgraph that tracks objects (see object_tracking_cpu.pbtxt). node { calculator: "ObjectTrackingSubgraphCpu" input_stream: "VIDEO:input_video" input_stream: "DETECTIONS:output_detections" output_stream: "DETECTIONS:tracked_detections" } # Subgraph that renders annotations and overlays them on top of input images (see renderer_cpu.pbtxt). node { calculator: "RendererSubgraphCpu" input_stream: "IMAGE:input_video" input_stream: "DETECTIONS:tracked_detections" output_stream: "IMAGE:output_video" } """ import mediapipe.framework.calculator_pb2 as calculator_pb2 from google.protobuf import text_format graph = calculator_pb2.CalculatorGraphConfig() text_format.Merge(pb, graph) print(graph)
def test_invalid_config(self, text_config, error_type, error_message): config_proto = text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()) with self.assertRaisesRegex(error_type, error_message): solution_base.SolutionBase(graph_config=config_proto)
def test_solution_process(self, text_config, side_inputs): self._process_and_verify(config_proto=text_format.Parse( text_config, calculator_pb2.CalculatorGraphConfig()), side_inputs=side_inputs)