def test_reference_camel_cased():
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type="child_package.example_message")

    assert imports == {"from . import child_package"}
    assert name == '"child_package.ExampleMessage"'
def test_reference_nested_child_from_root():
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type="nested.child.Message")

    assert imports == {"from .nested import child as nested_child"}
    assert name == '"nested_child.Message"'
def test_reference_child_package_from_root():
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type="child.Message")

    assert imports == {"from . import child"}
    assert name == '"child.Message"'
def test_reference_far_far_cousin_package():
    imports = set()
    name = get_type_reference(package="a.x.y.z",
                              imports=imports,
                              source_type="a.b.c.d.Message")

    assert imports == {"from ....b.c import d as ___b_c_d__"}
    assert name == '"___b_c_d__.Message"'
def test_reference_unrelated_package():
    imports = set()
    name = get_type_reference(package="a",
                              imports=imports,
                              source_type="p.Message")

    assert imports == {"from .. import p as _p__"}
    assert name == '"_p__.Message"'
def test_reference_cousin_package_same_name():
    imports = set()
    name = get_type_reference(package="test.package",
                              imports=imports,
                              source_type="cousin.package.Message")

    assert imports == {"from ...cousin import package as __cousin_package__"}
    assert name == '"__cousin_package__.Message"'
def test_reference_cousin_package():
    imports = set()
    name = get_type_reference(package="a.x",
                              imports=imports,
                              source_type="a.y.Message")

    assert imports == {"from .. import y as _y__"}
    assert name == '"_y__.Message"'
def test_reference_unrelated_nested_package():
    imports = set()
    name = get_type_reference(package="a.b",
                              imports=imports,
                              source_type="p.q.Message")

    assert imports == {"from ...p import q as __p_q__"}
    assert name == '"__p_q__.Message"'
def test_reference_unrelated_deeply_nested_package():
    imports = set()
    name = get_type_reference(package="a.b.c.d",
                              imports=imports,
                              source_type="p.q.r.s.Message")

    assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
    assert name == '"____p_q_r_s__.Message"'
def test_reference_root_package_from_deeply_nested_child():
    imports = set()
    name = get_type_reference(package="package.deeply.nested.child",
                              imports=imports,
                              source_type="Message")

    assert imports == {"from ..... import Message as ____Message__"}
    assert name == '"____Message__"'
def test_referenceing_google_wrappers_without_unwrapping(
        google_type: str, expected_name: str):
    name = get_type_reference(package="",
                              imports=set(),
                              source_type=google_type,
                              unwrap=False)

    assert name == expected_name
def test_reference_parent_package_from_child():
    imports = set()
    name = get_type_reference(package="package.child",
                              imports=imports,
                              source_type="package.Message")

    assert imports == {"from ... import package as __package__"}
    assert name == '"__package__.Message"'
def test_reference_deeply_nested_siblings():
    imports = set()
    name = get_type_reference(package="foo.bar",
                              imports=imports,
                              source_type="foo.bar.Message")

    assert imports == set()
    assert name == '"Message"'
def test_reference_root_sibling():
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type="Message")

    assert imports == set()
    assert name == '"Message"'
def test_referenceing_google_wrappers_unwraps_them(google_type: str,
                                                   expected_name: str):
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type=google_type)

    assert name == expected_name
    assert imports == set()
def test_reference_ancestor_package_from_nested_child():
    imports = set()
    name = get_type_reference(
        package="package.ancestor.nested.child",
        imports=imports,
        source_type="package.ancestor.Message",
    )

    assert imports == {"from .... import ancestor as ___ancestor__"}
    assert name == '"___ancestor__.Message"'
def test_reference_google_wellknown_types_non_wrappers(google_type: str,
                                                       expected_name: str,
                                                       expected_import: str):
    imports = set()
    name = get_type_reference(package="",
                              imports=imports,
                              source_type=google_type)

    assert name == expected_name
    assert imports.__contains__(
        expected_import), f"{expected_import} not found in {imports}"
