Пример #1
0
def get_type_reference(
    package: str,
    imports: set,
    source_type: str,
    unwrap: bool = True,
) -> str:
    """
    Return a Python type name for a proto type reference. Adds the import if
    necessary. Unwraps well known type if required.
    """
    if unwrap:
        if source_type in WRAPPER_TYPES:
            wrapped_type = type(WRAPPER_TYPES[source_type]().value)
            return f"Optional[{wrapped_type.__name__}]"

        if source_type == ".google.protobuf.Duration":
            return "timedelta"

        if source_type == ".google.protobuf.Timestamp":
            return "datetime"

    source_package, source_type = parse_source_type_name(source_type)

    current_package: List[str] = package.split(".") if package else []
    py_package: List[str] = source_package.split(".") if source_package else []
    py_type: str = pythonize_class_name(source_type)

    compiling_google_protobuf = current_package == ["google", "protobuf"]
    importing_google_protobuf = py_package == ["google", "protobuf"]
    if importing_google_protobuf and not compiling_google_protobuf:
        py_package = ["betterproto", "lib"] + py_package

    if py_package[:1] == ["betterproto"]:
        return reference_absolute(imports, py_package, py_type)

    if py_package == current_package:
        return reference_sibling(py_type)

    if py_package[:len(current_package)] == current_package:
        return reference_descendent(current_package, imports, py_package,
                                    py_type)

    if current_package[:len(py_package)] == py_package:
        return reference_ancestor(current_package, imports, py_package,
                                  py_type)

    return reference_cousin(current_package, imports, py_package, py_type)
 def py_name(self) -> str:
     return pythonize_class_name(self.proto_name)
Пример #3
0
def read_protobuf_service(service: ServiceDescriptorProto, index, proto_file,
                          content, output_types):
    input_package_name = content["input_package"]
    template_data = content["template_data"]
    # print(service, file=sys.stderr)
    data = {
        "name": service.name,
        "py_name": pythonize_class_name(service.name),
        "comment": get_comment(proto_file, [6, index]),
        "methods": [],
    }
    for j, method in enumerate(service.method):
        method_input_message = lookup_method_input_type(method, output_types)

        # This section ensures that method arguments having a default
        # value that is initialised as a List/Dict (mutable) is replaced
        # with None and initialisation is deferred to the beginning of the
        # method definition. This is done so to avoid any side-effects.
        # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
        mutable_default_args = []

        if method_input_message:
            for field in method_input_message["properties"]:
                if (not method.client_streaming and field["zero"] != "None"
                        and is_mutable_field_type(field["type"])):
                    mutable_default_args.append(
                        (field["py_name"], field["zero"]))
                    field["zero"] = "None"

                if field["zero"] == "None":
                    template_data["typing_imports"].add("Optional")

        data["methods"].append({
            "name":
            method.name,
            "py_name":
            pythonize_method_name(method.name),
            "comment":
            get_comment(proto_file, [6, index, 2, j], indent=8),
            "route":
            f"/{input_package_name}.{service.name}/{method.name}",
            "input":
            get_type_reference(input_package_name, template_data["imports"],
                               method.input_type).strip('"'),
            "input_message":
            method_input_message,
            "output":
            get_type_reference(
                input_package_name,
                template_data["imports"],
                method.output_type,
                unwrap=False,
            ),
            "client_streaming":
            method.client_streaming,
            "server_streaming":
            method.server_streaming,
            "mutable_default_args":
            mutable_default_args,
        })

        if method.client_streaming:
            template_data["typing_imports"].add("AsyncIterable")
            template_data["typing_imports"].add("Iterable")
            template_data["typing_imports"].add("Union")
        if method.server_streaming:
            template_data["typing_imports"].add("AsyncIterator")
    template_data["services"].append(data)
Пример #4
0
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file,
                       content):
    input_package_name = content["input_package"]
    template_data = content["template_data"]
    data = {
        "name": item.name,
        "py_name": pythonize_class_name(item.name),
        "descriptor": item,
        "package": input_package_name,
    }
    if isinstance(item, DescriptorProto):
        # print(item, file=sys.stderr)
        if item.options.map_entry:
            # Skip generated map entry messages since we just use dicts
            return

        data.update({
            "type": "Message",
            "comment": get_comment(proto_file, path),
            "properties": [],
        })

        for i, f in enumerate(item.field):
            t = py_type(input_package_name, template_data["imports"], f)
            zero = get_py_zero(f.type)

            repeated = False
            packed = False

            field_type = f.Type.Name(f.type).lower()[5:]

            field_wraps = ""
            match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value",
                                     f.type_name)
            if match_wrapper:
                wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
                if hasattr(betterproto, wrapped_type):
                    field_wraps = f"betterproto.{wrapped_type}"

            map_types = None
            if f.type == 11:
                # This might be a map...
                message_type = f.type_name.split(".").pop().lower()
                # message_type = py_type(package)
                map_entry = f"{f.name.replace('_', '').lower()}entry"

                if message_type == map_entry:
                    for nested in item.nested_type:
                        if nested.name.replace("_", "").lower() == map_entry:
                            if nested.options.map_entry:
                                # print("Found a map!", file=sys.stderr)
                                k = py_type(
                                    input_package_name,
                                    template_data["imports"],
                                    nested.field[0],
                                )
                                v = py_type(
                                    input_package_name,
                                    template_data["imports"],
                                    nested.field[1],
                                )
                                t = f"Dict[{k}, {v}]"
                                field_type = "map"
                                map_types = (
                                    f.Type.Name(nested.field[0].type),
                                    f.Type.Name(nested.field[1].type),
                                )
                                template_data["typing_imports"].add("Dict")

            if f.label == 3 and field_type != "map":
                # Repeated field
                repeated = True
                t = f"List[{t}]"
                zero = "[]"
                template_data["typing_imports"].add("List")

                if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
                    packed = True

            one_of = ""
            if f.HasField("oneof_index"):
                one_of = item.oneof_decl[f.oneof_index].name

            if "Optional[" in t:
                template_data["typing_imports"].add("Optional")

            if "timedelta" in t:
                template_data["datetime_imports"].add("timedelta")
            elif "datetime" in t:
                template_data["datetime_imports"].add("datetime")

            data["properties"].append({
                "name":
                f.name,
                "py_name":
                pythonize_field_name(f.name),
                "number":
                f.number,
                "comment":
                get_comment(proto_file, path + [2, i]),
                "proto_type":
                int(f.type),
                "field_type":
                field_type,
                "field_wraps":
                field_wraps,
                "map_types":
                map_types,
                "type":
                t,
                "zero":
                zero,
                "repeated":
                repeated,
                "packed":
                packed,
                "one_of":
                one_of,
            })
            # print(f, file=sys.stderr)

        template_data["messages"].append(data)
        return data
    elif isinstance(item, EnumDescriptorProto):
        # print(item.name, path, file=sys.stderr)
        data.update({
            "type":
            "Enum",
            "comment":
            get_comment(proto_file, path),
            "entries": [{
                "name": v.name,
                "value": v.number,
                "comment": get_comment(proto_file, path + [2, i]),
            } for i, v in enumerate(item.value)],
        })

        template_data["enums"].append(data)
        return data