Ejemplo n.º 1
0
    def leave_ImportFrom(self, original_node: libcst.ImportFrom,
                         updated_node: libcst.ImportFrom) -> libcst.ImportFrom:
        if isinstance(updated_node.names, libcst.ImportStar):
            # There's nothing to do here!
            return updated_node

        # Get the module we're importing as a string, see if we have work to do.
        module = get_absolute_module_for_import(self.context.full_module_name,
                                                updated_node)
        if (module is None or module not in self.module_mapping
                and module not in self.alias_mapping):
            return updated_node

        # We have work to do, mark that we won't modify this again.
        imports_to_add = self.module_mapping.get(module, [])
        if module in self.module_mapping:
            del self.module_mapping[module]
        aliases_to_add = self.alias_mapping.get(module, [])
        if module in self.alias_mapping:
            del self.alias_mapping[module]

        # Now, do the actual update.
        return updated_node.with_changes(names=[
            *(libcst.ImportAlias(name=libcst.Name(imp))
              for imp in sorted(imports_to_add)),
            *(libcst.ImportAlias(
                name=libcst.Name(imp),
                asname=libcst.AsName(name=libcst.Name(alias)),
            ) for (imp, alias) in sorted(aliases_to_add)),
            *updated_node.names,
        ])
Ejemplo n.º 2
0
    def leave_ImportFrom(self, original_node: cst.ImportFrom,
                         updated_node: cst.ImportFrom) -> cst.ImportFrom:
        module = updated_node.module
        if module is None:
            return updated_node
        imported_module_name = get_full_name_for_node(module)
        names = original_node.names

        if imported_module_name is None or not isinstance(names, Sequence):
            return updated_node

        else:
            new_names = []
            for import_alias in names:
                alias_name = get_full_name_for_node(import_alias.name)
                if alias_name is not None:
                    qual_name = f"{imported_module_name}.{alias_name}"
                    if self.old_name == qual_name:

                        replacement_module = self.gen_replacement_module(
                            imported_module_name)
                        replacement_obj = self.gen_replacement(alias_name)
                        if not replacement_obj:
                            # The user has requested an `import` statement rather than an `from ... import`.
                            # This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal.
                            new_names.append(import_alias)
                            self.scheduled_removals.add(original_node)
                            continue

                        new_import_alias_name: Union[
                            cst.Attribute,
                            cst.Name] = self.gen_name_or_attr_node(
                                replacement_obj)
                        # Rename on the spot only if this is the only imported name under the module.
                        if len(names) == 1:
                            self.bypass_import = True
                            return updated_node.with_changes(
                                module=cst.parse_expression(
                                    replacement_module),
                                names=(cst.ImportAlias(
                                    name=new_import_alias_name), ),
                            )
                        # Or if the module name is to stay the same.
                        elif replacement_module == imported_module_name:
                            self.bypass_import = True
                            new_names.append(
                                cst.ImportAlias(name=new_import_alias_name))
                    else:
                        if self.old_name.startswith(qual_name + "."):
                            # This import might be in use elsewhere in the code, so schedule a potential removal.
                            self.scheduled_removals.add(original_node)
                        new_names.append(import_alias)

            return updated_node.with_changes(names=new_names)
        return updated_node