def test_reference_deeply_nested_child_from_package():
    imports = set()
    name = get_type_reference(
        package="package",
        imports=imports,
        source_type="package.deeply.nested.child.Message",
    )

    assert imports == {
        "from .deeply.nested import child as deeply_nested_child"
    }
    assert name == '"deeply_nested_child.Message"'
    def py_input_message_type(self) -> str:
        """String representation of the Python type corresponding to the
        input message.

        Returns
        -------
        str
            String representation of the Python type corresponding to the input message.
        """
        return get_type_reference(
            package=self.output_file.package,
            imports=self.output_file.imports,
            source_type=self.proto_obj.input_type,
        ).strip('"')
Пример #20
0
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
    if field.type in [1, 2]:
        return "float"
    elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
        return "int"
    elif field.type == 8:
        return "bool"
    elif field.type == 9:
        return "str"
    elif field.type in [11, 14]:
        # Type referencing another defined Message or a named enum
        return get_type_reference(package, imports, field.type_name)
    elif field.type == 12:
        return "bytes"
    else:
        raise NotImplementedError(f"Unknown type {field.type}")
 def py_type(self) -> str:
     """String representation of Python type."""
     if self.proto_obj.type in PROTO_FLOAT_TYPES:
         return "float"
     elif self.proto_obj.type in PROTO_INT_TYPES:
         return "int"
     elif self.proto_obj.type in PROTO_BOOL_TYPES:
         return "bool"
     elif self.proto_obj.type in PROTO_STR_TYPES:
         return "str"
     elif self.proto_obj.type in PROTO_BYTES_TYPES:
         return "bytes"
     elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
         # Type referencing another defined Message or a named enum
         return get_type_reference(
             package=self.output_file.package,
             imports=self.output_file.imports,
             source_type=self.proto_obj.type_name,
         )
     else:
         raise NotImplementedError(f"Unknown type {field.type}")
Пример #22
0
def read_protobuf_service(service: ServiceDescriptorProto, index, proto_file,
                          content, output_types):
    input_package_name = content["input_package"]
    template_data = content["template_data"]
    # print(service, file=sys.stderr)
    data = {
        "name": service.name,
        "py_name": pythonize_class_name(service.name),
        "comment": get_comment(proto_file, [6, index]),
        "methods": [],
    }
    for j, method in enumerate(service.method):
        method_input_message = lookup_method_input_type(method, output_types)

        # This section ensures that method arguments having a default
        # value that is initialised as a List/Dict (mutable) is replaced
        # with None and initialisation is deferred to the beginning of the
        # method definition. This is done so to avoid any side-effects.
        # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
        mutable_default_args = []

        if method_input_message:
            for field in method_input_message["properties"]:
                if (not method.client_streaming and field["zero"] != "None"
                        and is_mutable_field_type(field["type"])):
                    mutable_default_args.append(
                        (field["py_name"], field["zero"]))
                    field["zero"] = "None"

                if field["zero"] == "None":
                    template_data["typing_imports"].add("Optional")

        data["methods"].append({
            "name":
            method.name,
            "py_name":
            pythonize_method_name(method.name),
            "comment":
            get_comment(proto_file, [6, index, 2, j], indent=8),
            "route":
            f"/{input_package_name}.{service.name}/{method.name}",
            "input":
            get_type_reference(input_package_name, template_data["imports"],
                               method.input_type).strip('"'),
            "input_message":
            method_input_message,
            "output":
            get_type_reference(
                input_package_name,
                template_data["imports"],
                method.output_type,
                unwrap=False,
            ),
            "client_streaming":
            method.client_streaming,
            "server_streaming":
            method.server_streaming,
            "mutable_default_args":
            mutable_default_args,
        })

        if method.client_streaming:
            template_data["typing_imports"].add("AsyncIterable")
            template_data["typing_imports"].add("Iterable")
            template_data["typing_imports"].add("Union")
        if method.server_streaming:
            template_data["typing_imports"].add("AsyncIterator")
    template_data["services"].append(data)