def __get_required_imports(self): def find_required_modules(all_types): req_mod = set() for _, a_node in all_types: m = match.findall( a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare())) if len(m) != 0: for i in m: req_mod.add([ n.value for n in match.findall( i, match.Name(value=match.DoNotCare())) ][0]) return req_mod req_imports = [] all_req_mods = find_required_modules(self.all_applied_types) all_type_names = set( chain.from_iterable( map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types))) typing_imports = PY_TYPING_MOD & all_type_names collection_imports = PY_COLLECTION_MOD & all_type_names if len(typing_imports) > 0: req_imports.append( cst.SimpleStatementLine(body=[ cst.ImportFrom(module=cst.Name(value="typing"), names=[ cst.ImportAlias(name=cst.Name(value=t), asname=None) for t in typing_imports ]), ])) if len(collection_imports) > 0: req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"), names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \ for t in collection_imports]),])) if len(all_req_mods) > 0: for mod_name in all_req_mods: req_imports.append( cst.SimpleStatementLine(body=[ cst.Import(names=[ cst.ImportAlias(name=cst.Name(value=mod_name), asname=None) ]) ])) return req_imports
def visit_Assign(self, node) -> None: if (m.matches(node, m.Assign(targets=[m.AssignTarget(m.Name())])) and self.toplevel == 0): name = node.targets[0].target self.imprts[name.value] = cst.ImportFrom( module=parse_expr(self.mod), names=[cst.ImportAlias(name=name, asname=None)])
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def import_to_node_single(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] trailing_whitespace = cst.TrailingWhitespace() trailing_comments = list(imp.comments.first_inline) names: List[cst.ImportAlias] = [] for item in imp.items: name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None node = cst.ImportAlias(name=name, asname=asname) names.append(node) trailing_comments += item.comments.before trailing_comments += item.comments.inline trailing_comments += item.comments.following trailing_comments += imp.comments.final trailing_comments += imp.comments.last_inline if trailing_comments: text = COMMENT_INDENT.join(trailing_comments) trailing_whitespace = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(text)) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots line = cst.SimpleStatementLine( body=[ cst.ImportFrom(module=module_name, names=names, relative=relative) ], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) else: line = cst.SimpleStatementLine( body=[cst.Import(names=names)], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) return line
def body( self, ) -> typing.Iterable[typing.Union[cst.BaseCompoundStatement, cst.SimpleStatementLine]]: yield cst.SimpleStatementLine( [cst.ImportFrom(cst.Name("typing"), names=cst.ImportStar())]) yield from assign_properties(self.properties) yield from function_defs(self.function_overloads, self.functions, "function") for name, class_ in sort_items(self.classes): yield class_.class_def(name)
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def visit_ImportFrom(self, node) -> None: for alias in node.names: name = alias.asname.name.value if alias.asname is not None else alias.name.value level = len(node.relative) if level > 0: parts = self.mod.split('.') mod_level = '.'.join( parts[:-level]) if len(parts) > 1 else parts[0] if node.module is not None: module = parse_expr(f'{mod_level}.{a2s(node.module)}') else: module = parse_expr(mod_level) else: module = node.module # Regenerate alias to avoid trailing comma issue alias = cst.ImportAlias(name=alias.name, asname=alias.asname) self.imprts[name] = cst.ImportFrom(module=module, names=[alias])
def import_to_node_multi(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: body: List[cst.BaseSmallStatement] = [] names: List[cst.ImportAlias] = [] prev: Optional[cst.ImportAlias] = None following: List[str] = [] lpar_lines: List[cst.EmptyLine] = [] lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace() item_count = len(imp.items) for idx, item in enumerate(imp.items): name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None # Leading comments actually have to be trailing comments on the previous node. # That means putting them on the lpar node for the first item if item.comments.before: lines = [ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in item.comments.before ] if prev is None: lpar_lines.extend(lines) else: prev.comma.whitespace_after.empty_lines.extend( lines) # type: ignore # all items except the last needs whitespace to indent the *next* line/item indent = idx != (len(imp.items) - 1) first_line = cst.TrailingWhitespace() inline = COMMENT_INDENT.join(item.comments.inline) if inline: first_line = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) if idx == item_count - 1: following = item.comments.following + imp.comments.final else: following = item.comments.following after = cst.ParenthesizedWhitespace( indent=True, first_line=first_line, empty_lines=[ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in following ], last_line=cst.SimpleWhitespace( module.default_indent if indent else ""), ) node = cst.ImportAlias( name=name, asname=asname, comma=cst.Comma(whitespace_after=after), ) names.append(node) prev = node # from foo import ( # bar # ) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots # inline comment following lparen if imp.comments.first_inline: inline = COMMENT_INDENT.join(imp.comments.first_inline) lpar_inline = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) body = [ cst.ImportFrom( module=module_name, names=names, relative=relative, lpar=cst.LeftParen( whitespace_after=cst.ParenthesizedWhitespace( indent=True, first_line=lpar_inline, empty_lines=lpar_lines, last_line=cst.SimpleWhitespace(module.default_indent), ), ), rpar=cst.RightParen(), ) ] # import foo else: raise ValueError("can't render basic imports on multiple lines") # comment lines above import leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] # inline comments following import/rparen if imp.comments.last_inline: inline = COMMENT_INDENT.join(imp.comments.last_inline) trailing = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline)) else: trailing = cst.TrailingWhitespace() return cst.SimpleStatementLine( body=body, leading_lines=leading_lines, trailing_whitespace=trailing, )
class ImportFromParseTest(CSTNodeTest): @data_provider( ( # Simple from import statement { "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), "code": "from foo import bar", }, # From import statement with alias { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), ), "code": "from foo import bar as baz", }, # Multiple imports { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias(cst.Name("baz")), ), ), "code": "from foo import bar, baz", }, # Trailing comma { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), "code": "from foo import bar, baz,", }, # Star import statement { "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), "code": "from foo import *", }, # Simple relative import statement { "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .foo import bar", }, { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from ..foo import bar", }, # Relative only import { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .. import bar", }, # Parenthesis { "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), rpar=cst.RightParen(), ), "code": "from foo import (bar as baz)", }, # Verify whitespace works everywhere. { "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(" "), ), cst.Dot( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(" "), ), ), module=cst.Name("foo"), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), whitespace_after_from=cst.SimpleWhitespace(" "), whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "from . . foo import ( bar as baz , unittest as ut )", }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], **kwargs, )
class ImportFromCreateTest(CSTNodeTest): @data_provider( ( # Simple from import statement { "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), "code": "from foo import bar", }, # From import statement with alias { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), ), "code": "from foo import bar as baz", }, # Multiple imports { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar")), cst.ImportAlias(cst.Name("baz")), ), ), "code": "from foo import bar, baz", }, # Trailing comma { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()), cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), "code": "from foo import bar,baz,", "expected_position": CodeRange((1, 0), (1, 23)), }, # Star import statement { "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), "code": "from foo import *", "expected_position": CodeRange((1, 0), (1, 17)), }, # Simple relative import statement { "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .foo import bar", }, { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from ..foo import bar", }, # Relative only import { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .. import bar", }, # Parenthesis { "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), rpar=cst.RightParen(), ), "code": "from foo import (bar as baz)", "expected_position": CodeRange((1, 0), (1, 28)), }, # Verify whitespace works everywhere. { "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), module=cst.Name("foo"), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), whitespace_after_from=cst.SimpleWhitespace(" "), whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "from . . foo import ( bar as baz , unittest as ut )", "expected_position": CodeRange((1, 0), (1, 61)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.ImportFrom( module=None, names=(cst.ImportAlias(cst.Name("bar")),) ), "expected_re": "Must have a module specified", }, { "get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()), "expected_re": "at least one ImportAlias", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), lpar=cst.LeftParen(), ), "expected_re": "left paren without right paren", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), rpar=cst.RightParen(), ), "expected_re": "right paren without left paren", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen() ), "expected_re": "cannot have parens", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), rpar=cst.RightParen(), ), "expected_re": "cannot have parens", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_from=cst.SimpleWhitespace(""), ), "expected_re": "one space after from", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_before_import=cst.SimpleWhitespace(""), ), "expected_re": "one space before import", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_import=cst.SimpleWhitespace(""), ), "expected_re": "one space after import", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)