Ejemplo n.º 3
0
    def __get_required_imports(self):
        def find_required_modules(all_types):
            req_mod = set()
            for _, a_node in all_types:
                m = match.findall(
                    a_node.annotation,
                    match.Attribute(value=match.DoNotCare(),
                                    attr=match.DoNotCare()))
                if len(m) != 0:
                    for i in m:
                        req_mod.add([
                            n.value for n in match.findall(
                                i, match.Name(value=match.DoNotCare()))
                        ][0])
            return req_mod

        req_imports = []
        all_req_mods = find_required_modules(self.all_applied_types)
        all_type_names = set(
            chain.from_iterable(
                map(lambda t: regex.findall(r"\w+", t[0]),
                    self.all_applied_types)))

        typing_imports = PY_TYPING_MOD & all_type_names
        collection_imports = PY_COLLECTION_MOD & all_type_names

        if len(typing_imports) > 0:
            req_imports.append(
                cst.SimpleStatementLine(body=[
                    cst.ImportFrom(module=cst.Name(value="typing"),
                                   names=[
                                       cst.ImportAlias(name=cst.Name(value=t),
                                                       asname=None)
                                       for t in typing_imports
                                   ]),
                ]))
        if len(collection_imports) > 0:
            req_imports.append(cst.SimpleStatementLine(body=[cst.ImportFrom(module=cst.Name(value="collections"),
                                                       names=[cst.ImportAlias(name=cst.Name(value=t), asname=None) \
                                                              for t in collection_imports]),]))
        if len(all_req_mods) > 0:
            for mod_name in all_req_mods:
                req_imports.append(
                    cst.SimpleStatementLine(body=[
                        cst.Import(names=[
                            cst.ImportAlias(name=cst.Name(value=mod_name),
                                            asname=None)
                        ])
                    ]))

        return req_imports
Ejemplo n.º 4
0
    def refactor_import_star(self,
                             updated_node: cst.ImportFrom) -> cst.ImportFrom:
        """Add used import aliases to import star.

        :param updated_node: `cst.ImportFrom` node to refactor.
        :returns: refactored node.
        """
        is_multiline = len(self._used_names) > 3
        used_aliases: List[cst.ImportAlias] = []
        for name in self._used_names:

            # Skip any dotted name in order
            # to avoid names collision.
            if "." in name:
                continue

            # Initialy create a single line alias.
            cst_alias = cst.ImportAlias(
                name=cst.Name(name),
                comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
            )

            # Convert the single line alias to multiline
            # if there're more than 3 used names.
            if is_multiline:
                cst_alias = self._multiline_alias(cst_alias)

            used_aliases.append(cst_alias)

        return self._stylize(updated_node, used_aliases, is_multiline)
Ejemplo n.º 5
0
    def visit_Import(self, node) -> None:
        for alias in node.names:
            name = alias.asname.name.value if alias.asname is not None else alias.name.value

            # Regenerate alias to avoid trailing comma issue
            alias = cst.ImportAlias(name=alias.name, asname=alias.asname)
            self.imprts[name] = cst.Import(names=[alias])
Ejemplo n.º 6
0
 def visit_Assign(self, node) -> None:
     if (m.matches(node, m.Assign(targets=[m.AssignTarget(m.Name())]))
             and self.toplevel == 0):
         name = node.targets[0].target
         self.imprts[name.value] = cst.ImportFrom(
             module=parse_expr(self.mod),
             names=[cst.ImportAlias(name=name, asname=None)])
Ejemplo n.º 7
0
 def test_multiline_alias(self, init, indent):
     init.return_value = None
     transformer = transform.ImportTransformer(None, None)
     transformer._indentation = indent
     alias = transformer._multiline_alias(
         cst.ImportAlias(name=cst.Name("x")))
     assert alias.comma.whitespace_after.last_line.value == indent + " " * 4
Ejemplo n.º 8
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Ejemplo n.º 9
0
 def _add_annotation_to_imports(
         self, annotation: cst.Attribute) -> Union[cst.Name, cst.Attribute]:
     key = _get_attribute_as_string(annotation.value)
     # Don't attempt to re-import existing imports.
     if key in self.existing_imports:
         return annotation
     self._add_to_imports([cst.ImportAlias(name=annotation.attr)],
                          annotation.value, key)
     return annotation.attr
