Example #1
0
 async def test_extension_numbers_of_type(self):
     requests = (
         reflection_pb2.ServerReflectionRequest(
             all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME),
         reflection_pb2.ServerReflectionRequest(
             all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo'
         ),
     )
     responses = []
     async for response in self._stub.ServerReflectionInfo(iter(requests)):
         responses.append(response)
     expected_responses = (
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             all_extension_numbers_response=reflection_pb2.
             ExtensionNumberResponse(
                 base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME,
                 extension_number=_EMPTY_EXTENSIONS_NUMBERS)),
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             error_response=reflection_pb2.ErrorResponse(
                 error_code=grpc.StatusCode.NOT_FOUND.value[0],
                 error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
             )),
     )
     self.assertSequenceEqual(expected_responses, responses)
Example #2
0
 def testFileByName(self):
   requests = (
     reflection_pb2.ServerReflectionRequest(
       file_by_filename=_EMPTY_PROTO_FILE_NAME
     ),
     reflection_pb2.ServerReflectionRequest(
       file_by_filename='i-donut-exist'
     ),
   )
   responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
   expected_responses = (
     reflection_pb2.ServerReflectionResponse(
       valid_host='',
       file_descriptor_response=reflection_pb2.FileDescriptorResponse(
         file_descriptor_proto=(
           _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
         )
       )
     ),
     reflection_pb2.ServerReflectionResponse(
       valid_host='',
       error_response=reflection_pb2.ErrorResponse(
         error_code=grpc.StatusCode.NOT_FOUND.value[0],
         error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
       )
     ),
   )
   self.assertSequenceEqual(expected_responses, responses)
Example #3
0
 async def test_file_by_symbol(self):
     requests = (
         reflection_pb2.ServerReflectionRequest(
             file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME),
         reflection_pb2.ServerReflectionRequest(
             file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'
         ),
     )
     responses = []
     async for response in self._stub.ServerReflectionInfo(iter(requests)):
         responses.append(response)
     expected_responses = (
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             file_descriptor_response=reflection_pb2.FileDescriptorResponse(
                 file_descriptor_proto=(
                     _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ))),
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             error_response=reflection_pb2.ErrorResponse(
                 error_code=grpc.StatusCode.NOT_FOUND.value[0],
                 error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
             )),
     )
     self.assertSequenceEqual(expected_responses, responses)
Example #4
0
 def testFileContainingExtension(self):
     requests = (
         reflection_pb2.ServerReflectionRequest(
             file_containing_extension=reflection_pb2.ExtensionRequest(
                 containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME,
                 extension_number=125,
             ), ),
         reflection_pb2.ServerReflectionRequest(
             file_containing_extension=reflection_pb2.ExtensionRequest(
                 containing_type='i.donut.exist.co.uk.org.net.me.name.foo',
                 extension_number=55,
             ), ),
     )
     responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
     expected_responses = (
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             file_descriptor_response=reflection_pb2.FileDescriptorResponse(
                 file_descriptor_proto=(_file_descriptor_to_proto(
                     empty2_extensions_pb2.DESCRIPTOR), ))),
         reflection_pb2.ServerReflectionResponse(
             valid_host='',
             error_response=reflection_pb2.ErrorResponse(
                 error_code=grpc.StatusCode.NOT_FOUND.value[0],
                 error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
             )),
     )
     self._assert_sequence_of_proto_equal(expected_responses, responses)
Example #5
0
def build_database_from_stub(reflection_client):
    """Build descriptor pool and symbol database from reflection service.

    Args:
        reflection_client: ServerReflectionStub. GRPC reflection client

    Returns:
        tuple (descriptor pool, symbol database)
    """
    services_response = funcy.first(
        # Note that this is stupid, but grpc has problems iterating over lists
        reflection_client.ServerReflectionInfo(x
                                               for x in [GET_SERVICES_REQUEST]
                                               ))
    service_names = [
        service.name
        for service in services_response.list_services_response.service
    ]
    file_requests = (reflection_pb2.ServerReflectionRequest(
        file_containing_symbol=service_name) for service_name in service_names)
    file_descriptor_proto_bytes = get_proto_bytes_for_requests(
        reflection_client, file_requests)

    descriptor_pool = DescriptorPool()
    add_protos_to_descriptor_pool(reflection_client, descriptor_pool,
                                  file_descriptor_proto_bytes)

    symbol_database = SymbolDatabase(descriptor_pool)
    return (descriptor_pool, symbol_database)
