コード例 #1
0
def str_to_schema(raw_schema):
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
        raw_schema)

    descriptor_pool = DescriptorPool()
    descriptor_pool.Add(file_descriptor_proto)

    name = file_descriptor_proto.name
    descriptor_person = descriptor_pool.FindMessageTypeByName(
        'tutorial.Person')
    descriptor_addressbook = descriptor_pool.FindMessageTypeByName(
        'tutorial.AddressBook')
    descriptor_phonenumber = descriptor_pool.FindMessageTypeByName(
        'tutorial.Person.PhoneNumber')

    person_class = MessageFactory().GetPrototype(descriptor_person)
    addressbook_class = MessageFactory().GetPrototype(descriptor_addressbook)
    phonenumber_class = MessageFactory().GetPrototype(descriptor_phonenumber)

    # A different way to extract types from the proto
    # messages_type_dic = message_factory.GetMessages([file_descriptor_proto])
    # # messages_type_dic only contain 'tutorial.Person' and 'tutorial.AddressBook'
    # assert len(messages_type_dic) == 2
    # assert 'tutorial.Person' in messages_type_dic
    # assert 'tutorial.AddressBook' in messages_type_dic
    # person_class = messages_type_dic['tutorial.Person']
    # addressbook_class = messages_type_dic['tutorial.AddressBook']
    # person_instance = person_class()
    # addressbook_instance = addressbook_class()
    # return person_class, addressbook_class

    return person_class, addressbook_class, phonenumber_class
コード例 #2
0
ファイル: proto.py プロジェクト: KendallPark/struct2tensor
def create_expression_from_file_descriptor_set(tensor_of_protos,
                                               proto_name,
                                               file_descriptor_set,
                                               message_format="binary"):
    """Create an expression from a 1D tensor of serialized protos.

  Args:
    tensor_of_protos: 1D tensor of serialized protos.
    proto_name: fully qualified name (e.g. "some.package.SomeProto") of the
      proto in `tensor_of_protos`.
    file_descriptor_set: The FileDescriptorSet proto containing `proto_name`'s
      and all its dependencies' FileDescriptorProto. Note that if file1 imports
      file2, then file2's FileDescriptorProto must precede file1's in
      file_descriptor_set.file.
    message_format: Indicates the format of the protocol buffer: is one of
       'text' or 'binary'.

  Returns:
    An expression.
  """

    pool = DescriptorPool()
    for f in file_descriptor_set.file:
        # This method raises if f's dependencies have not been added.
        pool.Add(f)

    # This method raises if proto not found.
    desc = pool.FindMessageTypeByName(proto_name)

    return create_expression_from_proto(tensor_of_protos, desc, message_format)
コード例 #3
0
def str_to_schema(raw_schema):
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
        raw_schema)

    descriptor_pool = DescriptorPool()
    descriptor_pool.Add(file_descriptor_proto)

    # message_descriptors = []
    # for message_type in file_descriptor_proto.message_type:
    #     print(message_type)
    #     # The following line would raise error asking to import
    #     message_descriptors.append(descriptor.MakeDescriptor(message_type))

    name = file_descriptor_proto.name
    file_descriptor = descriptor_pool.FindFileByName(name)
    descriptor1 = descriptor_pool.FindMessageTypeByName('tutorial.Person')
    descriptor2 = descriptor_pool.FindMessageTypeByName('tutorial.AddressBook')
    descriptor3 = descriptor_pool.FindMessageTypeByName(
        'tutorial.Person.PhoneNumber')

    return descriptor1, descriptor2, descriptor3
コード例 #4
0
class ProtoFactory:
    def __init__(self, platform_delegate: PlatformDelegate):
        # Declare descriptor pool
        self.descriptor_pool = DescriptorPool()

        # Load trace processor descriptor and add to descriptor pool
        tp_desc = platform_delegate.get_resource('trace_processor.descriptor')
        tp_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet()
        tp_file_desc_set_pb2.MergeFromString(tp_desc)

        for f_desc_pb2 in tp_file_desc_set_pb2.file:
            self.descriptor_pool.Add(f_desc_pb2)

        # Load metrics descriptor and add to descriptor pool
        metrics_desc = platform_delegate.get_resource('metrics.descriptor')
        metrics_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet()
        metrics_file_desc_set_pb2.MergeFromString(metrics_desc)

        for f_desc_pb2 in metrics_file_desc_set_pb2.file:
            self.descriptor_pool.Add(f_desc_pb2)

        def create_message_factory(message_type):
            message_desc = self.descriptor_pool.FindMessageTypeByName(
                message_type)
            return message_factory.MessageFactory().GetPrototype(message_desc)

        # Create proto messages to correctly communicate with the RPC API by sending
        # and receiving data as protos
        self.AppendTraceDataResult = create_message_factory(
            'perfetto.protos.AppendTraceDataResult')
        self.StatusResult = create_message_factory(
            'perfetto.protos.StatusResult')
        self.ComputeMetricArgs = create_message_factory(
            'perfetto.protos.ComputeMetricArgs')
        self.ComputeMetricResult = create_message_factory(
            'perfetto.protos.ComputeMetricResult')
        self.RawQueryArgs = create_message_factory(
            'perfetto.protos.RawQueryArgs')
        self.QueryResult = create_message_factory(
            'perfetto.protos.QueryResult')
        self.TraceMetrics = create_message_factory(
            'perfetto.protos.TraceMetrics')
        self.DisableAndReadMetatraceResult = create_message_factory(
            'perfetto.protos.DisableAndReadMetatraceResult')
        self.CellsBatch = create_message_factory(
            'perfetto.protos.QueryResult.CellsBatch')
