def leave_ImportFrom(self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom) -> libcst.ImportFrom: if isinstance(updated_node.names, libcst.ImportStar): # There's nothing to do here! return updated_node # Get the module we're importing as a string, see if we have work to do. module = get_absolute_module_for_import(self.context.full_module_name, updated_node) if (module is None or module not in self.module_mapping and module not in self.alias_mapping): return updated_node # We have work to do, mark that we won't modify this again. imports_to_add = self.module_mapping.get(module, []) if module in self.module_mapping: del self.module_mapping[module] aliases_to_add = self.alias_mapping.get(module, []) if module in self.alias_mapping: del self.alias_mapping[module] # Now, do the actual update. return updated_node.with_changes(names=[ *(libcst.ImportAlias(name=libcst.Name(imp)) for imp in sorted(imports_to_add)), *(libcst.ImportAlias( name=libcst.Name(imp), asname=libcst.AsName(name=libcst.Name(alias)), ) for (imp, alias) in sorted(aliases_to_add)), *updated_node.names, ])
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: module = updated_node.module if module is None: return updated_node imported_module_name = get_full_name_for_node(module) names = original_node.names if imported_module_name is None or not isinstance(names, Sequence): return updated_node else: new_names = [] for import_alias in names: alias_name = get_full_name_for_node(import_alias.name) if alias_name is not None: qual_name = f"{imported_module_name}.{alias_name}" if self.old_name == qual_name: replacement_module = self.gen_replacement_module( imported_module_name) replacement_obj = self.gen_replacement(alias_name) if not replacement_obj: # The user has requested an `import` statement rather than an `from ... import`. # This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal. new_names.append(import_alias) self.scheduled_removals.add(original_node) continue new_import_alias_name: Union[ cst.Attribute, cst.Name] = self.gen_name_or_attr_node( replacement_obj) # Rename on the spot only if this is the only imported name under the module. if len(names) == 1: self.bypass_import = True return updated_node.with_changes( module=cst.parse_expression( replacement_module), names=(cst.ImportAlias( name=new_import_alias_name), ), ) # Or if the module name is to stay the same. elif replacement_module == imported_module_name: self.bypass_import = True new_names.append( cst.ImportAlias(name=new_import_alias_name)) else: if self.old_name.startswith(qual_name + "."): # This import might be in use elsewhere in the code, so schedule a potential removal. self.scheduled_removals.add(original_node) new_names.append(import_alias) return updated_node.with_changes(names=new_names) return updated_node
def __get_required_imports(self): def find_required_modules(all_types): req_mod = set() for _, a_node in all_types: m = match.findall( a_node.annotation, match.Attribute(value=match.DoNotCare(), attr=match.DoNotCare())) if len(m) != 0: for i in m: req_mod.add([ n.value for n in match.findall( i, match.Name(value=match.DoNotCare())) ][0]) return req_mod req_imports = [] all_req_mods = find_required_modules(self.all_applied_types) all_type_names = set( chain.from_iterable( map(lambda t: regex.findall(r"\w+", t[0]), self.all_applied_types))) typing_imports = PY_TYPING_MOD & all_type_names collection_imports = PY_COLLECTION_MOD & all_type_names if len(typing_imports) > 0: req_imports.append( cst.SimpleStatementLine(body=[ cst.ImportFrom(module=cst.Name(value="typing"), names=[ cst.ImportAlias(name=cst.Name(value=t), asname=None) for t in typing_imports ]), ])) if len(collection_imports) > 0: req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"), names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \ for t in collection_imports]),])) if len(all_req_mods) > 0: for mod_name in all_req_mods: req_imports.append( cst.SimpleStatementLine(body=[ cst.Import(names=[ cst.ImportAlias(name=cst.Name(value=mod_name), asname=None) ]) ])) return req_imports
def refactor_import_star(self, updated_node: cst.ImportFrom) -> cst.ImportFrom: """Add used import aliases to import star. :param updated_node: `cst.ImportFrom` node to refactor. :returns: refactored node. """ is_multiline = len(self._used_names) > 3 used_aliases: List[cst.ImportAlias] = [] for name in self._used_names: # Skip any dotted name in order # to avoid names collision. if "." in name: continue # Initialy create a single line alias. cst_alias = cst.ImportAlias( name=cst.Name(name), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ) # Convert the single line alias to multiline # if there're more than 3 used names. if is_multiline: cst_alias = self._multiline_alias(cst_alias) used_aliases.append(cst_alias) return self._stylize(updated_node, used_aliases, is_multiline)
def visit_Import(self, node) -> None: for alias in node.names: name = alias.asname.name.value if alias.asname is not None else alias.name.value # Regenerate alias to avoid trailing comma issue alias = cst.ImportAlias(name=alias.name, asname=alias.asname) self.imprts[name] = cst.Import(names=[alias])
def visit_Assign(self, node) -> None: if (m.matches(node, m.Assign(targets=[m.AssignTarget(m.Name())])) and self.toplevel == 0): name = node.targets[0].target self.imprts[name.value] = cst.ImportFrom( module=parse_expr(self.mod), names=[cst.ImportAlias(name=name, asname=None)])
def test_multiline_alias(self, init, indent): init.return_value = None transformer = transform.ImportTransformer(None, None) transformer._indentation = indent alias = transformer._multiline_alias( cst.ImportAlias(name=cst.Name("x"))) assert alias.comma.whitespace_after.last_line.value == indent + " " * 4
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def _add_annotation_to_imports( self, annotation: cst.Attribute) -> Union[cst.Name, cst.Attribute]: key = _get_attribute_as_string(annotation.value) # Don't attempt to re-import existing imports. if key in self.existing_imports: return annotation self._add_to_imports([cst.ImportAlias(name=annotation.attr)], annotation.value, key) return annotation.attr
def import_to_node_single(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] trailing_whitespace = cst.TrailingWhitespace() trailing_comments = list(imp.comments.first_inline) names: List[cst.ImportAlias] = [] for item in imp.items: name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None node = cst.ImportAlias(name=name, asname=asname) names.append(node) trailing_comments += item.comments.before trailing_comments += item.comments.inline trailing_comments += item.comments.following trailing_comments += imp.comments.final trailing_comments += imp.comments.last_inline if trailing_comments: text = COMMENT_INDENT.join(trailing_comments) trailing_whitespace = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(text)) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots line = cst.SimpleStatementLine( body=[ cst.ImportFrom(module=module_name, names=names, relative=relative) ], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) else: line = cst.SimpleStatementLine( body=[cst.Import(names=names)], leading_lines=leading_lines, trailing_whitespace=trailing_whitespace, ) return line
def _multiline_alias(self, alias: cst.ImportAlias) -> cst.ImportAlias: # Convert the given `alias` to multiline `alias`. return cst.ImportAlias( name=alias.name, asname=alias.asname, comma=cst.Comma( whitespace_after=ImportTransformer. _multiline_parenthesized_whitespace(self._indentation + SPACE4)), )
def _create_import_from_annotation(self, returns: cst.CSTNode) -> cst.CSTNode: # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. if isinstance(returns.annotation, cst.Attribute): annotation = returns.annotation key = _get_attribute_as_string(annotation.value) self._add_to_imports([cst.ImportAlias(name=annotation.attr)], annotation.value, key) return cst.Annotation(annotation=returns.annotation.attr) else: return returns
def leave_StarImport( updated_node: cst.ImportFrom, imp: ImportFrom, ) -> Union[cst.ImportFrom, cst.RemovalSentinel]: if imp.suggestions: names_to_suggestions = [ cst.ImportAlias(cst.Name(module)) for module in imp.suggestions ] return updated_node.with_changes(names=names_to_suggestions) else: return cst.RemoveFromParent()
def leave_StarImport(self, original_node, updated_node, **kwargs): imp = kwargs["imp"] if imp["modules"]: modules = ",".join(imp["modules"]) names_to_suggestion = [] for module in modules.split(","): names_to_suggestion.append(cst.ImportAlias(cst.Name(module))) return updated_node.with_changes(names=names_to_suggestion) else: if imp["module"]: return cst.RemoveFromParent() return original_node
def leave_ImportFrom(self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom) -> libcst.ImportFrom: if len(updated_node.relative) > 0 or updated_node.module is None: # Don't support relative-only imports at the moment. return updated_node if updated_node.names == "*": # There's nothing to do here! return updated_node # Get the module we're importing as a string, see if we have work to do module = self._get_string_name(updated_node.module) if module not in self.module_mapping and module not in self.alias_mapping: return updated_node # We have work to do, mark that we won't modify this again. imports_to_add = self.module_mapping.get(module, []) if module in self.module_mapping: del self.module_mapping[module] aliases_to_add = self.alias_mapping.get(module, []) if module in self.alias_mapping: del self.alias_mapping[module] # Now, do the actual update. return updated_node.with_changes(names=( *[ libcst.ImportAlias(name=libcst.Name(imp)) for imp in imports_to_add ], *[ libcst.ImportAlias( name=libcst.Name(imp), asname=libcst.AsName(name=libcst.Name(alias)), ) for (imp, alias) in aliases_to_add ], *updated_node.names, ))
def leave_Import(self, original_node: cst.Import, updated_node: cst.Import) -> cst.Import: new_names = [] for import_alias in updated_node.names: import_alias_name = import_alias.name import_alias_full_name = get_full_name_for_node(import_alias_name) if import_alias_full_name is None: raise Exception( "Could not parse full name for ImportAlias.name node.") if isinstance(import_alias_name, cst.Name) and self.old_name.startswith( import_alias_full_name + "."): # Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias. new_names.append(import_alias) self.scheduled_removals.add(original_node) new_names.append( cst.ImportAlias(name=cst.Name( value=self.gen_replacement_module( import_alias_full_name)))) self.bypass_import = True elif isinstance(import_alias_name, cst.Attribute) and self.old_name.startswith( import_alias_full_name + "."): # Same idea as above. new_names.append(import_alias) self.scheduled_removals.add(original_node) new_name_node: Union[ cst.Attribute, cst.Name] = self.gen_name_or_attr_node( self.gen_replacement_module(import_alias_full_name)) new_names.append(cst.ImportAlias(name=new_name_node)) self.bypass_import = True else: new_names.append(import_alias) return updated_node.with_changes(names=new_names)
def visit_ImportFrom(self, node) -> None: for alias in node.names: name = alias.asname.name.value if alias.asname is not None else alias.name.value level = len(node.relative) if level > 0: parts = self.mod.split('.') mod_level = '.'.join( parts[:-level]) if len(parts) > 1 else parts[0] if node.module is not None: module = parse_expr(f'{mod_level}.{a2s(node.module)}') else: module = parse_expr(mod_level) else: module = node.module # Regenerate alias to avoid trailing comma issue alias = cst.ImportAlias(name=alias.name, asname=alias.asname) self.imprts[name] = cst.ImportFrom(module=module, names=[alias])
def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom ) -> cst.ImportFrom: self.import_statements.append(original_node) # pyre-fixme[6]: Expected `Union[Attribute, Name]` for 1st param but got # `Optional[Union[Attribute, Name]]`. key = _get_attribute_as_string(original_node.module) import_names = updated_node.names module = original_node.module if ( module is not None and module.value in self.imports and not isinstance(import_names, cst.ImportStar) ): names_as_string = [_get_name_as_string(name.name) for name in import_names] updated_names = self.imports[key].names.union(set(names_as_string)) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(updated_names)] updated_node = updated_node.with_changes(names=tuple(names)) del self.imports[key] return updated_node
def import_to_node_multi(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: body: List[cst.BaseSmallStatement] = [] names: List[cst.ImportAlias] = [] prev: Optional[cst.ImportAlias] = None following: List[str] = [] lpar_lines: List[cst.EmptyLine] = [] lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace() item_count = len(imp.items) for idx, item in enumerate(imp.items): name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None # Leading comments actually have to be trailing comments on the previous node. # That means putting them on the lpar node for the first item if item.comments.before: lines = [ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in item.comments.before ] if prev is None: lpar_lines.extend(lines) else: prev.comma.whitespace_after.empty_lines.extend( lines) # type: ignore # all items except the last needs whitespace to indent the *next* line/item indent = idx != (len(imp.items) - 1) first_line = cst.TrailingWhitespace() inline = COMMENT_INDENT.join(item.comments.inline) if inline: first_line = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) if idx == item_count - 1: following = item.comments.following + imp.comments.final else: following = item.comments.following after = cst.ParenthesizedWhitespace( indent=True, first_line=first_line, empty_lines=[ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in following ], last_line=cst.SimpleWhitespace( module.default_indent if indent else ""), ) node = cst.ImportAlias( name=name, asname=asname, comma=cst.Comma(whitespace_after=after), ) names.append(node) prev = node # from foo import ( # bar # ) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots # inline comment following lparen if imp.comments.first_inline: inline = COMMENT_INDENT.join(imp.comments.first_inline) lpar_inline = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) body = [ cst.ImportFrom( module=module_name, names=names, relative=relative, lpar=cst.LeftParen( whitespace_after=cst.ParenthesizedWhitespace( indent=True, first_line=lpar_inline, empty_lines=lpar_lines, last_line=cst.SimpleWhitespace(module.default_indent), ), ), rpar=cst.RightParen(), ) ] # import foo else: raise ValueError("can't render basic imports on multiple lines") # comment lines above import leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] # inline comments following import/rparen if imp.comments.last_inline: inline = COMMENT_INDENT.join(imp.comments.last_inline) trailing = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline)) else: trailing = cst.TrailingWhitespace() return cst.SimpleStatementLine( body=body, leading_lines=leading_lines, trailing_whitespace=trailing, )
def make_simple_package_import(package: str) -> cst.Import: assert not "." in package, "this only supports a root package, e.g. 'import os'" return cst.Import(names=[cst.ImportAlias(name=cst.Name(package))])
class ImportCreateTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", "expected_position": CodeRange((1, 0), (1, 23)), }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", "expected_position": CodeRange((1, 0), (1, 46)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.Import(names=()), "expected_re": "at least one ImportAlias", }, { "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(), ), ) ), "expected_re": "trailing comma", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ), whitespace_after_import=cst.SimpleWhitespace(""), ), "expected_re": "at least one space", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def _add_annotation_to_imports(self, annotation: cst.Attribute) -> cst.Name: key = _get_attribute_as_string(annotation.value) self._add_to_imports( [cst.ImportAlias(name=annotation.attr)], annotation.value, key ) return annotation.attr
class StatementTest(UnitTest): @data_provider( ( # Simple imports that are already absolute. (None, "from a.b import c", "a.b"), ("x.y.z", "from a.b import c", "a.b"), # Relative import that can't be resolved due to missing module. (None, "from ..w import c", None), # Relative import that goes past the module level. ("x", "from ...y import z", None), ("x.y.z", "from .....w import c", None), ("x.y.z", "from ... import c", None), # Correct resolution of absolute from relative modules. ("x.y.z", "from . import c", "x.y"), ("x.y.z", "from .. import c", "x"), ("x.y.z", "from .w import c", "x.y.w"), ("x.y.z", "from ..w import c", "x.w"), ("x.y.z", "from ...w import c", "w"), ) ) def test_get_absolute_module( self, module: Optional[str], importfrom: str, output: Optional[str], ) -> None: node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine) assert len(node.body) == 1, "Unexpected number of statements!" import_node = ensure_type(node.body[0], cst.ImportFrom) self.assertEqual(get_absolute_module_for_import(module, import_node), output) if output is None: with self.assertRaises(Exception): get_absolute_module_for_import_or_raise(module, import_node) else: self.assertEqual( get_absolute_module_for_import_or_raise(module, import_node), output ) @data_provider( ( # Nodes without an asname (cst.ImportAlias(name=cst.Name("foo")), "foo", None), ( cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))), "foo.bar", None, ), # Nodes with an asname ( cst.ImportAlias( name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz")) ), "foo", "baz", ), ( cst.ImportAlias( name=cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(name=cst.Name("baz")), ), "foo.bar", "baz", ), ) ) def test_importalias_helpers( self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str] ) -> None: self.assertEqual(alias_node.evaluated_name, full_name) self.assertEqual(alias_node.evaluated_alias, alias)
class ImportParseTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], **kwargs, )
class ImportFromCreateTest(CSTNodeTest): @data_provider( ( # Simple from import statement { "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), "code": "from foo import bar", }, # From import statement with alias { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), ), "code": "from foo import bar as baz", }, # Multiple imports { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar")), cst.ImportAlias(cst.Name("baz")), ), ), "code": "from foo import bar, baz", }, # Trailing comma { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()), cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), "code": "from foo import bar,baz,", "expected_position": CodeRange((1, 0), (1, 23)), }, # Star import statement { "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), "code": "from foo import *", "expected_position": CodeRange((1, 0), (1, 17)), }, # Simple relative import statement { "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .foo import bar", }, { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from ..foo import bar", }, # Relative only import { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .. import bar", }, # Parenthesis { "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), rpar=cst.RightParen(), ), "code": "from foo import (bar as baz)", "expected_position": CodeRange((1, 0), (1, 28)), }, # Verify whitespace works everywhere. { "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), module=cst.Name("foo"), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), whitespace_after_from=cst.SimpleWhitespace(" "), whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "from . . foo import ( bar as baz , unittest as ut )", "expected_position": CodeRange((1, 0), (1, 61)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.ImportFrom( module=None, names=(cst.ImportAlias(cst.Name("bar")),) ), "expected_re": "Must have a module specified", }, { "get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()), "expected_re": "at least one ImportAlias", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), lpar=cst.LeftParen(), ), "expected_re": "left paren without right paren", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), rpar=cst.RightParen(), ), "expected_re": "right paren without left paren", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen() ), "expected_re": "cannot have parens", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), rpar=cst.RightParen(), ), "expected_re": "cannot have parens", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_from=cst.SimpleWhitespace(""), ), "expected_re": "one space after from", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_before_import=cst.SimpleWhitespace(""), ), "expected_re": "one space before import", }, { "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_import=cst.SimpleWhitespace(""), ), "expected_re": "one space after import", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class ImportFromParseTest(CSTNodeTest): @data_provider( ( # Simple from import statement { "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), "code": "from foo import bar", }, # From import statement with alias { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), ), "code": "from foo import bar as baz", }, # Multiple imports { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias(cst.Name("baz")), ), ), "code": "from foo import bar, baz", }, # Trailing comma { "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( cst.Name("bar"), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), "code": "from foo import bar, baz,", }, # Star import statement { "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), "code": "from foo import *", }, # Simple relative import statement { "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .foo import bar", }, { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from ..foo import bar", }, # Relative only import { "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), "code": "from .. import bar", }, # Parenthesis { "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName(cst.Name("baz")) ), ), rpar=cst.RightParen(), ), "code": "from foo import (bar as baz)", }, # Verify whitespace works everywhere. { "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(" "), ), cst.Dot( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(" "), ), ), module=cst.Name("foo"), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), names=( cst.ImportAlias( cst.Name("bar"), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), whitespace_after_from=cst.SimpleWhitespace(" "), whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "from . . foo import ( bar as baz , unittest as ut )", }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], **kwargs, )