Example #6
0
    def _get_file_descriptor(self, reflection_stub, service):
        # Get File Descriptor
        message = reflection_pb2.ServerReflectionRequest(
            file_containing_symbol=service.name)
        responses = reflection_stub.ServerReflectionInfo(iter((message, )))
        for response in responses:
            file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
                response.file_descriptor_response.file_descriptor_proto[0])

            # Parse File Descriptor
            package = file_descriptor_proto.package
            version = package.split('.')[-1:][0]
            if version == self._version:
                api_class_name = file_descriptor_proto.service[0].name
                module_name = self._parse_proto_name(
                    file_descriptor_proto.name, package)

                self._create_grpc_stub(package, module_name, api_class_name)
                self._preload_message_type(package, module_name,
                                           api_class_name,
                                           file_descriptor_proto)

                self._preload_related_message_type(file_descriptor_proto,
                                                   module_name)
                self._preload_field_type(file_descriptor_proto.message_type,
                                         module_name)
Example #7
0
def import_dependencies_then_proto(reflection_client, descriptor_pool,
                                   proto_name, imported_names,
                                   reflected_names):
    """Recursively add file descriptor protos to descriptor pool, first importing dependencies

    Args:
        reflection_client: ServerReflectionStub. GRPC reflection client
        descriptor_pool: DescriptorPool. GRPC descriptor pool where file descriptor protos will be
            loaded
        proto_name: string. Name of the file desciptor proto to be imported
        imported_names: set. The names of the memoized, already imported file descriptor protos
        reflected_names: dict. Map of to-be-imported file descriptor proto names to their
            respective file descriptor protos

    Returns:
        None
    """
    if proto_name in imported_names:
        return

    # Anything in the proto_name depth-first search stack should already be reflected
    if proto_name not in reflected_names:
        raise RuntimeError(
            "Something went wrong. Stacked planned imports should either "
            "be already reflected or already imported. Please fix.")

    dependencies = reflected_names[proto_name].dependency  # pylint: disable=no-member
    # Reflect not-yet-reflected dependencies the same way we run server reflection on root protos
    not_yet_reflected_dependencies = []
    for dependency in dependencies:
        if (dependency not in imported_names) and (dependency
                                                   not in reflected_names):
            not_yet_reflected_dependencies.append(dependency)

    if not_yet_reflected_dependencies:
        file_requests = (reflection_pb2.ServerReflectionRequest(
            file_by_filename=dependency)
                         for dependency in not_yet_reflected_dependencies)
        file_descriptor_proto_bytes = get_proto_bytes_for_requests(
            reflection_client, file_requests)
        for serialized_proto in file_descriptor_proto_bytes:
            parsed_file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
            parsed_file_descriptor_proto.ParseFromString(serialized_proto)
            file_descriptor_proto_name = (
                parsed_file_descriptor_proto.name  # pylint: disable=no-member
            )
            reflected_names[
                file_descriptor_proto_name] = parsed_file_descriptor_proto

    # Recursively import dependencies
    for dependency in dependencies:
        import_dependencies_then_proto(reflection_client, descriptor_pool,
                                       dependency, imported_names,
                                       reflected_names)

    # Import the proto itself, update memos
    descriptor_pool.Add(reflected_names[proto_name])
    imported_names.add(proto_name)
    del reflected_names[proto_name]
Example #8
0
def _list_services(stub):
    responses = stub.ServerReflectionInfo(iter(
        [reflection_pb2.ServerReflectionRequest(list_services="")]),
                                          timeout=QUERY_TIMEOUT)
    for response in responses:
        if response.HasField("error_response"):
            raise ServiceError(response.error_response.error_message)
        for service in response.list_services_response.service:
            yield service.name
Example #9
0
    def _get_server_reflection_info(self):
        reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self._channel)

        # List Services
        message = reflection_pb2.ServerReflectionRequest(list_services='')
        responses = reflection_stub.ServerReflectionInfo(iter((message,)))

        for response in responses:
            for service in response.list_services_response.service:
                self._get_file_descriptor(reflection_stub, service)
Example #10
0
 def testListServices(self):
     requests = (reflection_pb2.ServerReflectionRequest(
         list_services='', ), )
     responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
     expected_responses = (reflection_pb2.ServerReflectionResponse(
         valid_host='',
         list_services_response=reflection_pb2.ListServiceResponse(
             service=tuple(
                 reflection_pb2.ServiceResponse(name=name)
                 for name in _SERVICE_NAMES))), )
     self._assert_sequence_of_proto_equal(expected_responses, responses)
Example #11
0
 async def test_list_services(self):
     requests = (reflection_pb2.ServerReflectionRequest(list_services='',),)
     responses = []
     async for response in self._stub.ServerReflectionInfo(iter(requests)):
         responses.append(response)
     expected_responses = (reflection_pb2.ServerReflectionResponse(
         valid_host='',
         list_services_response=reflection_pb2.ListServiceResponse(
             service=tuple(
                 reflection_pb2.ServiceResponse(name=name)
                 for name in _SERVICE_NAMES))),)
     self.assertSequenceEqual(expected_responses, responses)
