def get_type_reference( package: str, imports: set, source_type: str, unwrap: bool = True, ) -> str: """ Return a Python type name for a proto type reference. Adds the import if necessary. Unwraps well known type if required. """ if unwrap: if source_type in WRAPPER_TYPES: wrapped_type = type(WRAPPER_TYPES[source_type]().value) return f"Optional[{wrapped_type.__name__}]" if source_type == ".google.protobuf.Duration": return "timedelta" if source_type == ".google.protobuf.Timestamp": return "datetime" source_package, source_type = parse_source_type_name(source_type) current_package: List[str] = package.split(".") if package else [] py_package: List[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) compiling_google_protobuf = current_package == ["google", "protobuf"] importing_google_protobuf = py_package == ["google", "protobuf"] if importing_google_protobuf and not compiling_google_protobuf: py_package = ["betterproto", "lib"] + py_package if py_package[:1] == ["betterproto"]: return reference_absolute(imports, py_package, py_type) if py_package == current_package: return reference_sibling(py_type) if py_package[:len(current_package)] == current_package: return reference_descendent(current_package, imports, py_package, py_type) if current_package[:len(py_package)] == py_package: return reference_ancestor(current_package, imports, py_package, py_type) return reference_cousin(current_package, imports, py_package, py_type)
def py_name(self) -> str: return pythonize_class_name(self.proto_name)
def read_protobuf_service(service: ServiceDescriptorProto, index, proto_file, content, output_types): input_package_name = content["input_package"] template_data = content["template_data"] # print(service, file=sys.stderr) data = { "name": service.name, "py_name": pythonize_class_name(service.name), "comment": get_comment(proto_file, [6, index]), "methods": [], } for j, method in enumerate(service.method): method_input_message = lookup_method_input_type(method, output_types) # This section ensures that method arguments having a default # value that is initialised as a List/Dict (mutable) is replaced # with None and initialisation is deferred to the beginning of the # method definition. This is done so to avoid any side-effects. # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments mutable_default_args = [] if method_input_message: for field in method_input_message["properties"]: if (not method.client_streaming and field["zero"] != "None" and is_mutable_field_type(field["type"])): mutable_default_args.append( (field["py_name"], field["zero"])) field["zero"] = "None" if field["zero"] == "None": template_data["typing_imports"].add("Optional") data["methods"].append({ "name": method.name, "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, index, 2, j], indent=8), "route": f"/{input_package_name}.{service.name}/{method.name}", "input": get_type_reference(input_package_name, template_data["imports"], method.input_type).strip('"'), "input_message": method_input_message, "output": get_type_reference( input_package_name, template_data["imports"], method.output_type, unwrap=False, ), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, "mutable_default_args": mutable_default_args, }) if method.client_streaming: template_data["typing_imports"].add("AsyncIterable") template_data["typing_imports"].add("Iterable") template_data["typing_imports"].add("Union") if method.server_streaming: template_data["typing_imports"].add("AsyncIterator") template_data["services"].append(data)
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): input_package_name = content["input_package"] template_data = content["template_data"] data = { "name": item.name, "py_name": pythonize_class_name(item.name), "descriptor": item, "package": input_package_name, } if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) if item.options.map_entry: # Skip generated map entry messages since we just use dicts return data.update({ "type": "Message", "comment": get_comment(proto_file, path), "properties": [], }) for i, f in enumerate(item.field): t = py_type(input_package_name, template_data["imports"], f) zero = get_py_zero(f.type) repeated = False packed = False field_type = f.Type.Name(f.type).lower()[5:] field_wraps = "" match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name) if match_wrapper: wrapped_type = "TYPE_" + match_wrapper.group(1).upper() if hasattr(betterproto, wrapped_type): field_wraps = f"betterproto.{wrapped_type}" map_types = None if f.type == 11: # This might be a map... message_type = f.type_name.split(".").pop().lower() # message_type = py_type(package) map_entry = f"{f.name.replace('_', '').lower()}entry" if message_type == map_entry: for nested in item.nested_type: if nested.name.replace("_", "").lower() == map_entry: if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( input_package_name, template_data["imports"], nested.field[0], ) v = py_type( input_package_name, template_data["imports"], nested.field[1], ) t = f"Dict[{k}, {v}]" field_type = "map" map_types = ( f.Type.Name(nested.field[0].type), f.Type.Name(nested.field[1].type), ) template_data["typing_imports"].add("Dict") if f.label == 3 and field_type != "map": # Repeated field repeated = True t = f"List[{t}]" zero = "[]" template_data["typing_imports"].add("List") if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: packed = True one_of = "" if f.HasField("oneof_index"): one_of = item.oneof_decl[f.oneof_index].name if "Optional[" in t: template_data["typing_imports"].add("Optional") if "timedelta" in t: template_data["datetime_imports"].add("timedelta") elif "datetime" in t: template_data["datetime_imports"].add("datetime") data["properties"].append({ "name": f.name, "py_name": pythonize_field_name(f.name), "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), "field_type": field_type, "field_wraps": field_wraps, "map_types": map_types, "type": t, "zero": zero, "repeated": repeated, "packed": packed, "one_of": one_of, }) # print(f, file=sys.stderr) template_data["messages"].append(data) return data elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update({ "type": "Enum", "comment": get_comment(proto_file, path), "entries": [{ "name": v.name, "value": v.number, "comment": get_comment(proto_file, path + [2, i]), } for i, v in enumerate(item.value)], }) template_data["enums"].append(data) return data