예제 #1
0
def generate_code(request, response):
    plugin_options = request.parameter.split(",") if request.parameter else []

    env = jinja2.Environment(
        trim_blocks=True,
        lstrip_blocks=True,
        loader=jinja2.FileSystemLoader("%s/templates/" %
                                       os.path.dirname(__file__)),
    )
    template = env.get_template("template.py.j2")

    output_map = {}
    for proto_file in request.proto_file:
        out = proto_file.package

        if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
            continue

        if not out:
            out = os.path.splitext(proto_file.name)[0].replace(
                os.path.sep, ".")

        if out not in output_map:
            output_map[out] = {"package": proto_file.package, "files": []}
        output_map[out]["files"].append(proto_file)

    # TODO: Figure out how to handle gRPC request/response messages and add
    # processing below for Service.

    for filename, options in output_map.items():
        package = options["package"]
        # print(package, filename, file=sys.stderr)
        output = {
            "package": package,
            "files": [f.name for f in options["files"]],
            "imports": set(),
            "datetime_imports": set(),
            "typing_imports": set(),
            "messages": [],
            "enums": [],
            "services": [],
        }

        type_mapping = {}

        for proto_file in options["files"]:
            # print(proto_file.message_type, file=sys.stderr)
            # print(proto_file.service, file=sys.stderr)
            # print(proto_file.source_code_info, file=sys.stderr)

            for item, path in traverse(proto_file):
                # print(item, file=sys.stderr)
                # print(path, file=sys.stderr)
                data = {
                    "name": item.name,
                    "py_name": stringcase.pascalcase(item.name)
                }

                if isinstance(item, DescriptorProto):
                    # print(item, file=sys.stderr)
                    if item.options.map_entry:
                        # Skip generated map entry messages since we just use dicts
                        continue

                    data.update({
                        "type": "Message",
                        "comment": get_comment(proto_file, path),
                        "properties": [],
                    })

                    for i, f in enumerate(item.field):
                        t = py_type(package, output["imports"], item, f)
                        zero = get_py_zero(f.type)

                        repeated = False
                        packed = False

                        field_type = f.Type.Name(f.type).lower()[5:]

                        field_wraps = ""
                        match_wrapper = re.match(
                            r"\.google\.protobuf\.(.+)Value", f.type_name)
                        if match_wrapper:
                            wrapped_type = "TYPE_" + match_wrapper.group(
                                1).upper()
                            if hasattr(betterproto, wrapped_type):
                                field_wraps = f"betterproto.{wrapped_type}"

                        map_types = None
                        if f.type == 11:
                            # This might be a map...
                            message_type = f.type_name.split(".").pop().lower()
                            # message_type = py_type(package)
                            map_entry = f"{f.name.replace('_', '').lower()}entry"

                            if message_type == map_entry:
                                for nested in item.nested_type:
                                    if (nested.name.replace(
                                            "_", "").lower() == map_entry):
                                        if nested.options.map_entry:
                                            # print("Found a map!", file=sys.stderr)
                                            k = py_type(
                                                package,
                                                output["imports"],
                                                item,
                                                nested.field[0],
                                            )
                                            v = py_type(
                                                package,
                                                output["imports"],
                                                item,
                                                nested.field[1],
                                            )
                                            t = f"Dict[{k}, {v}]"
                                            field_type = "map"
                                            map_types = (
                                                f.Type.Name(
                                                    nested.field[0].type),
                                                f.Type.Name(
                                                    nested.field[1].type),
                                            )
                                            output["typing_imports"].add(
                                                "Dict")

                        if f.label == 3 and field_type != "map":
                            # Repeated field
                            repeated = True
                            t = f"List[{t}]"
                            zero = "[]"
                            output["typing_imports"].add("List")

                            if f.type in [
                                    1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18
                            ]:
                                packed = True

                        one_of = ""
                        if f.HasField("oneof_index"):
                            one_of = item.oneof_decl[f.oneof_index].name

                        if "Optional[" in t:
                            output["typing_imports"].add("Optional")

                        if "timedelta" in t:
                            output["datetime_imports"].add("timedelta")
                        elif "datetime" in t:
                            output["datetime_imports"].add("datetime")

                        data["properties"].append({
                            "name":
                            f.name,
                            "py_name":
                            safe_snake_case(f.name),
                            "number":
                            f.number,
                            "comment":
                            get_comment(proto_file, path + [2, i]),
                            "proto_type":
                            int(f.type),
                            "field_type":
                            field_type,
                            "field_wraps":
                            field_wraps,
                            "map_types":
                            map_types,
                            "type":
                            t,
                            "zero":
                            zero,
                            "repeated":
                            repeated,
                            "packed":
                            packed,
                            "one_of":
                            one_of,
                        })
                        # print(f, file=sys.stderr)

                    output["messages"].append(data)
                elif isinstance(item, EnumDescriptorProto):
                    # print(item.name, path, file=sys.stderr)
                    data.update({
                        "type":
                        "Enum",
                        "comment":
                        get_comment(proto_file, path),
                        "entries": [{
                            "name":
                            v.name,
                            "value":
                            v.number,
                            "comment":
                            get_comment(proto_file, path + [2, i]),
                        } for i, v in enumerate(item.value)],
                    })

                    output["enums"].append(data)

            for i, service in enumerate(proto_file.service):
                # print(service, file=sys.stderr)

                data = {
                    "name": service.name,
                    "py_name": stringcase.pascalcase(service.name),
                    "comment": get_comment(proto_file, [6, i]),
                    "methods": [],
                }

                for j, method in enumerate(service.method):
                    input_message = None
                    input_type = get_ref_type(package, output["imports"],
                                              method.input_type).strip('"')
                    for msg in output["messages"]:
                        if msg["name"] == input_type:
                            input_message = msg
                            for field in msg["properties"]:
                                if field["zero"] == "None":
                                    output["typing_imports"].add("Optional")
                            break

                    data["methods"].append({
                        "name":
                        method.name,
                        "py_name":
                        stringcase.snakecase(method.name),
                        "comment":
                        get_comment(proto_file, [6, i, 2, j], indent=8),
                        "route":
                        f"/{package}.{service.name}/{method.name}",
                        "input":
                        get_ref_type(package, output["imports"],
                                     method.input_type).strip('"'),
                        "input_message":
                        input_message,
                        "output":
                        get_ref_type(
                            package,
                            output["imports"],
                            method.output_type,
                            unwrap=False,
                        ).strip('"'),
                        "client_streaming":
                        method.client_streaming,
                        "server_streaming":
                        method.server_streaming,
                    })

                    if method.client_streaming:
                        output["typing_imports"].add("AsyncIterable")
                        output["typing_imports"].add("Iterable")
                        output["typing_imports"].add("Union")
                    if method.server_streaming:
                        output["typing_imports"].add("AsyncIterator")

                output["services"].append(data)

        output["imports"] = sorted(output["imports"])
        output["datetime_imports"] = sorted(output["datetime_imports"])
        output["typing_imports"] = sorted(output["typing_imports"])

        # Fill response
        f = response.file.add()
        # print(filename, file=sys.stderr)
        f.name = filename.replace(".", os.path.sep) + ".py"

        # Render and then format the output file.
        f.content = black.format_str(
            template.render(description=output),
            mode=black.FileMode(
                target_versions=set([black.TargetVersion.PY37])),
        )

    inits = set([""])
    for f in response.file:
        # Ensure output paths exist
        # print(f.name, file=sys.stderr)
        dirnames = os.path.dirname(f.name)
        if dirnames:
            os.makedirs(dirnames, exist_ok=True)
            base = ""
            for part in dirnames.split(os.path.sep):
                base = os.path.join(base, part)
                inits.add(base)

    for base in inits:
        name = os.path.join(base, "__init__.py")

        if os.path.exists(name):
            # Never overwrite inits as they may have custom stuff in them.
            continue

        init = response.file.add()
        init.name = name
        init.content = b""

    filenames = sorted([f.name for f in response.file])
    for fname in filenames:
        print(f"Writing {fname}", file=sys.stderr)
예제 #2
0
def pythonize_field_name(name: str) -> str:
    return casing.safe_snake_case(name)
예제 #3
0
def pythonize_method_name(name: str):
    return casing.safe_snake_case(name)