Ejemplo n.º 10
0
def import_to_node_single(imp: SortableImport,
                          module: cst.Module) -> cst.BaseStatement:
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    trailing_whitespace = cst.TrailingWhitespace()
    trailing_comments = list(imp.comments.first_inline)

    names: List[cst.ImportAlias] = []
    for item in imp.items:
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None
        node = cst.ImportAlias(name=name, asname=asname)
        names.append(node)
        trailing_comments += item.comments.before
        trailing_comments += item.comments.inline
        trailing_comments += item.comments.following

    trailing_comments += imp.comments.final
    trailing_comments += imp.comments.last_inline
    if trailing_comments:
        text = COMMENT_INDENT.join(trailing_comments)
        trailing_whitespace = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(text))

    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        line = cst.SimpleStatementLine(
            body=[
                cst.ImportFrom(module=module_name,
                               names=names,
                               relative=relative)
            ],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    else:
        line = cst.SimpleStatementLine(
            body=[cst.Import(names=names)],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    return line
Ejemplo n.º 11
0
 def _multiline_alias(self, alias: cst.ImportAlias) -> cst.ImportAlias:
     # Convert the given `alias` to multiline `alias`.
     return cst.ImportAlias(
         name=alias.name,
         asname=alias.asname,
         comma=cst.Comma(
             whitespace_after=ImportTransformer.
             _multiline_parenthesized_whitespace(self._indentation +
                                                 SPACE4)),
     )
Ejemplo n.º 12
0
 def _create_import_from_annotation(self,
                                    returns: cst.CSTNode) -> cst.CSTNode:
     # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
     if isinstance(returns.annotation, cst.Attribute):
         annotation = returns.annotation
         key = _get_attribute_as_string(annotation.value)
         self._add_to_imports([cst.ImportAlias(name=annotation.attr)],
                              annotation.value, key)
         return cst.Annotation(annotation=returns.annotation.attr)
     else:
         return returns
Ejemplo n.º 13
0
 def leave_StarImport(
     updated_node: cst.ImportFrom,
     imp: ImportFrom,
 ) -> Union[cst.ImportFrom, cst.RemovalSentinel]:
     if imp.suggestions:
         names_to_suggestions = [
             cst.ImportAlias(cst.Name(module)) for module in imp.suggestions
         ]
         return updated_node.with_changes(names=names_to_suggestions)
     else:
         return cst.RemoveFromParent()
Ejemplo n.º 14
0
 def leave_StarImport(self, original_node, updated_node, **kwargs):
     imp = kwargs["imp"]
     if imp["modules"]:
         modules = ",".join(imp["modules"])
         names_to_suggestion = []
         for module in modules.split(","):
             names_to_suggestion.append(cst.ImportAlias(cst.Name(module)))
         return updated_node.with_changes(names=names_to_suggestion)
     else:
         if imp["module"]:
             return cst.RemoveFromParent()
     return original_node
Ejemplo n.º 15
0
    def leave_ImportFrom(self, original_node: libcst.ImportFrom,
                         updated_node: libcst.ImportFrom) -> libcst.ImportFrom:
        if len(updated_node.relative) > 0 or updated_node.module is None:
            # Don't support relative-only imports at the moment.
            return updated_node
        if updated_node.names == "*":
            # There's nothing to do here!
            return updated_node

        # Get the module we're importing as a string, see if we have work to do
        module = self._get_string_name(updated_node.module)
        if module not in self.module_mapping and module not in self.alias_mapping:
            return updated_node

        # We have work to do, mark that we won't modify this again.
        imports_to_add = self.module_mapping.get(module, [])
        if module in self.module_mapping:
            del self.module_mapping[module]
        aliases_to_add = self.alias_mapping.get(module, [])
        if module in self.alias_mapping:
            del self.alias_mapping[module]

        # Now, do the actual update.
        return updated_node.with_changes(names=(
            *[
                libcst.ImportAlias(name=libcst.Name(imp))
                for imp in imports_to_add
            ],
            *[
                libcst.ImportAlias(
                    name=libcst.Name(imp),
                    asname=libcst.AsName(name=libcst.Name(alias)),
                ) for (imp, alias) in aliases_to_add
            ],
            *updated_node.names,
        ))
Ejemplo n.º 16
0
    def leave_Import(self, original_node: cst.Import,
                     updated_node: cst.Import) -> cst.Import:
        new_names = []
        for import_alias in updated_node.names:
            import_alias_name = import_alias.name
            import_alias_full_name = get_full_name_for_node(import_alias_name)
            if import_alias_full_name is None:
                raise Exception(
                    "Could not parse full name for ImportAlias.name node.")

            if isinstance(import_alias_name,
                          cst.Name) and self.old_name.startswith(
                              import_alias_full_name + "."):
                # Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias.
                new_names.append(import_alias)
                self.scheduled_removals.add(original_node)
                new_names.append(
                    cst.ImportAlias(name=cst.Name(
                        value=self.gen_replacement_module(
                            import_alias_full_name))))
                self.bypass_import = True
            elif isinstance(import_alias_name,
                            cst.Attribute) and self.old_name.startswith(
                                import_alias_full_name + "."):
                # Same idea as above.
                new_names.append(import_alias)
                self.scheduled_removals.add(original_node)
                new_name_node: Union[
                    cst.Attribute, cst.Name] = self.gen_name_or_attr_node(
                        self.gen_replacement_module(import_alias_full_name))
                new_names.append(cst.ImportAlias(name=new_name_node))
                self.bypass_import = True
            else:
                new_names.append(import_alias)

        return updated_node.with_changes(names=new_names)
Ejemplo n.º 17
0
    def visit_ImportFrom(self, node) -> None:
        for alias in node.names:
            name = alias.asname.name.value if alias.asname is not None else alias.name.value

            level = len(node.relative)
            if level > 0:
                parts = self.mod.split('.')
                mod_level = '.'.join(
                    parts[:-level]) if len(parts) > 1 else parts[0]
                if node.module is not None:
                    module = parse_expr(f'{mod_level}.{a2s(node.module)}')
                else:
                    module = parse_expr(mod_level)
            else:
                module = node.module

            # Regenerate alias to avoid trailing comma issue
            alias = cst.ImportAlias(name=alias.name, asname=alias.asname)
            self.imprts[name] = cst.ImportFrom(module=module, names=[alias])
Ejemplo n.º 18
0
 def leave_ImportFrom(
     self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
 ) -> cst.ImportFrom:
     self.import_statements.append(original_node)
     # pyre-fixme[6]: Expected `Union[Attribute, Name]` for 1st param but got
     #  `Optional[Union[Attribute, Name]]`.
     key = _get_attribute_as_string(original_node.module)
     import_names = updated_node.names
     module = original_node.module
     if (
         module is not None
         and module.value in self.imports
         and not isinstance(import_names, cst.ImportStar)
     ):
         names_as_string = [_get_name_as_string(name.name) for name in import_names]
         updated_names = self.imports[key].names.union(set(names_as_string))
         names = [cst.ImportAlias(cst.Name(name)) for name in sorted(updated_names)]
         updated_node = updated_node.with_changes(names=tuple(names))
         del self.imports[key]
     return updated_node
Ejemplo n.º 19
0
def import_to_node_multi(imp: SortableImport,
                         module: cst.Module) -> cst.BaseStatement:
    body: List[cst.BaseSmallStatement] = []
    names: List[cst.ImportAlias] = []
    prev: Optional[cst.ImportAlias] = None
    following: List[str] = []
    lpar_lines: List[cst.EmptyLine] = []
    lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace()

    item_count = len(imp.items)
    for idx, item in enumerate(imp.items):
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None

        # Leading comments actually have to be trailing comments on the previous node.
        # That means putting them on the lpar node for the first item
        if item.comments.before:
            lines = [
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in item.comments.before
            ]
            if prev is None:
                lpar_lines.extend(lines)
            else:
                prev.comma.whitespace_after.empty_lines.extend(
                    lines)  # type: ignore

        # all items except the last needs whitespace to indent the *next* line/item
        indent = idx != (len(imp.items) - 1)

        first_line = cst.TrailingWhitespace()
        inline = COMMENT_INDENT.join(item.comments.inline)
        if inline:
            first_line = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        if idx == item_count - 1:
            following = item.comments.following + imp.comments.final
        else:
            following = item.comments.following

        after = cst.ParenthesizedWhitespace(
            indent=True,
            first_line=first_line,
            empty_lines=[
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in following
            ],
            last_line=cst.SimpleWhitespace(
                module.default_indent if indent else ""),
        )

        node = cst.ImportAlias(
            name=name,
            asname=asname,
            comma=cst.Comma(whitespace_after=after),
        )
        names.append(node)
        prev = node

    # from foo import (
    #     bar
    # )
    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        # inline comment following lparen
        if imp.comments.first_inline:
            inline = COMMENT_INDENT.join(imp.comments.first_inline)
            lpar_inline = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        body = [
            cst.ImportFrom(
                module=module_name,
                names=names,
                relative=relative,
                lpar=cst.LeftParen(
                    whitespace_after=cst.ParenthesizedWhitespace(
                        indent=True,
                        first_line=lpar_inline,
                        empty_lines=lpar_lines,
                        last_line=cst.SimpleWhitespace(module.default_indent),
                    ), ),
                rpar=cst.RightParen(),
            )
        ]

    # import foo
    else:
        raise ValueError("can't render basic imports on multiple lines")

    # comment lines above import
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    # inline comments following import/rparen
    if imp.comments.last_inline:
        inline = COMMENT_INDENT.join(imp.comments.last_inline)
        trailing = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(inline))
    else:
        trailing = cst.TrailingWhitespace()

    return cst.SimpleStatementLine(
        body=body,
        leading_lines=leading_lines,
        trailing_whitespace=trailing,
    )
Ejemplo n.º 20
0
 def make_simple_package_import(package: str) -> cst.Import:
     assert not "." in package, "this only supports a root package, e.g. 'import os'"
     return cst.Import(names=[cst.ImportAlias(name=cst.Name(package))])
Ejemplo n.º 21
0
class ImportCreateTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
                "expected_position": CodeRange((1, 0), (1, 23)),
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
                "expected_position": CodeRange((1, 0), (1, 46)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.Import(names=()),
                "expected_re": "at least one ImportAlias",
            },
            {
                "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(),
                        ),
                    )
                ),
                "expected_re": "trailing comma",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "at least one space",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 22
