示例#1
0
def _make_descriptor(descriptor_proto, package, full_name):
    """
    We basically need to re-implement the CPP API implementation of Protobuf's
    MakeDescriptor call here. The one provided by Google creates a file
    descriptor proto with a GUID-ish name, sticks the provided descriptor
    proto, adds that file into a default descriptor pool, then calls on the
    descriptor pool to return a descriptor with everything resolved.
    Unfortunately, if you have fields that are message types which require
    importing another file, there's no way to provide that import in the
    default MakeDescriptor() call.

    This call basically copies the default implementation, but instead of using
    the default pool, it uses a custom descriptor pool with Bonsai's Inkling
    Types already imported. It also adds the required import to the generated
    FileDescriptorProto for the schema represented by descriptor_proto.

    for reference, see:
    https://github.com/google/protobuf/python/google/protobuf/descriptor.py

    :param descriptor_proto: The descriptor proto to turn into a descriptor.
    :return: A descriptor corresponding to descriptor_proto.
    """

    # The descriptor may already exist... look for it first.
    pool = _message_factory.pool
    try:
        return pool.FindMessageTypeByName(full_name)
    except KeyError:
        pass

    proto_name = str(uuid.uuid4())
    proto_path = os.path.join(package, proto_name + '.proto')
    file_descriptor_proto = FileDescriptorProto()
    file_descriptor_proto.message_type.add().MergeFrom(descriptor_proto)
    file_descriptor_proto.name = proto_path
    file_descriptor_proto.package = package
    file_descriptor_proto.dependency.append('bonsai/proto/inkling_types.proto')

    # Not sure why this is needed; there's no documentation indicating how this
    # field is used. Some Google unit tests do this when adding a dependency,
    # so it's being done here too.
    file_descriptor_proto.public_dependency.append(0)

    pool.Add(file_descriptor_proto)
    result = pool.FindFileByName(proto_path)
    return result.message_types_by_name[descriptor_proto.name]
示例#2
0
    def _find_descriptor(self, desc_proto, package):
        if desc_proto is None:
            return None
        full_name = '{}.{}'.format(package, desc_proto.name)
        pool = self._message_factory.pool
        try:
            return pool.FindMessageTypeByName(full_name)
        except KeyError:
            pass

        proto_name = str(uuid.uuid4())
        proto_path = os.path.join(package, proto_name + '.proto')
        file_desc_proto = FileDescriptorProto()
        file_desc_proto.message_type.add().MergeFrom(desc_proto)
        file_desc_proto.name = proto_path
        file_desc_proto.package = package

        file_desc_proto.dependency.append('bonsai/proto/inkling_types.proto')

        file_desc_proto.public_dependency.append(0)

        pool.Add(file_desc_proto)
        result = pool.FindFileByName(proto_path)
        return result.message_types_by_name[desc_proto.name]
