async def test_file_by_symbol(self): requests = ( reflection_pb2.ServerReflectionRequest( file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), reflection_pb2.ServerReflectionRequest( file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ))), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self.assertSequenceEqual(expected_responses, responses)
def testFileByName(self): requests = ( reflection_pb2.ServerReflectionRequest( file_by_filename=_EMPTY_PROTO_FILE_NAME ), reflection_pb2.ServerReflectionRequest( file_by_filename='i-donut-exist' ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ) ), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ) ), ) self.assertSequenceEqual(expected_responses, responses)
async def test_extension_numbers_of_type(self): requests = ( reflection_pb2.ServerReflectionRequest( all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), reflection_pb2.ServerReflectionRequest( all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', all_extension_numbers_response=reflection_pb2. ExtensionNumberResponse( base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=_EMPTY_EXTENSIONS_NUMBERS)), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self.assertSequenceEqual(expected_responses, responses)
def testFileContainingExtension(self): requests = ( reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=125, ), ), reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( containing_type='i.donut.exist.co.uk.org.net.me.name.foo', extension_number=55, ), ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( valid_host='', file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=(_file_descriptor_to_proto( empty2_extensions_pb2.DESCRIPTOR), ))), reflection_pb2.ServerReflectionResponse( valid_host='', error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), )), ) self._assert_sequence_of_proto_equal(expected_responses, responses)
def _file_descriptor_response(descriptor): proto = descriptor_pb2.FileDescriptorProto() descriptor.CopyToProto(proto) serialized_proto = proto.SerializeToString() return _reflection_pb2.ServerReflectionResponse( file_descriptor_response=_reflection_pb2.FileDescriptorResponse( file_descriptor_proto=(serialized_proto,)),)
async def ServerReflectionInfo( self, request_iterator: AsyncIterable[ _reflection_pb2.ServerReflectionRequest], unused_context ) -> AsyncIterable[_reflection_pb2.ServerReflectionResponse]: async for request in request_iterator: if request.HasField('file_by_filename'): yield self._file_by_filename(request.file_by_filename) elif request.HasField('file_containing_symbol'): yield self._file_containing_symbol( request.file_containing_symbol) elif request.HasField('file_containing_extension'): yield self._file_containing_extension( request.file_containing_extension.containing_type, request.file_containing_extension.extension_number) elif request.HasField('all_extension_numbers_of_type'): yield self._all_extension_numbers_of_type( request.all_extension_numbers_of_type) elif request.HasField('list_services'): yield self._list_services() else: yield _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0], error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]. encode(), ))
def _list_services(self): return _reflection_pb2.ServerReflectionResponse( list_services_response=_reflection_pb2.ListServiceResponse( service=[ _reflection_pb2.ServiceResponse(name=service_name) for service_name in self._service_names ]))
def _file_containing_extension(containing_type, extension_number): # TODO(atash) Python protobuf currently doesn't support querying extensions. # https://github.com/google/protobuf/issues/2248 return reflection_pb2.ServerReflectionResponse( error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.UNIMPLEMENTED.value[0], error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(), ))
def _extension_numbers_of_type(fully_qualified_name): # TODO(atash) We're allowed to leave this unsupported according to the # protocol, but we should still eventually implement it. Hits the same issue # as `_file_containing_extension`, however. # https://github.com/google/protobuf/issues/2248 return reflection_pb2.ServerReflectionResponse( error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.UNIMPLEMENTED.value[0], error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(), ))
def testListServices(self): requests = (reflection_pb2.ServerReflectionRequest( list_services='', ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = (reflection_pb2.ServerReflectionResponse( valid_host='', list_services_response=reflection_pb2.ListServiceResponse( service=tuple( reflection_pb2.ServiceResponse(name=name) for name in _SERVICE_NAMES))), ) self._assert_sequence_of_proto_equal(expected_responses, responses)
async def test_list_services(self): requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = (reflection_pb2.ServerReflectionResponse( valid_host='', list_services_response=reflection_pb2.ListServiceResponse( service=tuple( reflection_pb2.ServiceResponse(name=name) for name in _SERVICE_NAMES))),) self.assertSequenceEqual(expected_responses, responses)
def _all_extension_numbers_of_type(self, containing_type): try: message_descriptor = self._pool.FindMessageTypeByName(containing_type) extension_numbers = tuple(sorted( extension.number for extension in self._pool.FindAllExtensions(message_descriptor))) except KeyError: return _not_found_error() else: return reflection_pb2.ServerReflectionResponse( all_extension_numbers_response=reflection_pb2. ExtensionNumberResponse( base_type_name=message_descriptor.full_name, extension_number=extension_numbers))
def _not_found_error(): return _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ))
def mock_error(self): response = reflection_pb2.ServerReflectionResponse() response.error_response.error_code = 1 response.error_response.error_message = "fake error" return response