def _make_descriptor(descriptor_proto, package, full_name): """ We basically need to re-implement the CPP API implementation of Protobuf's MakeDescriptor call here. The one provided by Google creates a file descriptor proto with a GUID-ish name, sticks the provided descriptor proto, adds that file into a default descriptor pool, then calls on the descriptor pool to return a descriptor with everything resolved. Unfortunately, if you have fields that are message types which require importing another file, there's no way to provide that import in the default MakeDescriptor() call. This call basically copies the default implementation, but instead of using the default pool, it uses a custom descriptor pool with Bonsai's Inkling Types already imported. It also adds the required import to the generated FileDescriptorProto for the schema represented by descriptor_proto. for reference, see: https://github.com/google/protobuf/python/google/protobuf/descriptor.py :param descriptor_proto: The descriptor proto to turn into a descriptor. :return: A descriptor corresponding to descriptor_proto. """ # The descriptor may already exist... look for it first. pool = _message_factory.pool try: return pool.FindMessageTypeByName(full_name) except KeyError: pass proto_name = str(uuid.uuid4()) proto_path = os.path.join(package, proto_name + '.proto') file_descriptor_proto = FileDescriptorProto() file_descriptor_proto.message_type.add().MergeFrom(descriptor_proto) file_descriptor_proto.name = proto_path file_descriptor_proto.package = package file_descriptor_proto.dependency.append('bonsai/proto/inkling_types.proto') # Not sure why this is needed; there's no documentation indicating how this # field is used. Some Google unit tests do this when adding a dependency, # so it's being done here too. file_descriptor_proto.public_dependency.append(0) pool.Add(file_descriptor_proto) result = pool.FindFileByName(proto_path) return result.message_types_by_name[descriptor_proto.name]
def _find_descriptor(self, desc_proto, package): if desc_proto is None: return None full_name = '{}.{}'.format(package, desc_proto.name) pool = self._message_factory.pool try: return pool.FindMessageTypeByName(full_name) except KeyError: pass proto_name = str(uuid.uuid4()) proto_path = os.path.join(package, proto_name + '.proto') file_desc_proto = FileDescriptorProto() file_desc_proto.message_type.add().MergeFrom(desc_proto) file_desc_proto.name = proto_path file_desc_proto.package = package file_desc_proto.dependency.append('bonsai/proto/inkling_types.proto') file_desc_proto.public_dependency.append(0) pool.Add(file_desc_proto) result = pool.FindFileByName(proto_path) return result.message_types_by_name[desc_proto.name]
def nest_and_print_to_files(msg_path_to_obj, msg_to_referrers): msg_to_topmost = {} msg_to_newloc = {} newloc_to_msg = {} msg_to_imports = defaultdict(list) for msg, referrers in msg_to_referrers.items(): for _, referrer, _ in referrers: msg_to_imports[referrer].append(msg) # Iterate over referred to messages/groups/enums. # Merge groups first: msg_to_referrers = OrderedDict( sorted(msg_to_referrers.items(), key=lambda x: -x[1][0][2])) mergeable = {} enumfield_to_enums = defaultdict(set) enum_to_dupfields = defaultdict(set) for msg, referrers in dict(msg_to_referrers).items(): msg_pkg = get_pkg(msg) msg_obj = msg_path_to_obj[msg] # Check for duplicate enum fields in the same package: if not isinstance(msg_obj, DescriptorProto): for enum_field in msg_obj.value: name = msg_pkg + '.' + enum_field.name enumfield_to_enums[name].add(msg) if len(enumfield_to_enums[name]) > 1: for other_enum in enumfield_to_enums[name]: enum_to_dupfields[other_enum].add(name) first_field = referrers[0] field, referrer, is_group = first_field # Check whether message/enum has exactly one reference in this # package: if not is_group: in_pkg = [(field, referrer) for field, referrer, _ in referrers \ if (get_pkg(referrer) == msg_pkg or not msg_pkg) \ and msg_to_topmost.get(referrer, referrer) != msg \ and not msg_path_to_obj[referrer].options.map_entry \ and ('$' not in msg or msg.split('.')[-1].split('$')[0] == \ referrer.split('.')[-1].split('$')[0])] if len({i for _, i in in_pkg}) != 1: # It doesn't. Keep for the next step if in_pkg: mergeable[msg] = in_pkg continue else: field, referrer = in_pkg[0] else: assert len(referrers) == 1 merge_and_rename(msg, referrer, msg_pkg, is_group, msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) # Try to fix recursive (mutual) imports, and conflicting enum field names. for msg, in_pkg in mergeable.items(): duplicate_enumfields = enum_to_dupfields.get(msg, set()) for field, referrer in sorted( in_pkg, key=lambda x: msg_to_newloc.get(x[1], x[1]).count('.')): top_referrer = msg_to_topmost.get(referrer, referrer) if (msg in msg_to_imports[top_referrer] and \ top_referrer in msg_to_imports[msg] and \ msg_to_topmost.get(referrer, referrer) != msg) or \ duplicate_enumfields: merge_and_rename(msg, referrer, get_pkg(msg), False, msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) break for dupfield in duplicate_enumfields: siblings = enumfield_to_enums[dupfield] siblings.remove(msg) if len(siblings) == 1: enum_to_dupfields[siblings.pop()].remove(dupfield) for msg, msg_obj in msg_path_to_obj.items(): # If we're a top-level message, enforce name transforms anyway if msg not in msg_to_topmost: new_name = msg_obj.name.split('$')[-1] new_name = new_name[0].upper() + new_name[1:] msg_pkg = get_pkg(msg) if msg_pkg: msg_pkg += '.' if new_name != msg_obj.name: while newloc_to_msg.get(msg_pkg + new_name, msg_pkg + new_name) in msg_path_to_obj and \ newloc_to_msg.get(msg_pkg + new_name, msg_pkg + new_name) not in msg_to_topmost: new_name += '_' msg_obj.name = new_name fix_naming(msg_obj, msg_pkg + new_name, msg, msg, msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) # Turn messages into individual files and stringify. path_to_file = OrderedDict() path_to_defines = defaultdict(list) for msg, msg_obj in msg_path_to_obj.items(): if msg not in msg_to_topmost: path = msg.split('$')[0].replace('.', '/') + '.proto' if path not in path_to_file: path_to_file[path] = FileDescriptorProto() path_to_file[path].syntax = 'proto2' path_to_file[path].package = get_pkg(msg) path_to_file[path].name = path file_obj = path_to_file[path] for imported in msg_to_imports[msg]: import_path = imported.split('$')[0].replace('.', '/') + '.proto' if import_path != path and imported not in msg_to_topmost: if import_path not in file_obj.dependency: file_obj.dependency.append(import_path) if isinstance(msg_obj, DescriptorProto): nested = file_obj.message_type.add() else: nested = file_obj.enum_type.add() nested.MergeFrom(msg_obj) path_to_defines[path].append(msg) path_to_defines[path] += [ k for k, v in msg_to_topmost.items() if v == msg and '$map' not in k ] for path, file_obj in path_to_file.items(): name, proto = descpb_to_proto(file_obj) header_lines = ['/**', 'Messages defined in this file:\n'] header_lines += path_to_defines[path] yield name, '\n * '.join(header_lines) + '\n */\n\n' + proto