def generate_code(request, response): plugin_options = request.parameter.split(",") if request.parameter else [] env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)), ) template = env.get_template("template.py.j2") output_map = {} for proto_file in request.proto_file: out = proto_file.package if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options: continue if not out: out = os.path.splitext(proto_file.name)[0].replace( os.path.sep, ".") if out not in output_map: output_map[out] = {"package": proto_file.package, "files": []} output_map[out]["files"].append(proto_file) # TODO: Figure out how to handle gRPC request/response messages and add # processing below for Service. for filename, options in output_map.items(): package = options["package"] # print(package, filename, file=sys.stderr) output = { "package": package, "files": [f.name for f in options["files"]], "imports": set(), "datetime_imports": set(), "typing_imports": set(), "messages": [], "enums": [], "services": [], } type_mapping = {} for proto_file in options["files"]: # print(proto_file.message_type, file=sys.stderr) # print(proto_file.service, file=sys.stderr) # print(proto_file.source_code_info, file=sys.stderr) for item, path in traverse(proto_file): # print(item, file=sys.stderr) # print(path, file=sys.stderr) data = { "name": item.name, "py_name": stringcase.pascalcase(item.name) } if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) if item.options.map_entry: # Skip generated map entry messages since we just use dicts continue data.update({ "type": "Message", "comment": get_comment(proto_file, path), "properties": [], }) for i, f in enumerate(item.field): t = py_type(package, output["imports"], item, f) zero = get_py_zero(f.type) repeated = False packed = False field_type = f.Type.Name(f.type).lower()[5:] field_wraps = "" match_wrapper = re.match( r"\.google\.protobuf\.(.+)Value", f.type_name) if match_wrapper: wrapped_type = "TYPE_" + match_wrapper.group( 1).upper() if hasattr(betterproto, wrapped_type): field_wraps = f"betterproto.{wrapped_type}" map_types = None if f.type == 11: # This might be a map... message_type = f.type_name.split(".").pop().lower() # message_type = py_type(package) map_entry = f"{f.name.replace('_', '').lower()}entry" if message_type == map_entry: for nested in item.nested_type: if (nested.name.replace( "_", "").lower() == map_entry): if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( package, output["imports"], item, nested.field[0], ) v = py_type( package, output["imports"], item, nested.field[1], ) t = f"Dict[{k}, {v}]" field_type = "map" map_types = ( f.Type.Name( nested.field[0].type), f.Type.Name( nested.field[1].type), ) output["typing_imports"].add( "Dict") if f.label == 3 and field_type != "map": # Repeated field repeated = True t = f"List[{t}]" zero = "[]" output["typing_imports"].add("List") if f.type in [ 1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18 ]: packed = True one_of = "" if f.HasField("oneof_index"): one_of = item.oneof_decl[f.oneof_index].name if "Optional[" in t: output["typing_imports"].add("Optional") if "timedelta" in t: output["datetime_imports"].add("timedelta") elif "datetime" in t: output["datetime_imports"].add("datetime") data["properties"].append({ "name": f.name, "py_name": safe_snake_case(f.name), "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), "field_type": field_type, "field_wraps": field_wraps, "map_types": map_types, "type": t, "zero": zero, "repeated": repeated, "packed": packed, "one_of": one_of, }) # print(f, file=sys.stderr) output["messages"].append(data) elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update({ "type": "Enum", "comment": get_comment(proto_file, path), "entries": [{ "name": v.name, "value": v.number, "comment": get_comment(proto_file, path + [2, i]), } for i, v in enumerate(item.value)], }) output["enums"].append(data) for i, service in enumerate(proto_file.service): # print(service, file=sys.stderr) data = { "name": service.name, "py_name": stringcase.pascalcase(service.name), "comment": get_comment(proto_file, [6, i]), "methods": [], } for j, method in enumerate(service.method): input_message = None input_type = get_ref_type(package, output["imports"], method.input_type).strip('"') for msg in output["messages"]: if msg["name"] == input_type: input_message = msg for field in msg["properties"]: if field["zero"] == "None": output["typing_imports"].add("Optional") break data["methods"].append({ "name": method.name, "py_name": stringcase.snakecase(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), "route": f"/{package}.{service.name}/{method.name}", "input": get_ref_type(package, output["imports"], method.input_type).strip('"'), "input_message": input_message, "output": get_ref_type( package, output["imports"], method.output_type, unwrap=False, ).strip('"'), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, }) if method.client_streaming: output["typing_imports"].add("AsyncIterable") output["typing_imports"].add("Iterable") output["typing_imports"].add("Union") if method.server_streaming: output["typing_imports"].add("AsyncIterator") output["services"].append(data) output["imports"] = sorted(output["imports"]) output["datetime_imports"] = sorted(output["datetime_imports"]) output["typing_imports"] = sorted(output["typing_imports"]) # Fill response f = response.file.add() # print(filename, file=sys.stderr) f.name = filename.replace(".", os.path.sep) + ".py" # Render and then format the output file. f.content = black.format_str( template.render(description=output), mode=black.FileMode( target_versions=set([black.TargetVersion.PY37])), ) inits = set([""]) for f in response.file: # Ensure output paths exist # print(f.name, file=sys.stderr) dirnames = os.path.dirname(f.name) if dirnames: os.makedirs(dirnames, exist_ok=True) base = "" for part in dirnames.split(os.path.sep): base = os.path.join(base, part) inits.add(base) for base in inits: name = os.path.join(base, "__init__.py") if os.path.exists(name): # Never overwrite inits as they may have custom stuff in them. continue init = response.file.add() init.name = name init.content = b"" filenames = sorted([f.name for f in response.file]) for fname in filenames: print(f"Writing {fname}", file=sys.stderr)
def pythonize_field_name(name: str) -> str: return casing.safe_snake_case(name)
def pythonize_method_name(name: str): return casing.safe_snake_case(name)