0
 def _add_annotation_to_imports(self, annotation: cst.Attribute) -> cst.Name:
     key = _get_attribute_as_string(annotation.value)
     self._add_to_imports(
         [cst.ImportAlias(name=annotation.attr)], annotation.value, key
     )
     return annotation.attr
Ejemplo n.º 23
0
class StatementTest(UnitTest):
    @data_provider(
        (
            # Simple imports that are already absolute.
            (None, "from a.b import c", "a.b"),
            ("x.y.z", "from a.b import c", "a.b"),
            # Relative import that can't be resolved due to missing module.
            (None, "from ..w import c", None),
            # Relative import that goes past the module level.
            ("x", "from ...y import z", None),
            ("x.y.z", "from .....w import c", None),
            ("x.y.z", "from ... import c", None),
            # Correct resolution of absolute from relative modules.
            ("x.y.z", "from . import c", "x.y"),
            ("x.y.z", "from .. import c", "x"),
            ("x.y.z", "from .w import c", "x.y.w"),
            ("x.y.z", "from ..w import c", "x.w"),
            ("x.y.z", "from ...w import c", "w"),
        )
    )
    def test_get_absolute_module(
        self, module: Optional[str], importfrom: str, output: Optional[str],
    ) -> None:
        node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
        assert len(node.body) == 1, "Unexpected number of statements!"
        import_node = ensure_type(node.body[0], cst.ImportFrom)

        self.assertEqual(get_absolute_module_for_import(module, import_node), output)
        if output is None:
            with self.assertRaises(Exception):
                get_absolute_module_for_import_or_raise(module, import_node)
        else:
            self.assertEqual(
                get_absolute_module_for_import_or_raise(module, import_node), output
            )

    @data_provider(
        (
            # Nodes without an asname
            (cst.ImportAlias(name=cst.Name("foo")), "foo", None),
            (
                cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
                "foo.bar",
                None,
            ),
            # Nodes with an asname
            (
                cst.ImportAlias(
                    name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz"))
                ),
                "foo",
                "baz",
            ),
            (
                cst.ImportAlias(
                    name=cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                    asname=cst.AsName(name=cst.Name("baz")),
                ),
                "foo.bar",
                "baz",
            ),
        )
    )
    def test_importalias_helpers(
        self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str]
    ) -> None:
        self.assertEqual(alias_node.evaluated_name, full_name)
        self.assertEqual(alias_node.evaluated_alias, alias)
