Exemplo n.º 1
0
  def testGraphInitializedWithProtoConfig(self):
    text_config = """
      max_queue_size: 1
      input_stream: 'in'
      output_stream: 'out'
      node {
        calculator: 'PassThroughCalculator'
        input_stream: 'in'
        output_stream: 'out'
      }
    """
    config_proto = calculator_pb2.CalculatorGraphConfig()
    text_format.Parse(text_config, config_proto)
    graph = mp.CalculatorGraph(graph_config=config_proto)

    hello_world_packet = mp.packet_creator.create_string('hello world')
    out = []
    graph = mp.CalculatorGraph(graph_config=config_proto)
    graph.observe_output_stream('out', lambda _, packet: out.append(packet))
    graph.start_run()
    graph.add_packet_to_input_stream(
        stream='in', packet=hello_world_packet, timestamp=0)
    graph.add_packet_to_input_stream(
        stream='in', packet=hello_world_packet.at(1))
    graph.close()
    self.assertEqual(graph.graph_input_stream_add_mode,
                     mp.GraphInputStreamAddMode.WAIT_TILL_NOT_FULL)
    self.assertEqual(graph.max_queue_size, 1)
    self.assertFalse(graph.has_error())
    self.assertLen(out, 2)
    self.assertEqual(out[0].timestamp, 0)
    self.assertEqual(out[1].timestamp, 1)
    self.assertEqual(mp.packet_getter.get_str(out[0]), 'hello world')
    self.assertEqual(mp.packet_getter.get_str(out[1]), 'hello world')
Exemplo n.º 2
0
    def testInsertPacketsWithSameTimestamp(self):
        text_config = """
      max_queue_size: 1
      input_stream: 'in'
      output_stream: 'out'
      node {
        calculator: 'PassThroughCalculator'
        input_stream: 'in'
        output_stream: 'out'
      }
    """
        config_proto = calculator_pb2.CalculatorGraphConfig()
        text_format.Parse(text_config, config_proto)

        hello_world_packet = mp.packet_creator.create_string('hello world')
        out = []
        graph = mp.CalculatorGraph(graph_config=config_proto)
        graph.observe_output_stream('out',
                                    lambda _, packet: out.append(packet))
        graph.start_run()
        graph.add_packet_to_input_stream(stream='in',
                                         packet=hello_world_packet.at(0))
        graph.wait_until_idle()
        graph.add_packet_to_input_stream(stream='in',
                                         packet=hello_world_packet.at(0))
        with self.assertRaisesRegex(
                ValueError,
                'Current minimum expected timestamp is 1 but received 0.'):
            graph.wait_until_idle()
Exemplo n.º 3
0
 def test_valid_input_data_type_proto(self):
   text_config = """
     input_stream: 'input_detections'
     output_stream: 'output_detections'
     node {
       calculator: 'DetectionUniqueIdCalculator'
       input_stream: 'DETECTION_LIST:input_detections'
       output_stream: 'DETECTION_LIST:output_detections'
     }
   """
   config_proto = text_format.Parse(text_config,
                                    calculator_pb2.CalculatorGraphConfig())
   with solution_base.SolutionBase(graph_config=config_proto) as solution:
     input_detections = detection_pb2.DetectionList()
     detection_1 = input_detections.detection.add()
     text_format.Parse('score: 0.5', detection_1)
     detection_2 = input_detections.detection.add()
     text_format.Parse('score: 0.8', detection_2)
     results = solution.process({'input_detections': input_detections})
     self.assertTrue(hasattr(results, 'output_detections'))
     self.assertLen(results.output_detections.detection, 2)
     expected_detection_1 = detection_pb2.Detection()
     text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
     expected_detection_2 = detection_pb2.Detection()
     text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
     self.assertEqual(results.output_detections.detection[0],
                      expected_detection_1)
     self.assertEqual(results.output_detections.detection[1],
                      expected_detection_2)
 def test_invalid_calculator_options(self):
     text_config = """
   input_stream: 'image_in'
   output_stream: 'image_out'
   node {
     calculator: 'ImageTransformationCalculator'
     input_stream: 'IMAGE:image_in'
     output_stream: 'IMAGE:transformed_image_in'
   }
   node {
     name: 'SignalGate'
     calculator: 'GateCalculator'
     input_stream: 'transformed_image_in'
     input_side_packet: 'ALLOW:allow_signal'
     output_stream: 'image_out_to_transform'
   }
   node {
     calculator: 'ImageTransformationCalculator'
     input_stream: 'IMAGE:image_out_to_transform'
     output_stream: 'IMAGE:image_out'
   }
 """
     config_proto = text_format.Parse(
         text_config, calculator_pb2.CalculatorGraphConfig())
     with self.assertRaisesRegex(
             ValueError,
             'Modifying the calculator options of SignalGate is not supported.'
     ):
         solution_base.SolutionBase(
             graph_config=config_proto,
             calculator_params={'SignalGate.invalid_field': 'I am invalid'})
