def test_reference_camel_cased(): imports = set() name = get_type_reference(package="", imports=imports, source_type="child_package.example_message") assert imports == {"from . import child_package"} assert name == '"child_package.ExampleMessage"'
def test_reference_nested_child_from_root(): imports = set() name = get_type_reference(package="", imports=imports, source_type="nested.child.Message") assert imports == {"from .nested import child as nested_child"} assert name == '"nested_child.Message"'
def test_reference_child_package_from_root(): imports = set() name = get_type_reference(package="", imports=imports, source_type="child.Message") assert imports == {"from . import child"} assert name == '"child.Message"'
def test_reference_far_far_cousin_package(): imports = set() name = get_type_reference(package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message") assert imports == {"from ....b.c import d as ___b_c_d__"} assert name == '"___b_c_d__.Message"'
def test_reference_unrelated_package(): imports = set() name = get_type_reference(package="a", imports=imports, source_type="p.Message") assert imports == {"from .. import p as _p__"} assert name == '"_p__.Message"'
def test_reference_cousin_package_same_name(): imports = set() name = get_type_reference(package="test.package", imports=imports, source_type="cousin.package.Message") assert imports == {"from ...cousin import package as __cousin_package__"} assert name == '"__cousin_package__.Message"'
def test_reference_cousin_package(): imports = set() name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") assert imports == {"from .. import y as _y__"} assert name == '"_y__.Message"'
def test_reference_unrelated_nested_package(): imports = set() name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") assert imports == {"from ...p import q as __p_q__"} assert name == '"__p_q__.Message"'
def test_reference_unrelated_deeply_nested_package(): imports = set() name = get_type_reference(package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message") assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} assert name == '"____p_q_r_s__.Message"'
def test_reference_root_package_from_deeply_nested_child(): imports = set() name = get_type_reference(package="package.deeply.nested.child", imports=imports, source_type="Message") assert imports == {"from ..... import Message as ____Message__"} assert name == '"____Message__"'
def test_referenceing_google_wrappers_without_unwrapping( google_type: str, expected_name: str): name = get_type_reference(package="", imports=set(), source_type=google_type, unwrap=False) assert name == expected_name
def test_reference_parent_package_from_child(): imports = set() name = get_type_reference(package="package.child", imports=imports, source_type="package.Message") assert imports == {"from ... import package as __package__"} assert name == '"__package__.Message"'
def test_reference_deeply_nested_siblings(): imports = set() name = get_type_reference(package="foo.bar", imports=imports, source_type="foo.bar.Message") assert imports == set() assert name == '"Message"'
def test_reference_root_sibling(): imports = set() name = get_type_reference(package="", imports=imports, source_type="Message") assert imports == set() assert name == '"Message"'
def test_referenceing_google_wrappers_unwraps_them(google_type: str, expected_name: str): imports = set() name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports == set()
def test_reference_ancestor_package_from_nested_child(): imports = set() name = get_type_reference( package="package.ancestor.nested.child", imports=imports, source_type="package.ancestor.Message", ) assert imports == {"from .... import ancestor as ___ancestor__"} assert name == '"___ancestor__.Message"'
def test_reference_google_wellknown_types_non_wrappers(google_type: str, expected_name: str, expected_import: str): imports = set() name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports.__contains__( expected_import), f"{expected_import} not found in {imports}"
def test_reference_deeply_nested_child_from_package(): imports = set() name = get_type_reference( package="package", imports=imports, source_type="package.deeply.nested.child.Message", ) assert imports == { "from .deeply.nested import child as deeply_nested_child" } assert name == '"deeply_nested_child.Message"'
def py_input_message_type(self) -> str: """String representation of the Python type corresponding to the input message. Returns ------- str String representation of the Python type corresponding to the input message. """ return get_type_reference( package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.input_type, ).strip('"')
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: if field.type in [1, 2]: return "float" elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]: return "int" elif field.type == 8: return "bool" elif field.type == 9: return "str" elif field.type in [11, 14]: # Type referencing another defined Message or a named enum return get_type_reference(package, imports, field.type_name) elif field.type == 12: return "bytes" else: raise NotImplementedError(f"Unknown type {field.type}")
def py_type(self) -> str: """String representation of Python type.""" if self.proto_obj.type in PROTO_FLOAT_TYPES: return "float" elif self.proto_obj.type in PROTO_INT_TYPES: return "int" elif self.proto_obj.type in PROTO_BOOL_TYPES: return "bool" elif self.proto_obj.type in PROTO_STR_TYPES: return "str" elif self.proto_obj.type in PROTO_BYTES_TYPES: return "bytes" elif self.proto_obj.type in PROTO_MESSAGE_TYPES: # Type referencing another defined Message or a named enum return get_type_reference( package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.type_name, ) else: raise NotImplementedError(f"Unknown type {field.type}")
def read_protobuf_service(service: ServiceDescriptorProto, index, proto_file, content, output_types): input_package_name = content["input_package"] template_data = content["template_data"] # print(service, file=sys.stderr) data = { "name": service.name, "py_name": pythonize_class_name(service.name), "comment": get_comment(proto_file, [6, index]), "methods": [], } for j, method in enumerate(service.method): method_input_message = lookup_method_input_type(method, output_types) # This section ensures that method arguments having a default # value that is initialised as a List/Dict (mutable) is replaced # with None and initialisation is deferred to the beginning of the # method definition. This is done so to avoid any side-effects. # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments mutable_default_args = [] if method_input_message: for field in method_input_message["properties"]: if (not method.client_streaming and field["zero"] != "None" and is_mutable_field_type(field["type"])): mutable_default_args.append( (field["py_name"], field["zero"])) field["zero"] = "None" if field["zero"] == "None": template_data["typing_imports"].add("Optional") data["methods"].append({ "name": method.name, "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, index, 2, j], indent=8), "route": f"/{input_package_name}.{service.name}/{method.name}", "input": get_type_reference(input_package_name, template_data["imports"], method.input_type).strip('"'), "input_message": method_input_message, "output": get_type_reference( input_package_name, template_data["imports"], method.output_type, unwrap=False, ), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, "mutable_default_args": mutable_default_args, }) if method.client_streaming: template_data["typing_imports"].add("AsyncIterable") template_data["typing_imports"].add("Iterable") template_data["typing_imports"].add("Union") if method.server_streaming: template_data["typing_imports"].add("AsyncIterator") template_data["services"].append(data)