Example #12
0
    def test_get_proto_bytes_for_requests(self):
        # Test empty case
        file_requests_empty = []
        proto_bytes_empty = reflection_descriptor_database.get_proto_bytes_for_requests(
            utils.reflection_client_mock, file_requests_empty)
        self.assertEqual([], proto_bytes_empty)

        # Test non-empty cases
        file_requests = (reflection_pb2.ServerReflectionRequest(
            file_by_filename=request) for request in ["proto_a", "proto_c"])
        proto_bytes = reflection_descriptor_database.get_proto_bytes_for_requests(
            utils.reflection_client_mock, file_requests)
        self.assertIn(utils.PROTO_BYTES_A, proto_bytes)
        self.assertIn(utils.PROTO_BYTES_C, proto_bytes)
        self.assertNotIn(utils.PROTO_BYTES_B, proto_bytes)
Example #13
0
 def _get_file_descriptor_by_name(self, name):
     request = reflection_pb2.ServerReflectionRequest(file_by_filename=name)
     result = self._reflection_single_request(request)
     proto = result.file_descriptor_response.file_descriptor_proto[0]
     return descriptor_pb2.FileDescriptorProto.FromString(proto)
Example #14
0
    def load_protocols(self, channel, filenames=None, symbols=None):
        """Implementation of `GrpcReflectionClient.load_protocols`"""
        stub = reflection_pb2_grpc.ServerReflectionStub(channel)

        requests = []
        if filenames:
            requests.extend(
                reflection_pb2.ServerReflectionRequest(file_by_filename=name)
                for name in filenames)
        if symbols:
            requests.extend(
                reflection_pb2.ServerReflectionRequest(
                    file_containing_symbol=symbol) for symbol in symbols)
        if not requests:
            requests.extend(
                reflection_pb2.ServerReflectionRequest(
                    file_containing_symbol=name)
                for name in _list_services(stub)
                if name != "grpc.reflection.v1alpha.ServerReflection")

        protos = {}
        traversed = set()
        while requests:
            responses = stub.ServerReflectionInfo(iter(requests),
                                                  timeout=QUERY_TIMEOUT)
            deps = set()
            for response in responses:
                if response.HasField("error_response"):
                    raise ServiceError(response.error_response.error_message)
                for desc_bytes in response.file_descriptor_response.file_descriptor_proto:
                    proto = descriptor_pb2.FileDescriptorProto.FromString(  # pylint: disable=no-member
                        desc_bytes)
                    traversed.add(proto.name)
                    deps.update(proto.dependency)
                    protos[proto.name] = proto
                    self.methods_by_file[proto.name] = {
                        service.name: service.method
                        for service in proto.service
                    }
            deps -= traversed
            requests = [
                reflection_pb2.ServerReflectionRequest(file_by_filename=dep)
                for dep in deps
            ]
            # prevent unsatisfied deps from looping forever
            traversed.update(deps)

        names = deque(protos.keys())
        traversed = set()
        while names:
            name = names[0]
            traversed.add(name)
            # raises KeyError if unsatisfied dep:
            proto = protos[name]
            deps = set(proto.dependency) - traversed
            if deps:
                names = deque(x for x in names if x not in deps)
                names.extendleft(deps)
            else:
                del names[0]
                self.pool.Add(proto)

        return protos.keys()
Example #15
0
 def _get_service_names(self):
     request = reflection_pb2.ServerReflectionRequest(list_services="")
     resp = self._reflection_single_request(request)
     services = tuple([s.name for s in resp.list_services_response.service])
     return services
Example #16
0
Because most usecases will require also requesting the transitive dependencies of requested files,
the queries will also return all transitive dependencies of the returned file.
Should interesting usecases for non-transitive queries turn up later, we can easily extend the
protocol to support them.
```
note that the response for file_containing_symbol request contains a list of descriptors, not a
single one.
"""

import funcy
from google.protobuf import descriptor_pb2
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.symbol_database import SymbolDatabase
from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc

GET_SERVICES_REQUEST = reflection_pb2.ServerReflectionRequest(
    list_services="services")


def get_proto_bytes_for_requests(reflection_client, file_requests):
    """Return the file descriptor proto bytes list given file requests for services or files

    Args:
        reflection_client: ServerReflectionStub. GRPC reflection client
        file_requests: generator (ServerReflectionRequest). The requests for file descriptor
            protos to be queried on during server reflection

    Returns:
        bytes
    """
    file_descriptors_responses = reflection_client.ServerReflectionInfo(
        file_requests)
Example #17
0
 def _get_file_descriptor_by_symbol(self, symbol):
     request = reflection_pb2.ServerReflectionRequest(
         file_containing_symbol=symbol)
     result = self._reflection_single_request(request)
     proto = result.file_descriptor_response.file_descriptor_proto[0]
     return descriptor_pb2.FileDescriptorProto.FromString(proto)