def get_init_import_records(self) -> List[ImportRecord]: """ Get import records for `__init__.py[i]`. """ import_records: Set[ImportRecord] = set() import_records.add( ImportRecord( ImportString.parent() + ImportString(ServiceModuleName.client.name), self.client.name, ) ) if self.service_resource: import_records.add( ImportRecord( ImportString.parent() + ImportString(ServiceModuleName.service_resource.name), self.service_resource.name, ) ) for waiter in self.waiters: import_records.add( ImportRecord( ImportString.parent() + ImportString(ServiceModuleName.waiter.name), waiter.name, ) ) for paginator in self.paginators: import_records.add( ImportRecord( ImportString.parent() + ImportString(ServiceModuleName.paginator.name), paginator.name, ) ) return list(sorted(import_records))
def test_str(self) -> None: source_mock = MagicMock() source_mock.__str__.return_value = "source" assert str(ImportRecord(source_mock, "name", "alias")) == "from source import name as alias" assert str(ImportRecord(source_mock, alias="alias")) == "import source as alias" assert str(ImportRecord(source_mock, "name")) == "from source import name" assert str(ImportRecord(source_mock)) == "import source"
def test_get_local_name(self) -> None: source_mock = MagicMock() source_mock.render.return_value = "source" assert ImportRecord(source_mock).get_local_name() == "source" assert ImportRecord(source_mock, "name").get_local_name() == "name" assert ImportRecord(source_mock, "name", "alias").get_local_name() == "alias" assert ImportRecord(source_mock, alias="alias").get_local_name() == "alias"
def get_type_defs_required_import_records(self) -> List[ImportRecord]: if not self.typed_dicts: return [] import_records: Set[ImportRecord] = set() import_records.add(ImportRecord(ImportString("sys"))) import_records.add( ImportRecord( ImportString("typing"), "TypedDict", min_version=(3, 8), fallback=ImportRecord(ImportString("typing_extensions"), "TypedDict"), )) for types_dict in self.typed_dicts: for type_annotation in types_dict.get_children_types(): import_record = type_annotation.get_import_record() if not import_record or import_record.is_builtins(): continue if import_record.is_type_defs(): continue import_records.add( import_record.get_external(self.service_name.module_name)) return list(sorted(import_records))
def test_is_type_defs(self) -> None: assert ImportRecord(ImportString("type_defs")).is_type_defs() assert ImportRecord(ImportString("service", "type_defs")).is_type_defs() assert not ImportRecord(ImportString("builtins")).is_type_defs() assert not ImportRecord(ImportString("other")).is_type_defs() assert not ImportRecord(ImportString("boto3")).is_builtins()
def test_is_standalone(self) -> None: assert not ImportRecord(ImportString("test"), name="my").is_standalone() assert ImportRecord(ImportString("test")).is_standalone() assert ImportRecord( ImportString("test"), name="my", fallback=ImportRecord(ImportString("test2")), ).is_standalone()
def setup_method(self) -> None: self.result = ImportRecordGroup( ImportString("typing"), [ ImportRecord(ImportString("typing"), "Any"), ImportRecord(ImportString("typing"), "Text", "string"), ], )
def get_import_record(self) -> ImportRecord: """ Get import record required for using type annotation. """ return ImportRecord( ImportString("typing"), "Literal", min_version=(3, 8), fallback=ImportRecord(ImportString("typing_extensions"), "Literal"), )
def get_helpers_import_record_groups(self) -> List[ImportRecordGroup]: import_records: Set[ImportRecord] = set() import_records.add(ImportRecord(ImportString("boto3"))) import_records.add(ImportRecord(ImportString("typing"), "Dict")) import_records.add(ImportRecord(ImportString("typing"), "Any")) for helper_function in self.helper_functions: for type_annotation in helper_function.get_types(): import_record = type_annotation.get_import_record() import_records.add(import_record) return ImportRecordGroup.from_import_records(import_records)
def __init__( self, source: ImportString, name: str = "", alias: str = "", ) -> None: self.source = source self.name = name self.alias = alias self.import_record = ImportRecord(source=source, name=name, alias=alias)
def get_typing_import_record() -> ImportRecord: """ Get import record required for using TypedDict. Fallback to typing_extensions for py38-. """ return ImportRecord( ImportString("typing"), "TypedDict", min_version=(3, 9), fallback=ImportRecord(ImportString("typing_extensions"), "TypedDict"), )
def get_import_record(self) -> ImportRecord: """ Create a safe Import Record for annotation. """ if self.has_fallback(): return ImportRecord( source=ImportString("typing"), name=self.get_import_name(), fallback=ImportRecord(source=ImportString("typing_extensions"), name=self.get_import_name()), ) return ImportRecord(source=ImportString("typing"), name=self.get_import_name())
def get_import_record(self) -> ImportRecord: """ Get import record required for using type annotation. """ if self.inline: return ImportRecord( ImportString("typing"), "Literal", min_version=(3, 8), fallback=ImportRecord(ImportString("typing_extensions"), "Literal"), ) return InternalImportRecord(ServiceModuleName.literals, name=self.name)
def get_import_records(self) -> Set[ImportRecord]: import_records: Set[ImportRecord] = set() source = f"{MODULE_NAME}_{self.service_name.name}.service_resource" import_records.add(ImportRecord(source, "ServiceResource")) for resource in self.sub_resources: import_records.add(ImportRecord(source, resource.name)) for collection in self.collections: import_records.add(ImportRecord(source, collection.name)) for resource in self.sub_resources: for collection in resource.collections: import_records.add(ImportRecord(source, collection.name)) return import_records
def get_literals_required_import_records(self) -> List[ImportRecord]: """ Get import records for `literals.py[i]`. """ import_records: Set[ImportRecord] = set() import_records.add(ImportRecord(ImportString("sys"))) import_records.add( ImportRecord( ImportString("typing"), "Literal", min_version=(3, 8), fallback=ImportRecord(ImportString("typing_extensions"), "Literal"), ) ) return list(sorted(import_records))
def get_client_required_import_records(self) -> List[ImportRecord]: """ Get import records for `client.py[i]`. """ import_records: Set[ImportRecord] = set() for import_record in self.client.get_required_import_records(): import_records.add(import_record.get_external(self.service_name.module_name)) if import_record.fallback: import_records.add(ImportRecord(ImportString("sys"))) for import_record in self.client.exceptions_class.get_required_import_records(): import_records.add(import_record.get_external(self.service_name.module_name)) if import_record.fallback: import_records.add(ImportRecord(ImportString("sys"))) return list(sorted(import_records))
def get_required_import_records(self) -> set[ImportRecord]: """ Extract import records from required type annotations. """ result = super().get_required_import_records() result.add(ImportRecord(ImportString("typing"), "Dict")) return result
def write_submodule(session: Session, service_name: ServiceName, output_path: Path) -> None: init_import_records: Set[ImportRecord] = set() import_record_renderer = ImportRecordRenderer( [service_name.module_name], [ImportRecord("__future__", "annotations")]) logger.info(f"Writing {service_name.extras_name} submodule") client = process_service_client(session, service_name, output_path, import_record_renderer) init_import_records.update(client.get_import_records()) service_resource = process_service_resource(session, service_name, output_path, import_record_renderer) if service_resource: init_import_records.update(service_resource.get_import_records()) process_service_waiter(session, service_name, output_path, import_record_renderer) process_service_paginator(session, service_name, output_path, import_record_renderer) init_file_path = output_path / "__init__.py" logger.debug(f"Writing {NicePath(init_file_path)}") write_init_file(init_file_path, init_import_records, service_name)
def from_import_records( cls, import_records: Iterable[ImportRecord] ) -> List["ImportRecordGroup"]: """ Get groups from `ImportRecord` list. Arguments: import_records -- Import records list. Returns: A list of generated `ImportRecordGroup`. """ result: List[ImportRecordGroup] = [] all_import_records: Set[ImportRecord] = set(import_records) for import_record in import_records: if import_record.fallback: all_import_records.add(ImportRecord(ImportString("sys"))) for import_record in sorted(all_import_records): if not import_record: continue if import_record.is_builtins(): continue if (not result or result[-1].source != import_record.source or not result[-1].import_records[0].name or import_record.is_standalone() or result[-1].import_records[0].is_standalone()): result.append( ImportRecordGroup(import_record.source, [import_record])) else: result[-1].import_records.append(import_record) return result
def get_import_record(self) -> ImportRecord: module = inspect.getmodule(self.wrapped_type) source = module.__name__ if module else "builtins" return ImportRecord( source=source, name=self.render(), )
def get_import_records(self) -> Set[ImportRecord]: import_records: Set[ImportRecord] = set() source = f"{MODULE_NAME}_{self.service_name.name}.waiter" for waiter in self.waiters: import_records.add(ImportRecord(source, waiter.name)) return import_records
def get_import_records(self) -> Set[ImportRecord]: import_records: Set[ImportRecord] = set() source = f"{MODULE_NAME}_{self.service_name.name}.paginator" for paginator in self.paginators: import_records.add(ImportRecord(source, paginator.name)) return import_records
def get_import_record(self) -> ImportRecord: if self.real_service_name is None: raise ValueError("Non-localized ImportString") return ImportRecord( source= f"{MODULE_NAME}_{self.real_service_name.name}.{self.module_name}", alias=self.scope, )
def test_is_third_party(self) -> None: assert not ImportRecord(ImportString("type_defs")).is_third_party() assert not ImportRecord(ImportString("builtins")).is_third_party() assert not ImportRecord(ImportString("other")).is_third_party() assert ImportRecord(ImportString("boto3")).is_third_party() assert ImportRecord(ImportString("boto3", "test")).is_third_party() assert ImportRecord(ImportString("botocore")).is_third_party() assert ImportRecord(ImportString("botocore", "test")).is_third_party()
def get_waiter_required_import_records(self) -> List[ImportRecord]: import_records: Set[ImportRecord] = set() for waiter in self.waiters: for import_record in waiter.get_required_import_records(): import_records.add( import_record.get_external(self.service_name.module_name)) if import_record.fallback: import_records.add(ImportRecord(ImportString("sys"))) return list(sorted(import_records))
def get_import_record(self) -> ImportRecord: module = inspect.getmodule(self.value) if module is None: raise ValueError(f"Unknown module for {self.value}") module_name = module.__name__ return ImportRecord( source=ImportString.from_str(module_name), name=self.get_import_name(), alias=self.alias, )
def get_import_record(self) -> ImportRecord: """ Get import record required for using type annotation. """ if self.service_name is not None: return ImportRecord( source=ImportString(self.service_name.module_name, self.module_name.name), alias=self.scope, ) return InternalImportRecord(self.module_name, alias=self.scope)
def get_type_defs_required_import_record_groups( self) -> List[ImportRecordGroup]: import_records: Set[ImportRecord] = set() if self.typed_dicts: import_records.add( ImportRecord( ImportString("typing"), "TypedDict", min_version=(3, 8), fallback=ImportRecord(ImportString("typing_extensions"), "TypedDict"), )) for types_dict in self.typed_dicts: for type_annotation in types_dict.get_children_types(): import_record = type_annotation.get_import_record() if import_record.is_type_defs(): continue import_records.add(import_record) return ImportRecordGroup.from_import_records(import_records)
class ExternalImport(FakeAnnotation): def __init__(self, source: str, name: str = "", alias: str = "") -> None: self.import_record = ImportRecord(source=source, name=name, alias=alias) def render(self) -> str: return self.import_record.get_local_name() def get_import_record(self) -> ImportRecord: return self.import_record
def get_paginator_required_import_records(self) -> List[ImportRecord]: """ Get import records for `paginator.py[i]`. """ import_records: Set[ImportRecord] = set() for paginator in self.paginators: for import_record in paginator.get_required_import_records(): import_records.add(import_record.get_external(self.service_name.module_name)) if import_record.fallback: import_records.add(ImportRecord(ImportString("sys"))) return list(sorted(import_records))