def create_expression_from_file_descriptor_set(tensor_of_protos, proto_name, file_descriptor_set, message_format="binary"): """Create an expression from a 1D tensor of serialized protos. Args: tensor_of_protos: 1D tensor of serialized protos. proto_name: fully qualified name (e.g. "some.package.SomeProto") of the proto in `tensor_of_protos`. file_descriptor_set: The FileDescriptorSet proto containing `proto_name`'s and all its dependencies' FileDescriptorProto. Note that if file1 imports file2, then file2's FileDescriptorProto must precede file1's in file_descriptor_set.file. message_format: Indicates the format of the protocol buffer: is one of 'text' or 'binary'. Returns: An expression. """ pool = DescriptorPool() for f in file_descriptor_set.file: # This method raises if f's dependencies have not been added. pool.Add(f) # This method raises if proto not found. desc = pool.FindMessageTypeByName(proto_name) return create_expression_from_proto(tensor_of_protos, desc, message_format)
def __init__(self, path): self._log_reader = _LogReader(path) # header header_ = self._log_reader.header fd_set = FileDescriptorSet() fd_set.ParseFromString(header_.proto) self._header = Header(proto=fd_set, types=header_.types) # descriptors self._pool = DescriptorPool() for proto in self._header.proto.file: self._pool.Add(proto) self._factory = MessageFactory()
def setUp(self): self._server = test_common.test_server() self._SERVICE_NAMES = ( test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(self._SERVICE_NAMES, self._server) port = self._server.add_insecure_port("[::]:0") self._server.start() self._channel = grpc.insecure_channel("localhost:%d" % port) self._reflection_db = ProtoReflectionDescriptorDatabase(self._channel) self.desc_pool = DescriptorPool(self._reflection_db)
def test_data(request): test_case_name = request.param # Reset the internal symbol database so we can import the `Test` message # multiple times. Ugh. sym = symbol_database.Default() sym.pool = DescriptorPool() reference_module_root = os.path.join( *reference_output_package.split("."), test_case_name ) sys.path.append(reference_module_root) yield ( TestData( plugin_module=importlib.import_module( f"{plugin_output_package}.{test_case_name}.{test_case_name}" ), reference_module=lambda: importlib.import_module( f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" ), json_data=get_test_case_json_data(test_case_name), ) ) sys.path.remove(reference_module_root)
def build_database_from_stub(reflection_client): """Build descriptor pool and symbol database from reflection service. Args: reflection_client: ServerReflectionStub. GRPC reflection client Returns: tuple (descriptor pool, symbol database) """ services_response = funcy.first( # Note that this is stupid, but grpc has problems iterating over lists reflection_client.ServerReflectionInfo(x for x in [GET_SERVICES_REQUEST] )) service_names = [ service.name for service in services_response.list_services_response.service ] file_requests = (reflection_pb2.ServerReflectionRequest( file_containing_symbol=service_name) for service_name in service_names) file_descriptor_proto_bytes = get_proto_bytes_for_requests( reflection_client, file_requests) descriptor_pool = DescriptorPool() add_protos_to_descriptor_pool(reflection_client, descriptor_pool, file_descriptor_proto_bytes) symbol_database = SymbolDatabase(descriptor_pool) return (descriptor_pool, symbol_database)
def main(): import argparse parser = argparse.ArgumentParser(description=( "Read all files in a given directory and scan each file for protobuf definitions," " printing usable .proto files to a given directory.")) parser.add_argument("input_path", help="Input path to scan. May be a file or directory.") parser.add_argument("output_path", help="Output directory to dump .protoc files to.") args = parser.parse_args() if api_implementation.Type() != "cpp": raise NotImplementedError( "This script requires the Protobuf installation to use the C++ implementation. Please" " reinstall Protobuf with C++ support.") GLOBAL_DESCRIPTOR_POOL = DescriptorPool() all_filenames = [ str(path) for path in Path(args.input_path).rglob("*") if not path.is_dir() ] cprint( f"Bootstrap: scanning {len(all_filenames):,} files under {args.input_path} for protobuf definitions", "green", ) proto_files_found = set() for path in all_filenames: for proto in extract_proto_from_file(path, GLOBAL_DESCRIPTOR_POOL): proto_files_found.add(proto) missing_deps = set() for found in proto_files_found: if not found.attempt_to_load(): missing_deps.update( find_missing_dependencies(proto_files_found, found.path)) if missing_deps: cprint( f"Warning: unable to print out all Protobuf definitions; {len(missing_deps):,} proto files could not be found:\n{missing_deps}", "red", ) else: for proto_file in proto_files_found: Path(args.output_path).mkdir(parents=True, exist_ok=True) with open(Path(args.output_path) / proto_file.path, "w") as f: source = proto_file.source if source: f.write(source) else: cprint(f"Warning: no source available for {proto_file}", "red") cprint( f"Bootstrap: wrote {len(proto_files_found):,} proto files to {args.output_path}", "green", )
def test_add_protos_to_descriptor_pool(self): descriptor_pool_mock = create_autospec(spec=DescriptorPool(), spec_set=True) # Test empty case reflection_descriptor_database.add_protos_to_descriptor_pool( utils.reflection_client_mock, descriptor_pool_mock, []) self.assertEqual(len(descriptor_pool_mock.Add.call_args_list), 0) # Test non-empty case reflection_descriptor_database.add_protos_to_descriptor_pool( utils.reflection_client_mock, descriptor_pool_mock, [utils.PROTO_BYTES_A, utils.PROTO_BYTES_B], ) expected_calls = [ call(utils.PROTO_A), call(utils.PROTO_C), call(utils.PROTO_E), call(utils.PROTO_F), call(utils.PROTO_D), call(utils.PROTO_B), ] descriptor_pool_mock.Add.assert_has_calls(expected_calls, any_order=False)
def test_data(request): test_case_name = request.param # Reset the internal symbol database so we can import the `Test` message # multiple times. Ugh. sym = symbol_database.Default() sym.pool = DescriptorPool() reference_module_root = os.path.join(*reference_output_package.split("."), test_case_name) sys.path.append(reference_module_root) plugin_module = importlib.import_module( f"{plugin_output_package}.{test_case_name}") plugin_module_entry_point = find_module(plugin_module, module_has_entry_point) if not plugin_module_entry_point: raise Exception( f"Test case {repr(test_case_name)} has no entry point. " "Please add a proto message or service called Test and recompile.") yield (TestData( plugin_module=plugin_module_entry_point, reference_module=lambda: importlib.import_module( f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" ), json_data=get_test_case_json_data(test_case_name), )) sys.path.remove(reference_module_root)
class ProtoFactory: def __init__(self, platform_delegate: PlatformDelegate): # Declare descriptor pool self.descriptor_pool = DescriptorPool() # Load trace processor descriptor and add to descriptor pool tp_desc = platform_delegate.get_resource('trace_processor.descriptor') tp_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet() tp_file_desc_set_pb2.MergeFromString(tp_desc) for f_desc_pb2 in tp_file_desc_set_pb2.file: self.descriptor_pool.Add(f_desc_pb2) # Load metrics descriptor and add to descriptor pool metrics_desc = platform_delegate.get_resource('metrics.descriptor') metrics_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet() metrics_file_desc_set_pb2.MergeFromString(metrics_desc) for f_desc_pb2 in metrics_file_desc_set_pb2.file: self.descriptor_pool.Add(f_desc_pb2) def create_message_factory(message_type): message_desc = self.descriptor_pool.FindMessageTypeByName( message_type) return message_factory.MessageFactory().GetPrototype(message_desc) # Create proto messages to correctly communicate with the RPC API by sending # and receiving data as protos self.AppendTraceDataResult = create_message_factory( 'perfetto.protos.AppendTraceDataResult') self.StatusResult = create_message_factory( 'perfetto.protos.StatusResult') self.ComputeMetricArgs = create_message_factory( 'perfetto.protos.ComputeMetricArgs') self.ComputeMetricResult = create_message_factory( 'perfetto.protos.ComputeMetricResult') self.RawQueryArgs = create_message_factory( 'perfetto.protos.RawQueryArgs') self.QueryResult = create_message_factory( 'perfetto.protos.QueryResult') self.TraceMetrics = create_message_factory( 'perfetto.protos.TraceMetrics') self.DisableAndReadMetatraceResult = create_message_factory( 'perfetto.protos.DisableAndReadMetatraceResult') self.CellsBatch = create_message_factory( 'perfetto.protos.QueryResult.CellsBatch')
def str_to_schema(raw_schema): file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( raw_schema) descriptor_pool = DescriptorPool() descriptor_pool.Add(file_descriptor_proto) name = file_descriptor_proto.name descriptor_person = descriptor_pool.FindMessageTypeByName( 'tutorial.Person') descriptor_addressbook = descriptor_pool.FindMessageTypeByName( 'tutorial.AddressBook') descriptor_phonenumber = descriptor_pool.FindMessageTypeByName( 'tutorial.Person.PhoneNumber') person_class = MessageFactory().GetPrototype(descriptor_person) addressbook_class = MessageFactory().GetPrototype(descriptor_addressbook) phonenumber_class = MessageFactory().GetPrototype(descriptor_phonenumber) # A different way to extract types from the proto # messages_type_dic = message_factory.GetMessages([file_descriptor_proto]) # # messages_type_dic only contain 'tutorial.Person' and 'tutorial.AddressBook' # assert len(messages_type_dic) == 2 # assert 'tutorial.Person' in messages_type_dic # assert 'tutorial.AddressBook' in messages_type_dic # person_class = messages_type_dic['tutorial.Person'] # addressbook_class = messages_type_dic['tutorial.AddressBook'] # person_instance = person_class() # addressbook_instance = addressbook_class() # return person_class, addressbook_class return person_class, addressbook_class, phonenumber_class
def test_import_dependencies_then_proto(self): descriptor_pool_mock = create_autospec(spec=DescriptorPool(), spec_set=True) # Test detached proto_a parsed_proto_bytes_a = descriptor_pb2.FileDescriptorProto() parsed_proto_bytes_a.ParseFromString(utils.PROTO_BYTES_A) reflection_descriptor_database.import_dependencies_then_proto( utils.reflection_client_mock, descriptor_pool_mock, "proto_a", set(), {"proto_a": parsed_proto_bytes_a}, ) expected_calls = [call(utils.PROTO_A)] descriptor_pool_mock.Add.assert_has_calls(expected_calls, any_order=False) # Test proto_b including dependencies in a tree descriptor_pool_mock = create_autospec(spec=DescriptorPool(), spec_set=True) parsed_proto_bytes_b = descriptor_pb2.FileDescriptorProto() parsed_proto_bytes_b.ParseFromString(utils.PROTO_BYTES_B) reflection_descriptor_database.import_dependencies_then_proto( utils.reflection_client_mock, descriptor_pool_mock, "proto_b", set(), {"proto_b": parsed_proto_bytes_b}, ) expected_calls = [ call(utils.PROTO_C), call(utils.PROTO_E), call(utils.PROTO_F), call(utils.PROTO_D), call(utils.PROTO_B), ] descriptor_pool_mock.Add.assert_has_calls(expected_calls, any_order=False)
def main(): # Read request message from stdin data = sys.stdin.buffer.read() # Parse request request = plugin.CodeGeneratorRequest() request.ParseFromString(data) # Create response response = plugin.CodeGeneratorResponse() # TODO: clean this part. # Generate code table_resolver = TableResolver() analyzer = Analyzer(table_resolver) pool = DescriptorPool() for proto_file in request.proto_file: pool.Add(proto_file) analyzer.generate_tables_for_file( file_descriptor=pool.FindFileByName(proto_file.name)) analyzer.link_tables_references() writer = ProtoPluginResponseWriter() writer.write(generator=KotlinExposedGenerator(), tables=table_resolver.tables, plugin_response=response) # Serialise response message output = response.SerializeToString() # Write to stdout sys.stdout.buffer.write(output)
def str_to_schema(raw_schema): file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( raw_schema) descriptor_pool = DescriptorPool() descriptor_pool.Add(file_descriptor_proto) # message_descriptors = [] # for message_type in file_descriptor_proto.message_type: # print(message_type) # # The following line would raise error asking to import # message_descriptors.append(descriptor.MakeDescriptor(message_type)) name = file_descriptor_proto.name file_descriptor = descriptor_pool.FindFileByName(name) descriptor1 = descriptor_pool.FindMessageTypeByName('tutorial.Person') descriptor2 = descriptor_pool.FindMessageTypeByName('tutorial.AddressBook') descriptor3 = descriptor_pool.FindMessageTypeByName( 'tutorial.Person.PhoneNumber') return descriptor1, descriptor2, descriptor3
def reset_grpc_db(self): reset_cached_client() if self.reset_descriptor_pool: _descriptor_pool._DEFAULT = DescriptorPool() _symbol_database._DEFAULT = SymbolDatabase( pool=_descriptor_pool.Default())
class LogReader(object): """ File-like interface for binary logs. >>> with LogReader("path/to/log/file.gz") as log: ... for item in log.items(): ... print(entry.type) ... print(entry.value.some.nested.object) ... print() >>> with LogReader("path/to/log/file.gz") as log: ... for value in log.values(): ... print(value.some.nested.object) ... print() """ def __init__(self, path): self._log_reader = _LogReader(path) # header header_ = self._log_reader.header fd_set = FileDescriptorSet() fd_set.ParseFromString(header_.proto) self._header = Header(proto=fd_set, types=header_.types) # descriptors self._pool = DescriptorPool() for proto in self._header.proto.file: self._pool.Add(proto) self._factory = MessageFactory() @property def path(self): """Log path.""" return self._log_reader.path @property def header(self): """Log header.""" return self._header def __repr__(self): return repr(self._log_reader) def items(self): """Return iterator to log items.""" this = self class Iterator(object): def __iter__(self): return self def next(self): return this._next() return Iterator() def values(self): """Return iterator to log values.""" this = self class Iterator(object): def __iter__(self): return self def next(self): return this._next().value return Iterator() def _next(self): next_ = self._log_reader.next() descriptor = self._pool.FindMessageTypeByName(next_.type) value = self._factory.GetPrototype(descriptor)() value.ParseFromString(next_.data) return LogItem(next_.type, value) def read(self): """Return None on EOF.""" try: return self._next().value except StopIteration: return None def close(self): """Closes LogReader. LogReader will take EOF state.""" self._log_reader.close() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close()
for filename in proto_files: print(f"Generating code for {os.path.basename(filename)}") subprocess.run( f"protoc --python_out=. {os.path.basename(filename)}", shell=True ) subprocess.run( f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}", shell=True, ) for filename in json_files: # Reset the internal symbol database so we can import the `Test` message # multiple times. Ugh. sym = symbol_database.Default() sym.pool = DescriptorPool() parts = get_base(filename).split("-") out = filename.replace(".json", ".bin") print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") imported = importlib.import_module(f"{parts[0]}_pb2") input_json = open(filename).read() parsed = Parse(input_json, imported.Test()) serialized = parsed.SerializeToString() preserve = "casing" not in filename serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve) s_loaded = json.loads(serialized_json) in_loaded = json.loads(input_json)
class ReflectionClientTest(unittest.TestCase): def setUp(self): self._server = test_common.test_server() self._SERVICE_NAMES = ( test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(self._SERVICE_NAMES, self._server) port = self._server.add_insecure_port("[::]:0") self._server.start() self._channel = grpc.insecure_channel("localhost:%d" % port) self._reflection_db = ProtoReflectionDescriptorDatabase(self._channel) self.desc_pool = DescriptorPool(self._reflection_db) def tearDown(self): self._server.stop(None) self._channel.close() def testListServices(self): services = self._reflection_db.get_services() self.assertCountEqual(self._SERVICE_NAMES, services) def testReflectionServiceName(self): self.assertEqual(reflection.SERVICE_NAME, "grpc.reflection.v1alpha.ServerReflection") def testFindFile(self): file_name = _PROTO_FILE_NAME file_desc = self.desc_pool.FindFileByName(file_name) self.assertEqual(file_name, file_desc.name) self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package) self.assertEqual("proto3", file_desc.syntax) self.assertIn("TestService", file_desc.services_by_name) file_name = _EMPTY_PROTO_FILE_NAME file_desc = self.desc_pool.FindFileByName(file_name) self.assertEqual(file_name, file_desc.name) self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package) self.assertEqual("proto3", file_desc.syntax) self.assertIn("Empty", file_desc.message_types_by_name) def testFindFileError(self): with self.assertRaises(KeyError): self.desc_pool.FindFileByName(_INVALID_FILE_NAME) def testFindMessage(self): message_name = _EMPTY_PROTO_SYMBOL_NAME message_desc = self.desc_pool.FindMessageTypeByName(message_name) self.assertEqual(message_name, message_desc.full_name) self.assertTrue(message_name.endswith(message_desc.name)) def testFindMessageError(self): with self.assertRaises(KeyError): self.desc_pool.FindMessageTypeByName(_INVALID_SYMBOL_NAME) def testFindServiceFindMethod(self): service_name = self._SERVICE_NAMES[0] service_desc = self.desc_pool.FindServiceByName(service_name) self.assertEqual(service_name, service_desc.full_name) self.assertTrue(service_name.endswith(service_desc.name)) file_name = _PROTO_FILE_NAME file_desc = self.desc_pool.FindFileByName(file_name) self.assertIs(file_desc, service_desc.file) method_name = "EmptyCall" self.assertIn(method_name, service_desc.methods_by_name) method_desc = service_desc.FindMethodByName(method_name) self.assertIs(method_desc, service_desc.methods_by_name[method_name]) self.assertIs(service_desc, method_desc.containing_service) self.assertEqual(method_name, method_desc.name) self.assertTrue(method_desc.full_name.endswith(method_name)) empty_message_desc = self.desc_pool.FindMessageTypeByName( _EMPTY_PROTO_SYMBOL_NAME) self.assertEqual(empty_message_desc, method_desc.input_type) self.assertEqual(empty_message_desc, method_desc.output_type) def testFindServiceError(self): with self.assertRaises(KeyError): self.desc_pool.FindServiceByName(_INVALID_SYMBOL_NAME) def testFindMethodError(self): service_name = self._SERVICE_NAMES[0] service_desc = self.desc_pool.FindServiceByName(service_name) # FindMethodByName sometimes raises a KeyError, and sometimes returns None. # See https://github.com/protocolbuffers/protobuf/issues/9592 with self.assertRaises(KeyError): res = service_desc.FindMethodByName(_INVALID_SYMBOL_NAME) if res is None: raise KeyError() def testFindExtensionNotImplemented(self): """ Extensions aren't implemented in Protobuf for Python. For now, simply assert that indeed they don't work. """ message_name = _EMPTY_EXTENSIONS_SYMBOL_NAME message_desc = self.desc_pool.FindMessageTypeByName(message_name) self.assertEqual(message_name, message_desc.full_name) self.assertTrue(message_name.endswith(message_desc.name)) extension_field_descs = self.desc_pool.FindAllExtensions(message_desc) self.assertEqual(0, len(extension_field_descs)) with self.assertRaises(KeyError): self.desc_pool.FindExtensionByName(message_name)