Ejemplo n.º 24
0
class ImportParseTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(
                parse_statement(code), cst.SimpleStatementLine
            ).body[0],
            **kwargs,
        )
Ejemplo n.º 25
0
class ImportFromCreateTest(CSTNodeTest):
    @data_provider(
        (
            # Simple from import statement
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "code": "from foo import bar",
            },
            # From import statement with alias
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                ),
                "code": "from foo import bar as baz",
            },
            # Multiple imports
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(cst.Name("bar")),
                        cst.ImportAlias(cst.Name("baz")),
                    ),
                ),
                "code": "from foo import bar, baz",
            },
            # Trailing comma
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()),
                        cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
                    ),
                ),
                "code": "from foo import bar,baz,",
                "expected_position": CodeRange((1, 0), (1, 23)),
            },
            # Star import statement
            {
                "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
                "code": "from foo import *",
                "expected_position": CodeRange((1, 0), (1, 17)),
            },
            # Simple relative import statement
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(),),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .foo import bar",
            },
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from ..foo import bar",
            },
            # Relative only import
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=None,
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .. import bar",
            },
            # Parenthesis
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                    rpar=cst.RightParen(),
                ),
                "code": "from foo import (bar as baz)",
                "expected_position": CodeRange((1, 0), (1, 28)),
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.ImportFrom(
                    relative=(
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                    ),
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                    whitespace_before_import=cst.SimpleWhitespace("  "),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "from   .  . foo  import  ( bar  as  baz ,  unittest  as  ut )",
                "expected_position": CodeRange((1, 0), (1, 61)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.ImportFrom(
                    module=None, names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "expected_re": "Must have a module specified",
            },
            {
                "get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()),
                "expected_re": "at least one ImportAlias",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    lpar=cst.LeftParen(),
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    rpar=cst.RightParen(),
                ),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen()
                ),
                "expected_re": "cannot have parens",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=cst.ImportStar(),
                    rpar=cst.RightParen(),
                ),
                "expected_re": "cannot have parens",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space after from",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_before_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space before import",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_after_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space after import",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 26
0
class ImportFromParseTest(CSTNodeTest):
    @data_provider(
        (
            # Simple from import statement
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "code": "from foo import bar",
            },
            # From import statement with alias
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                ),
                "code": "from foo import bar as baz",
            },
            # Multiple imports
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(cst.Name("baz")),
                    ),
                ),
                "code": "from foo import bar, baz",
            },
            # Trailing comma
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
                    ),
                ),
                "code": "from foo import bar, baz,",
            },
            # Star import statement
            {
                "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
                "code": "from foo import *",
            },
            # Simple relative import statement
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(),),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .foo import bar",
            },
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from ..foo import bar",
            },
            # Relative only import
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=None,
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .. import bar",
            },
            # Parenthesis
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                    rpar=cst.RightParen(),
                ),
                "code": "from foo import (bar as baz)",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.ImportFrom(
                    relative=(
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                    ),
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
                    whitespace_after_from=cst.SimpleWhitespace("   "),
                    whitespace_before_import=cst.SimpleWhitespace("  "),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "from   .  . foo  import  ( bar  as  baz ,  unittest  as  ut )",
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(
                parse_statement(code), cst.SimpleStatementLine
            ).body[0],
            **kwargs,
        )