def visit_Import(self, node) -> None: for alias in node.names: name = alias.asname.name.value if alias.asname is not None else alias.name.value # Regenerate alias to avoid trailing comma issue alias = cst.ImportAlias(name=alias.name, asname=alias.asname) self.imprts[name] = cst.Import(names=[alias])
def import_to_node_single(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] trailing_whitespace = cst.TrailingWhitespace() trailing_comments = list(imp.comments.first_inline) names: List[cst.ImportAlias] = [] for item in imp.items: name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None node = cst.ImportAlias(name=name, asname=asname) names.append(node) trailing_comments += item.comments.before trailing_comments += item.comments.inline trailing_comments += item.comments.following trailing_comments += imp.comments.final trailing_comments += imp.comments.last_inline if trailing_comments: text = COMMENT_INDENT.join(trailing_comments) trailing_whitespace = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(text)) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots line = cst.SimpleStatementLine( body=[ cst.ImportFrom(module=module_name, names=names, relative=relative) ], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) else: line = cst.SimpleStatementLine( body=[cst.Import(names=names)], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) return line
def __get_required_imports(self): def find_required_modules(all_types): req_mod = set() for _, a_node in all_types: m = match.findall( a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare())) if len(m) != 0: for i in m: req_mod.add([ n.value for n in match.findall( i, match.Name(value=match.DoNotCare())) ][0]) return req_mod req_imports = [] all_req_mods = find_required_modules(self.all_applied_types) all_type_names = set( chain.from_iterable( map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types))) typing_imports = PY_TYPING_MOD & all_type_names collection_imports = PY_COLLECTION_MOD & all_type_names if len(typing_imports) > 0: req_imports.append( cst.SimpleStatementLine(body=[ cst.ImportFrom(module=cst.Name(value="typing"), names=[ cst.ImportAlias(name=cst.Name(value=t), asname=None) for t in typing_imports ]), ])) if len(collection_imports) > 0: req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"), names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \ for t in collection_imports]),])) if len(all_req_mods) > 0: for mod_name in all_req_mods: req_imports.append( cst.SimpleStatementLine(body=[ cst.Import(names=[ cst.ImportAlias(name=cst.Name(value=mod_name), asname=None) ]) ])) return req_imports
def make_simple_package_import(package: str) -> cst.Import: assert not "." in package, "this only supports a root package, e.g. 'import os'" return cst.Import(names=[cst.ImportAlias(name=cst.Name(package))])
class ImportParseTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], **kwargs, )
class ImportCreateTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", "expected_position": CodeRange((1, 0), (1, 23)), }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", "expected_position": CodeRange((1, 0), (1, 46)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.Import(names=()), "expected_re": "at least one ImportAlias", }, { "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(), ), ) ), "expected_re": "trailing comma", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ), whitespace_after_import=cst.SimpleWhitespace(""), ), "expected_re": "at least one space", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)