async def test_extension_numbers_of_type(self): requests = ( reflection_pb2.ServerReflectionRequest( all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), reflection_pb2.ServerReflectionRequest( all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', all_extension_numbers_response=reflection_pb2. ExtensionNumberResponse( base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=_EMPTY_EXTENSIONS_NUMBERS)), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self.assertSequenceEqual(expected_responses, responses)
def testFileByName(self): requests = ( reflection_pb2.ServerReflectionRequest( file_by_filename=_EMPTY_PROTO_FILE_NAME ), reflection_pb2.ServerReflectionRequest( file_by_filename='i-donut-exist' ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ) ), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ) ), ) self.assertSequenceEqual(expected_responses, responses)
async def test_file_by_symbol(self): requests = ( reflection_pb2.ServerReflectionRequest( file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), reflection_pb2.ServerReflectionRequest( file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ))), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self.assertSequenceEqual(expected_responses, responses)
def testFileContainingExtension(self): requests = ( reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=125, ), ), reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( containing_type='i.donut.exist.co.uk.org.net.me.name.foo', extension_number=55, ), ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=(_file_descriptor_to_proto( empty2_extensions_pb2.DESCRIPTOR), ))), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self._assert_sequence_of_proto_equal(expected_responses, responses)
def build_database_from_stub(reflection_client): """Build descriptor pool and symbol database from reflection service. Args: reflection_client: ServerReflectionStub. GRPC reflection client Returns: tuple (descriptor pool, symbol database) """ services_response = funcy.first( # Note that this is stupid, but grpc has problems iterating over lists reflection_client.ServerReflectionInfo(x for x in [GET_SERVICES_REQUEST] )) service_names = [ service.name for service in services_response.list_services_response.service ] file_requests = (reflection_pb2.ServerReflectionRequest( file_containing_symbol=service_name) for service_name in service_names) file_descriptor_proto_bytes = get_proto_bytes_for_requests( reflection_client, file_requests) descriptor_pool = DescriptorPool() add_protos_to_descriptor_pool(reflection_client, descriptor_pool, file_descriptor_proto_bytes) symbol_database = SymbolDatabase(descriptor_pool) return (descriptor_pool, symbol_database)
def _get_file_descriptor(self, reflection_stub, service): # Get File Descriptor message = reflection_pb2.ServerReflectionRequest( file_containing_symbol=service.name) responses = reflection_stub.ServerReflectionInfo(iter((message, ))) for response in responses: file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( response.file_descriptor_response.file_descriptor_proto[0]) # Parse File Descriptor package = file_descriptor_proto.package version = package.split('.')[-1:][0] if version == self._version: api_class_name = file_descriptor_proto.service[0].name module_name = self._parse_proto_name( file_descriptor_proto.name, package) self._create_grpc_stub(package, module_name, api_class_name) self._preload_message_type(package, module_name, api_class_name, file_descriptor_proto) self._preload_related_message_type(file_descriptor_proto, module_name) self._preload_field_type(file_descriptor_proto.message_type, module_name)
def import_dependencies_then_proto(reflection_client, descriptor_pool, proto_name, imported_names, reflected_names): """Recursively add file descriptor protos to descriptor pool, first importing dependencies Args: reflection_client: ServerReflectionStub. GRPC reflection client descriptor_pool: DescriptorPool. GRPC descriptor pool where file descriptor protos will be loaded proto_name: string. Name of the file desciptor proto to be imported imported_names: set. The names of the memoized, already imported file descriptor protos reflected_names: dict. Map of to-be-imported file descriptor proto names to their respective file descriptor protos Returns: None """ if proto_name in imported_names: return # Anything in the proto_name depth-first search stack should already be reflected if proto_name not in reflected_names: raise RuntimeError( "Something went wrong. Stacked planned imports should either " "be already reflected or already imported. Please fix.") dependencies = reflected_names[proto_name].dependency # pylint: disable=no-member # Reflect not-yet-reflected dependencies the same way we run server reflection on root protos not_yet_reflected_dependencies = [] for dependency in dependencies: if (dependency not in imported_names) and (dependency not in reflected_names): not_yet_reflected_dependencies.append(dependency) if not_yet_reflected_dependencies: file_requests = (reflection_pb2.ServerReflectionRequest( file_by_filename=dependency) for dependency in not_yet_reflected_dependencies) file_descriptor_proto_bytes = get_proto_bytes_for_requests( reflection_client, file_requests) for serialized_proto in file_descriptor_proto_bytes: parsed_file_descriptor_proto = descriptor_pb2.FileDescriptorProto() parsed_file_descriptor_proto.ParseFromString(serialized_proto) file_descriptor_proto_name = ( parsed_file_descriptor_proto.name # pylint: disable=no-member ) reflected_names[ file_descriptor_proto_name] = parsed_file_descriptor_proto # Recursively import dependencies for dependency in dependencies: import_dependencies_then_proto(reflection_client, descriptor_pool, dependency, imported_names, reflected_names) # Import the proto itself, update memos descriptor_pool.Add(reflected_names[proto_name]) imported_names.add(proto_name) del reflected_names[proto_name]
def _list_services(stub): responses = stub.ServerReflectionInfo(iter( [reflection_pb2.ServerReflectionRequest(list_services="")]), timeout=QUERY_TIMEOUT) for response in responses: if response.HasField("error_response"): raise ServiceError(response.error_response.error_message) for service in response.list_services_response.service: yield service.name
def _get_server_reflection_info(self): reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) # List Services message = reflection_pb2.ServerReflectionRequest(list_services='') responses = reflection_stub.ServerReflectionInfo(iter((message,))) for response in responses: for service in response.list_services_response.service: self._get_file_descriptor(reflection_stub, service)
def testListServices(self): requests = (reflection_pb2.ServerReflectionRequest( list_services='', ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = (reflection_pb2.ServerReflectionResponse( valid_host='', list_services_response=reflection_pb2.ListServiceResponse( service=tuple( reflection_pb2.ServiceResponse(name=name) for name in _SERVICE_NAMES))), ) self._assert_sequence_of_proto_equal(expected_responses, responses)
async def test_list_services(self): requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = (reflection_pb2.ServerReflectionResponse( valid_host='', list_services_response=reflection_pb2.ListServiceResponse( service=tuple( reflection_pb2.ServiceResponse(name=name) for name in _SERVICE_NAMES))),) self.assertSequenceEqual(expected_responses, responses)
def test_get_proto_bytes_for_requests(self): # Test empty case file_requests_empty = [] proto_bytes_empty = reflection_descriptor_database.get_proto_bytes_for_requests( utils.reflection_client_mock, file_requests_empty) self.assertEqual([], proto_bytes_empty) # Test non-empty cases file_requests = (reflection_pb2.ServerReflectionRequest( file_by_filename=request) for request in ["proto_a", "proto_c"]) proto_bytes = reflection_descriptor_database.get_proto_bytes_for_requests( utils.reflection_client_mock, file_requests) self.assertIn(utils.PROTO_BYTES_A, proto_bytes) self.assertIn(utils.PROTO_BYTES_C, proto_bytes) self.assertNotIn(utils.PROTO_BYTES_B, proto_bytes)
def _get_file_descriptor_by_name(self, name): request = reflection_pb2.ServerReflectionRequest(file_by_filename=name) result = self._reflection_single_request(request) proto = result.file_descriptor_response.file_descriptor_proto[0] return descriptor_pb2.FileDescriptorProto.FromString(proto)
def load_protocols(self, channel, filenames=None, symbols=None): """Implementation of `GrpcReflectionClient.load_protocols`""" stub = reflection_pb2_grpc.ServerReflectionStub(channel) requests = [] if filenames: requests.extend( reflection_pb2.ServerReflectionRequest(file_by_filename=name) for name in filenames) if symbols: requests.extend( reflection_pb2.ServerReflectionRequest( file_containing_symbol=symbol) for symbol in symbols) if not requests: requests.extend( reflection_pb2.ServerReflectionRequest( file_containing_symbol=name) for name in _list_services(stub) if name != "grpc.reflection.v1alpha.ServerReflection") protos = {} traversed = set() while requests: responses = stub.ServerReflectionInfo(iter(requests), timeout=QUERY_TIMEOUT) deps = set() for response in responses: if response.HasField("error_response"): raise ServiceError(response.error_response.error_message) for desc_bytes in response.file_descriptor_response.file_descriptor_proto: proto = descriptor_pb2.FileDescriptorProto.FromString( # pylint: disable=no-member desc_bytes) traversed.add(proto.name) deps.update(proto.dependency) protos[proto.name] = proto self.methods_by_file[proto.name] = { service.name: service.method for service in proto.service } deps -= traversed requests = [ reflection_pb2.ServerReflectionRequest(file_by_filename=dep) for dep in deps ] # prevent unsatisfied deps from looping forever traversed.update(deps) names = deque(protos.keys()) traversed = set() while names: name = names[0] traversed.add(name) # raises KeyError if unsatisfied dep: proto = protos[name] deps = set(proto.dependency) - traversed if deps: names = deque(x for x in names if x not in deps) names.extendleft(deps) else: del names[0] self.pool.Add(proto) return protos.keys()
def _get_service_names(self): request = reflection_pb2.ServerReflectionRequest(list_services="") resp = self._reflection_single_request(request) services = tuple([s.name for s in resp.list_services_response.service]) return services
Because most usecases will require also requesting the transitive dependencies of requested files, the queries will also return all transitive dependencies of the returned file. Should interesting usecases for non-transitive queries turn up later, we can easily extend the protocol to support them. ``` note that the response for file_containing_symbol request contains a list of descriptors, not a single one. """ import funcy from google.protobuf import descriptor_pb2 from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.symbol_database import SymbolDatabase from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc GET_SERVICES_REQUEST = reflection_pb2.ServerReflectionRequest( list_services="services") def get_proto_bytes_for_requests(reflection_client, file_requests): """Return the file descriptor proto bytes list given file requests for services or files Args: reflection_client: ServerReflectionStub. GRPC reflection client file_requests: generator (ServerReflectionRequest). The requests for file descriptor protos to be queried on during server reflection Returns: bytes """ file_descriptors_responses = reflection_client.ServerReflectionInfo( file_requests)
def _get_file_descriptor_by_symbol(self, symbol): request = reflection_pb2.ServerReflectionRequest( file_containing_symbol=symbol) result = self._reflection_single_request(request) proto = result.file_descriptor_response.file_descriptor_proto[0] return descriptor_pb2.FileDescriptorProto.FromString(proto)