示例#3
0
def nest_and_print_to_files(msg_path_to_obj, msg_to_referrers):
    msg_to_topmost = {}
    msg_to_newloc = {}
    newloc_to_msg = {}
    msg_to_imports = defaultdict(list)
    for msg, referrers in msg_to_referrers.items():
        for _, referrer, _ in referrers:
            msg_to_imports[referrer].append(msg)

    # Iterate over referred to messages/groups/enums.

    # Merge groups first:
    msg_to_referrers = OrderedDict(
        sorted(msg_to_referrers.items(), key=lambda x: -x[1][0][2]))

    mergeable = {}
    enumfield_to_enums = defaultdict(set)
    enum_to_dupfields = defaultdict(set)

    for msg, referrers in dict(msg_to_referrers).items():
        msg_pkg = get_pkg(msg)
        msg_obj = msg_path_to_obj[msg]

        # Check for duplicate enum fields in the same package:
        if not isinstance(msg_obj, DescriptorProto):
            for enum_field in msg_obj.value:
                name = msg_pkg + '.' + enum_field.name
                enumfield_to_enums[name].add(msg)

                if len(enumfield_to_enums[name]) > 1:
                    for other_enum in enumfield_to_enums[name]:
                        enum_to_dupfields[other_enum].add(name)

        first_field = referrers[0]
        field, referrer, is_group = first_field

        # Check whether message/enum has exactly one reference in this
        # package:
        if not is_group:
            in_pkg = [(field, referrer) for field, referrer, _ in referrers \
                      if (get_pkg(referrer) == msg_pkg or not msg_pkg) \
                      and msg_to_topmost.get(referrer, referrer) != msg \
                      and not msg_path_to_obj[referrer].options.map_entry \
                      and ('$' not in msg or msg.split('.')[-1].split('$')[0] == \
                                        referrer.split('.')[-1].split('$')[0])]

            if len({i for _, i in in_pkg}) != 1:
                # It doesn't. Keep for the next step
                if in_pkg:
                    mergeable[msg] = in_pkg
                continue
            else:
                field, referrer = in_pkg[0]

        else:
            assert len(referrers) == 1

        merge_and_rename(msg, referrer, msg_pkg, is_group, msg_to_referrers,
                         msg_to_topmost, msg_to_newloc, msg_to_imports,
                         msg_path_to_obj, newloc_to_msg)

    # Try to fix recursive (mutual) imports, and conflicting enum field names.
    for msg, in_pkg in mergeable.items():
        duplicate_enumfields = enum_to_dupfields.get(msg, set())

        for field, referrer in sorted(
                in_pkg,
                key=lambda x: msg_to_newloc.get(x[1], x[1]).count('.')):
            top_referrer = msg_to_topmost.get(referrer, referrer)

            if (msg in msg_to_imports[top_referrer] and \
                top_referrer in msg_to_imports[msg] and \
                msg_to_topmost.get(referrer, referrer) != msg) or \
                duplicate_enumfields:

                merge_and_rename(msg, referrer, get_pkg(msg), False,
                                 msg_to_referrers, msg_to_topmost,
                                 msg_to_newloc, msg_to_imports,
                                 msg_path_to_obj, newloc_to_msg)
                break

        for dupfield in duplicate_enumfields:
            siblings = enumfield_to_enums[dupfield]
            siblings.remove(msg)
            if len(siblings) == 1:
                enum_to_dupfields[siblings.pop()].remove(dupfield)

    for msg, msg_obj in msg_path_to_obj.items():
        # If we're a top-level message, enforce name transforms anyway
        if msg not in msg_to_topmost:
            new_name = msg_obj.name.split('$')[-1]
            new_name = new_name[0].upper() + new_name[1:]

            msg_pkg = get_pkg(msg)
            if msg_pkg:
                msg_pkg += '.'

            if new_name != msg_obj.name:
                while newloc_to_msg.get(msg_pkg + new_name, msg_pkg + new_name) in msg_path_to_obj and \
                      newloc_to_msg.get(msg_pkg + new_name, msg_pkg + new_name) not in msg_to_topmost:
                    new_name += '_'
                msg_obj.name = new_name

            fix_naming(msg_obj, msg_pkg + new_name, msg, msg, msg_to_referrers,
                       msg_to_topmost, msg_to_newloc, msg_to_imports,
                       msg_path_to_obj, newloc_to_msg)

    # Turn messages into individual files and stringify.

    path_to_file = OrderedDict()
    path_to_defines = defaultdict(list)

    for msg, msg_obj in msg_path_to_obj.items():
        if msg not in msg_to_topmost:
            path = msg.split('$')[0].replace('.', '/') + '.proto'

            if path not in path_to_file:
                path_to_file[path] = FileDescriptorProto()
                path_to_file[path].syntax = 'proto2'
                path_to_file[path].package = get_pkg(msg)
                path_to_file[path].name = path
            file_obj = path_to_file[path]

            for imported in msg_to_imports[msg]:
                import_path = imported.split('$')[0].replace('.',
                                                             '/') + '.proto'
                if import_path != path and imported not in msg_to_topmost:
                    if import_path not in file_obj.dependency:
                        file_obj.dependency.append(import_path)

            if isinstance(msg_obj, DescriptorProto):
                nested = file_obj.message_type.add()
            else:
                nested = file_obj.enum_type.add()
            nested.MergeFrom(msg_obj)

            path_to_defines[path].append(msg)
            path_to_defines[path] += [
                k for k, v in msg_to_topmost.items()
                if v == msg and '$map' not in k
            ]

    for path, file_obj in path_to_file.items():
        name, proto = descpb_to_proto(file_obj)
        header_lines = ['/**', 'Messages defined in this file:\n']
        header_lines += path_to_defines[path]
        yield name, '\n * '.join(header_lines) + '\n */\n\n' + proto