コード例 #5
0
class ReflectionClientTest(unittest.TestCase):
    def setUp(self):
        self._server = test_common.test_server()
        self._SERVICE_NAMES = (
            test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name,
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(self._SERVICE_NAMES, self._server)
        port = self._server.add_insecure_port("[::]:0")
        self._server.start()

        self._channel = grpc.insecure_channel("localhost:%d" % port)

        self._reflection_db = ProtoReflectionDescriptorDatabase(self._channel)
        self.desc_pool = DescriptorPool(self._reflection_db)

    def tearDown(self):
        self._server.stop(None)
        self._channel.close()

    def testListServices(self):
        services = self._reflection_db.get_services()
        self.assertCountEqual(self._SERVICE_NAMES, services)

    def testReflectionServiceName(self):
        self.assertEqual(reflection.SERVICE_NAME,
                         "grpc.reflection.v1alpha.ServerReflection")

    def testFindFile(self):
        file_name = _PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertEqual(file_name, file_desc.name)
        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
        self.assertEqual("proto3", file_desc.syntax)
        self.assertIn("TestService", file_desc.services_by_name)

        file_name = _EMPTY_PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertEqual(file_name, file_desc.name)
        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
        self.assertEqual("proto3", file_desc.syntax)
        self.assertIn("Empty", file_desc.message_types_by_name)

    def testFindFileError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindFileByName(_INVALID_FILE_NAME)

    def testFindMessage(self):
        message_name = _EMPTY_PROTO_SYMBOL_NAME
        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
        self.assertEqual(message_name, message_desc.full_name)
        self.assertTrue(message_name.endswith(message_desc.name))

    def testFindMessageError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindMessageTypeByName(_INVALID_SYMBOL_NAME)

    def testFindServiceFindMethod(self):
        service_name = self._SERVICE_NAMES[0]
        service_desc = self.desc_pool.FindServiceByName(service_name)
        self.assertEqual(service_name, service_desc.full_name)
        self.assertTrue(service_name.endswith(service_desc.name))
        file_name = _PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertIs(file_desc, service_desc.file)

        method_name = "EmptyCall"
        self.assertIn(method_name, service_desc.methods_by_name)

        method_desc = service_desc.FindMethodByName(method_name)
        self.assertIs(method_desc, service_desc.methods_by_name[method_name])
        self.assertIs(service_desc, method_desc.containing_service)
        self.assertEqual(method_name, method_desc.name)
        self.assertTrue(method_desc.full_name.endswith(method_name))

        empty_message_desc = self.desc_pool.FindMessageTypeByName(
            _EMPTY_PROTO_SYMBOL_NAME)
        self.assertEqual(empty_message_desc, method_desc.input_type)
        self.assertEqual(empty_message_desc, method_desc.output_type)

    def testFindServiceError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindServiceByName(_INVALID_SYMBOL_NAME)

    def testFindMethodError(self):
        service_name = self._SERVICE_NAMES[0]
        service_desc = self.desc_pool.FindServiceByName(service_name)

        # FindMethodByName sometimes raises a KeyError, and sometimes returns None.
        # See https://github.com/protocolbuffers/protobuf/issues/9592
        with self.assertRaises(KeyError):
            res = service_desc.FindMethodByName(_INVALID_SYMBOL_NAME)
            if res is None:
                raise KeyError()

    def testFindExtensionNotImplemented(self):
        """
        Extensions aren't implemented in Protobuf for Python.
        For now, simply assert that indeed they don't work.
        """
        message_name = _EMPTY_EXTENSIONS_SYMBOL_NAME
        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
        self.assertEqual(message_name, message_desc.full_name)
        self.assertTrue(message_name.endswith(message_desc.name))
        extension_field_descs = self.desc_pool.FindAllExtensions(message_desc)

        self.assertEqual(0, len(extension_field_descs))
        with self.assertRaises(KeyError):
            self.desc_pool.FindExtensionByName(message_name)
コード例 #6
0
class LogReader(object):
    """
        File-like interface for binary logs.

        >>> with LogReader("path/to/log/file.gz") as log:
        ...     for item in log.items():
        ...         print(entry.type)
        ...         print(entry.value.some.nested.object)
        ...         print()

        >>> with LogReader("path/to/log/file.gz") as log:
        ...     for value in log.values():
        ...         print(value.some.nested.object)
        ...         print()
    """
    def __init__(self, path):
        self._log_reader = _LogReader(path)
        # header
        header_ = self._log_reader.header
        fd_set = FileDescriptorSet()
        fd_set.ParseFromString(header_.proto)
        self._header = Header(proto=fd_set, types=header_.types)
        # descriptors
        self._pool = DescriptorPool()
        for proto in self._header.proto.file:
            self._pool.Add(proto)
        self._factory = MessageFactory()

    @property
    def path(self):
        """Log path."""
        return self._log_reader.path

    @property
    def header(self):
        """Log header."""
        return self._header

    def __repr__(self):
        return repr(self._log_reader)

    def items(self):
        """Return iterator to log items."""
        this = self

        class Iterator(object):
            def __iter__(self):
                return self

            def next(self):
                return this._next()

        return Iterator()

    def values(self):
        """Return iterator to log values."""
        this = self

        class Iterator(object):
            def __iter__(self):
                return self

            def next(self):
                return this._next().value

        return Iterator()

    def _next(self):
        next_ = self._log_reader.next()
        descriptor = self._pool.FindMessageTypeByName(next_.type)
        value = self._factory.GetPrototype(descriptor)()
        value.ParseFromString(next_.data)
        return LogItem(next_.type, value)

    def read(self):
        """Return None on EOF."""
        try:
            return self._next().value
        except StopIteration:
            return None

    def close(self):
        """Closes LogReader. LogReader will take EOF state."""
        self._log_reader.close()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()