Exemplo n.º 5
0
 def test_solution_reset(self, text_config, side_inputs):
   config_proto = text_format.Parse(text_config,
                                    calculator_pb2.CalculatorGraphConfig())
   input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
   with solution_base.SolutionBase(
       graph_config=config_proto, side_inputs=side_inputs) as solution:
     for _ in range(20):
       outputs = solution.process(input_image)
       self.assertTrue(np.array_equal(input_image, outputs.image_out))
       solution.reset()
Exemplo n.º 6
0
 def test_calculator_has_both_options_and_node_options(self):
   config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
                                    calculator_pb2.CalculatorGraphConfig())
   with self.assertRaisesRegex(ValueError,
                               'has both options and node_options fields.'):
     solution_base.SolutionBase(
         graph_config=config_proto,
         calculator_params={
             'ImageTransformation.output_width': 0,
             'ImageTransformation.output_height': 0
         })
Exemplo n.º 7
0
    def _initialize_graph_interface(
            self,
            validated_graph: validated_graph_config.ValidatedGraphConfig,
            side_inputs: Optional[Mapping[str, Any]] = None,
            outputs: Optional[List[str]] = None):
        """Gets graph interface type information and returns the canonical graph config proto."""

        canonical_graph_config_proto = calculator_pb2.CalculatorGraphConfig()
        canonical_graph_config_proto.ParseFromString(
            validated_graph.binary_config)

        # Gets name from a 'TAG:index:name' str.
        def get_name(tag_index_name):
            return tag_index_name.split(':')[-1]

        # Gets the packet type information of the input streams and output streams
        # from the validated calculator graph. The mappings from the stream names to
        # the packet data types is for deciding which packet creator and getter
        # methods to call in the process() method.
        def get_stream_packet_type(packet_tag_index_name):
            return _PacketDataType.from_registered_name(
                validated_graph.registered_stream_type_name(
                    get_name(packet_tag_index_name)))

        self._input_stream_type_info = {
            get_name(tag_index_name): get_stream_packet_type(tag_index_name)
            for tag_index_name in canonical_graph_config_proto.input_stream
        }

        if not outputs:
            output_streams = canonical_graph_config_proto.output_stream
        else:
            output_streams = outputs
        self._output_stream_type_info = {
            get_name(tag_index_name): get_stream_packet_type(tag_index_name)
            for tag_index_name in output_streams
        }

        # Gets the packet type information of the input side packets from the
        # validated calculator graph. The mappings from the side packet names to the
        # packet data types is for making the input_side_packets dict for graph
        # start_run().
        def get_side_packet_type(packet_tag_index_name):
            return _PacketDataType.from_registered_name(
                validated_graph.registered_side_packet_type_name(
                    get_name(packet_tag_index_name)))

        self._side_input_type_info = {
            get_name(tag_index_name): get_side_packet_type(tag_index_name)
            for tag_index_name, _ in (side_inputs or {}).items()
        }
        return canonical_graph_config_proto
 def test_modifying_calculator_proto3_node_options(self):
     config_proto = text_format.Parse(
         CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
         calculator_pb2.CalculatorGraphConfig())
     # To test proto3 node options only, remove the proto2 options field from the
     # graph config.
     self.assertEqual('ImageTransformation', config_proto.node[0].name)
     config_proto.node[0].ClearField('options')
     self._process_and_verify(config_proto=config_proto,
                              calculator_params={
                                  'ImageTransformation.output_width': 0,
                                  'ImageTransformation.output_height': 0
                              })
Exemplo n.º 9
0
def main():
    args = get_args()

    logzero.loglevel(eval('logging.{}'.format(args.log_level)))
    graph_configs = []
    for graph in args.graphs:
        graph_config = calculator_pb2.CalculatorGraphConfig()
        text_format.Merge(open(graph).read(), graph_config)
        logger.info(graph_config)
        graph_configs.append(graph_config)
    graph = CalculatorGraph()
    graph.initialize(graph_configs[0], graph_configs[1:])
    graph.start_run(args.input_side_packet, None, True)
Exemplo n.º 10
0
 def testInvalidCalculatorType(self):
   text_config = """
     node {
       calculator: 'SomeUnknownCalculator'
       input_stream: 'in'
       output_stream: 'out'
     }
   """
   config_proto = calculator_pb2.CalculatorGraphConfig()
   text_format.Parse(text_config, config_proto)
   with self.assertRaisesRegex(
       RuntimeError, 'Unable to find Calculator \"SomeUnknownCalculator\"'):
     mp.CalculatorGraph(graph_config=config_proto)
 def test_invalid_initialization_arguments(self):
     with self.assertRaisesRegex(
             ValueError,
             'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
     ):
         solution_base.SolutionBase()
     with self.assertRaisesRegex(
             ValueError,
             'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
     ):
         solution_base.SolutionBase(
             graph_config=calculator_pb2.CalculatorGraphConfig(),
             binary_graph_path='/tmp/no_such.binarypb')
