def str_to_schema(raw_schema):
    file_descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
        raw_schema)

    descriptor_pool = DescriptorPool()
    descriptor_pool.Add(file_descriptor_proto)

    # message_descriptors = []
    # for message_type in file_descriptor_proto.message_type:
    #     print(message_type)
    #     # The following line would raise error asking to import
    #     message_descriptors.append(descriptor.MakeDescriptor(message_type))

    name = file_descriptor_proto.name
    file_descriptor = descriptor_pool.FindFileByName(name)
    descriptor1 = descriptor_pool.FindMessageTypeByName('tutorial.Person')
    descriptor2 = descriptor_pool.FindMessageTypeByName('tutorial.AddressBook')
    descriptor3 = descriptor_pool.FindMessageTypeByName(
        'tutorial.Person.PhoneNumber')

    return descriptor1, descriptor2, descriptor3
def main():
    # Read request message from stdin
    data = sys.stdin.buffer.read()

    # Parse request
    request = plugin.CodeGeneratorRequest()
    request.ParseFromString(data)

    # Create response
    response = plugin.CodeGeneratorResponse()

    # TODO: clean this part.
    # Generate code
    table_resolver = TableResolver()
    analyzer = Analyzer(table_resolver)

    pool = DescriptorPool()

    for proto_file in request.proto_file:
        pool.Add(proto_file)

        analyzer.generate_tables_for_file(
            file_descriptor=pool.FindFileByName(proto_file.name))

    analyzer.link_tables_references()

    writer = ProtoPluginResponseWriter()
    writer.write(generator=KotlinExposedGenerator(),
                 tables=table_resolver.tables,
                 plugin_response=response)

    # Serialise response message
    output = response.SerializeToString()

    # Write to stdout
    sys.stdout.buffer.write(output)
Esempio n. 3
0
class ReflectionClientTest(unittest.TestCase):
    def setUp(self):
        self._server = test_common.test_server()
        self._SERVICE_NAMES = (
            test_pb2.DESCRIPTOR.services_by_name["TestService"].full_name,
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(self._SERVICE_NAMES, self._server)
        port = self._server.add_insecure_port("[::]:0")
        self._server.start()

        self._channel = grpc.insecure_channel("localhost:%d" % port)

        self._reflection_db = ProtoReflectionDescriptorDatabase(self._channel)
        self.desc_pool = DescriptorPool(self._reflection_db)

    def tearDown(self):
        self._server.stop(None)
        self._channel.close()

    def testListServices(self):
        services = self._reflection_db.get_services()
        self.assertCountEqual(self._SERVICE_NAMES, services)

    def testReflectionServiceName(self):
        self.assertEqual(reflection.SERVICE_NAME,
                         "grpc.reflection.v1alpha.ServerReflection")

    def testFindFile(self):
        file_name = _PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertEqual(file_name, file_desc.name)
        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
        self.assertEqual("proto3", file_desc.syntax)
        self.assertIn("TestService", file_desc.services_by_name)

        file_name = _EMPTY_PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertEqual(file_name, file_desc.name)
        self.assertEqual(_PROTO_PACKAGE_NAME, file_desc.package)
        self.assertEqual("proto3", file_desc.syntax)
        self.assertIn("Empty", file_desc.message_types_by_name)

    def testFindFileError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindFileByName(_INVALID_FILE_NAME)

    def testFindMessage(self):
        message_name = _EMPTY_PROTO_SYMBOL_NAME
        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
        self.assertEqual(message_name, message_desc.full_name)
        self.assertTrue(message_name.endswith(message_desc.name))

    def testFindMessageError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindMessageTypeByName(_INVALID_SYMBOL_NAME)

    def testFindServiceFindMethod(self):
        service_name = self._SERVICE_NAMES[0]
        service_desc = self.desc_pool.FindServiceByName(service_name)
        self.assertEqual(service_name, service_desc.full_name)
        self.assertTrue(service_name.endswith(service_desc.name))
        file_name = _PROTO_FILE_NAME
        file_desc = self.desc_pool.FindFileByName(file_name)
        self.assertIs(file_desc, service_desc.file)

        method_name = "EmptyCall"
        self.assertIn(method_name, service_desc.methods_by_name)

        method_desc = service_desc.FindMethodByName(method_name)
        self.assertIs(method_desc, service_desc.methods_by_name[method_name])
        self.assertIs(service_desc, method_desc.containing_service)
        self.assertEqual(method_name, method_desc.name)
        self.assertTrue(method_desc.full_name.endswith(method_name))

        empty_message_desc = self.desc_pool.FindMessageTypeByName(
            _EMPTY_PROTO_SYMBOL_NAME)
        self.assertEqual(empty_message_desc, method_desc.input_type)
        self.assertEqual(empty_message_desc, method_desc.output_type)

    def testFindServiceError(self):
        with self.assertRaises(KeyError):
            self.desc_pool.FindServiceByName(_INVALID_SYMBOL_NAME)

    def testFindMethodError(self):
        service_name = self._SERVICE_NAMES[0]
        service_desc = self.desc_pool.FindServiceByName(service_name)

        # FindMethodByName sometimes raises a KeyError, and sometimes returns None.
        # See https://github.com/protocolbuffers/protobuf/issues/9592
        with self.assertRaises(KeyError):
            res = service_desc.FindMethodByName(_INVALID_SYMBOL_NAME)
            if res is None:
                raise KeyError()

    def testFindExtensionNotImplemented(self):
        """
        Extensions aren't implemented in Protobuf for Python.
        For now, simply assert that indeed they don't work.
        """
        message_name = _EMPTY_EXTENSIONS_SYMBOL_NAME
        message_desc = self.desc_pool.FindMessageTypeByName(message_name)
        self.assertEqual(message_name, message_desc.full_name)
        self.assertTrue(message_name.endswith(message_desc.name))
        extension_field_descs = self.desc_pool.FindAllExtensions(message_desc)

        self.assertEqual(0, len(extension_field_descs))
        with self.assertRaises(KeyError):
            self.desc_pool.FindExtensionByName(message_name)