Exemplo n.º 12
0
 def testInvalidNodeConfig(self):
     text_config = """
   node {
     calculator: 'PassThroughCalculator'
     input_stream: 'in'
     input_stream: 'in'
     output_stream: 'out'
   }
 """
     config_proto = calculator_pb2.CalculatorGraphConfig()
     text_format.Parse(text_config, config_proto)
     with self.assertRaisesRegex(
             ValueError,
             'Input and output streams to PassThroughCalculator must use matching tags and indexes.'
     ):
         mp.CalculatorGraph(graph_config=config_proto)
Exemplo n.º 13
0
 def testSidePacketGraph(self):
   text_config = """
     node {
       calculator: 'StringToUint64Calculator'
       input_side_packet: "string"
       output_side_packet: "number"
     }
   """
   config_proto = calculator_pb2.CalculatorGraphConfig()
   text_format.Parse(text_config, config_proto)
   graph = mp.CalculatorGraph(graph_config=config_proto)
   graph.start_run(
       input_side_packets={'string': mp.packet_creator.create_string('42')})
   graph.wait_until_done()
   self.assertFalse(graph.has_error())
   self.assertEqual(
       mp.packet_getter.get_uint(graph.get_output_side_packet('number')), 42)
 def test_invalid_input_data_type(self):
     text_config = """
   input_stream: 'input_detections'
   output_stream: 'output_detections'
   node {
     calculator: 'DetectionUniqueIdCalculator'
     input_stream: 'DETECTIONS:input_detections'
     output_stream: 'DETECTIONS:output_detections'
   }
 """
     config_proto = text_format.Parse(
         text_config, calculator_pb2.CalculatorGraphConfig())
     with solution_base.SolutionBase(graph_config=config_proto) as solution:
         detection = detection_pb2.Detection()
         text_format.Parse('score: 0.5', detection)
         with self.assertRaisesRegex(
                 NotImplementedError,
                 'SolutionBase can only process image data. PROTO_LIST type is not supported.'
         ):
             solution.process({'input_detections': detection})
Exemplo n.º 15
0
 def test_invalid_input_image_data(self):
   text_config = """
     input_stream: 'image_in'
     output_stream: 'image_out'
     node {
       calculator: 'ImageTransformationCalculator'
       input_stream: 'IMAGE:image_in'
       output_stream: 'IMAGE:transformed_image_in'
     }
     node {
       calculator: 'ImageTransformationCalculator'
       input_stream: 'IMAGE:transformed_image_in'
       output_stream: 'IMAGE:image_out'
     }
   """
   config_proto = text_format.Parse(text_config,
                                    calculator_pb2.CalculatorGraphConfig())
   with solution_base.SolutionBase(graph_config=config_proto) as solution:
     with self.assertRaisesRegex(
         ValueError, 'Input image must contain three channel rgb data.'):
       solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
Exemplo n.º 16
0
  }
}

# Subgraph that detections objects (see object_detection_cpu.pbtxt).
node {
  calculator: "ObjectDetectionSubgraphCpu"
  input_stream: "IMAGE:throttled_input_video"
  output_stream: "DETECTIONS:output_detections"
}

# Subgraph that tracks objects (see object_tracking_cpu.pbtxt).
node {
  calculator: "ObjectTrackingSubgraphCpu"
  input_stream: "VIDEO:input_video"
  input_stream: "DETECTIONS:output_detections"
  output_stream: "DETECTIONS:tracked_detections"
}

# Subgraph that renders annotations and overlays them on top of input images (see renderer_cpu.pbtxt).
node {
  calculator: "RendererSubgraphCpu"
  input_stream: "IMAGE:input_video"
  input_stream: "DETECTIONS:tracked_detections"
  output_stream: "IMAGE:output_video"
}
"""
import mediapipe.framework.calculator_pb2 as calculator_pb2
from google.protobuf import text_format
graph = calculator_pb2.CalculatorGraphConfig()
text_format.Merge(pb, graph)
print(graph)
 def test_invalid_config(self, text_config, error_type, error_message):
     config_proto = text_format.Parse(
         text_config, calculator_pb2.CalculatorGraphConfig())
     with self.assertRaisesRegex(error_type, error_message):
         solution_base.SolutionBase(graph_config=config_proto)
 def test_solution_process(self, text_config, side_inputs):
     self._process_and_verify(config_proto=text_format.Parse(
         text_config, calculator_pb2.CalculatorGraphConfig()),
                              side_inputs=side_inputs)