Ejemplo n.º 1
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]
        if not self.toplevel_annotations and not fresh_class_definitions:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None)
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Ejemplo n.º 2
0
    def leave_Module(
        self,
        original_node: cst.Module,
        updated_node: cst.Module,
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]

        # NOTE: The entire change will also be abandoned if
        # self.annotation_counts is all 0s, so if adding any new category make
        # sure to record it there.
        if not (self.toplevel_annotations or fresh_class_definitions
                or self.annotations.typevars):
            return updated_node

        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = self._apply_annotation_to_attribute_or_global(
                name=name,
                annotation=annotation,
                value=None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        # TypeVar definitions could be scattered through the file, so do not
        # attempt to put new ones with existing ones, just add them at the top.
        typevars = {
            k: v
            for k, v in self.annotations.typevars.items()
            if k not in self.typevars
        }
        if typevars:
            for var, stmt in typevars.items():
                toplevel_statements.append(cst.Newline())
                toplevel_statements.append(stmt)
                self.annotation_counts.typevars_and_generics_added += 1
            toplevel_statements.append(cst.Newline())

        self.annotation_counts.classes_added = len(fresh_class_definitions)
        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Ejemplo n.º 3
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Ejemplo n.º 4
0
 def visit_Module(self, node: cst.Module) -> None:
     if self.rule_disabled:
         return
     if not m.matches(
             node, m.Module(header=[*self.header_matcher,
                                    m.ZeroOrMore()])):
         self.report(
             node,
             replacement=node.with_changes(
                 header=[*self.header_replacement, *node.header]),
         )
Ejemplo n.º 5
0
 def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
Ejemplo n.º 6
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        blocks = sortable_blocks(updated_node.body, config=self.config)
        body: List[cst.CSTNode] = list(updated_node.body)

        for b in blocks:
            initial_blank, initial_comment = partition_leading_lines(
                b.stmts[0].node.leading_lines)
            b.stmts[0].node = b.stmts[0].node.with_changes(
                leading_lines=initial_comment)
            sorted_stmts = fixup_whitespace(initial_blank, sorted(b.stmts))
            body[b.start_idx:b.end_idx] = [s.node for s in sorted_stmts]
        return updated_node.with_changes(body=body)
Ejemplo n.º 7
0
 def update_imports(
     self,
     original_module: Module,
     updated_module: Module,
     import_name: str,
     updated_import_node: SimpleStatementLine,
     current_imports: Dict[str, str],
     new_imports: Set[str],
     noqa: bool,
 ) -> Module:
     if not new_imports:
         return updated_module
     noqa_comment = "  # noqa" if noqa else ""
     if not updated_import_node:
         i = -1
         blank_lines = "\n\n"
         if self.last_import_node_stmt:
             blank_lines = ""
             for i, (original, updated) in enumerate(
                 zip(original_module.body, updated_module.body)
             ):
                 if original is self.last_import_node_stmt:
                     break
         stmt = parse_module(
             f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}",
             config=updated_module.config_for_parsing,
         )
         body = list(updated_module.body)
         self.last_import_node_stmt = stmt
         return updated_module.with_changes(
             body=body[: i + 1] + stmt.children + body[i + 1 :]
         )
     else:
         if "*" not in current_imports:
             current_imports_set = {
                 f"{k}" if k == v else f"{k} as {v}"
                 for k, v in current_imports.items()
             }
             stmt = parse_statement(
                 f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}"
             )
             return updated_module.deep_replace(updated_import_node, stmt)
             # for i, (original, updated) in enumerate(
             #     zip(original_module.body, updated_module.body)
             # ):
             #     if original is original_import_node:
             #         body = list(updated_module.body)
             #         return updated_module.with_changes(
             #             body=body[:i] + [stmt] + body[i + 1 :]
             #         )
     return updated_module
Ejemplo n.º 8
0
def insert_header_comments(node: libcst.Module,
                           comments: List[str]) -> libcst.Module:
    """Insert comments after last non-empty line in header."""
    # Split the lines up into a contiguous comment-containing section and
    # the empty whitespace section that follows
    last_comment_index = -1
    for i, line in enumerate(node.header):
        if line.comment is not None:
            last_comment_index = i

    comment_lines = islice(node.header, last_comment_index + 1)
    empty_lines = islice(node.header, last_comment_index + 1, None)
    inserted_lines = [
        libcst.EmptyLine(comment=libcst.Comment(value=comment))
        for comment in comments
    ]
    return node.with_changes(header=(*comment_lines, *inserted_lines,
                                     *empty_lines))
Ejemplo n.º 9
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.toplevel_annotations and not self.imports:
            return updated_node

        toplevel_statements = []

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for _, import_statement in self.imports.items():
            import_statement = cst.ImportFrom(
                module=import_statement.module,
                # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]`
                #  for 2nd param but got `List[ImportFrom]`.
                names=import_statement.names,
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(
                cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Ejemplo n.º 10
0
def with_added_imports(
        module_node: cst.Module,
        import_nodes: Sequence[Union[cst.Import,
                                     cst.ImportFrom]]) -> cst.Module:
    """
    Adds new import `import_node` after the first import in the module `module_node`.
    """
    updated_body: List[Union[cst.SimpleStatementLine,
                             cst.BaseCompoundStatement]] = []
    added_import = False
    for line in module_node.body:
        updated_body.append(line)
        if not added_import and _is_import_line(line):
            for import_node in import_nodes:
                updated_body.append(
                    cst.SimpleStatementLine(body=tuple([import_node])))
            added_import = True

    if not added_import:
        raise RuntimeError("Failed to add imports")

    return module_node.with_changes(body=tuple(updated_body))
Ejemplo n.º 11
0
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module:
    """
    Insert comments after last non-empty line in header. Use this to insert one or more
    comments after any copyright preamble in a :class:`~libcst.Module`. Each comment in
    the list of ``comments`` must start with a ``#`` and will be placed on its own line
    in the appropriate location.
    """
    # Split the lines up into a contiguous comment-containing section and
    # the empty whitespace section that follows
    last_comment_index = -1
    for i, line in enumerate(node.header):
        if line.comment is not None:
            last_comment_index = i

    comment_lines = islice(node.header, last_comment_index + 1)
    empty_lines = islice(node.header, last_comment_index + 1, None)
    inserted_lines = [
        libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments
    ]
    # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples:
    #  `*comment_lines, *inserted_lines, *empty_lines`.
    return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
Ejemplo n.º 12
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.names or self.already_exists:
            return original_node

        modified_body = list(original_node.body)
        config = original_node.config_for_parsing

        list_of_names = f",{config.default_newline}{config.default_indent}".join(
            [repr(name) for name in sorted(self.names)])

        all_names = cst.parse_statement(
            f"""

__all__ = [
{config.default_indent}{list_of_names}
]
        """,
            config=original_node.config_for_parsing,
        )

        modified_body.append(all_names)
        return updated_node.with_changes(body=modified_body)
Ejemplo n.º 13
0
    def leave_Module(self, original_node: libcst.Module,
                     updated_node: libcst.Module) -> libcst.Module:
        # Don't try to modify if we have nothing to do
        if (not self.module_imports and not self.module_mapping
                and not self.module_aliases and not self.alias_mapping):
            return updated_node

        # First, find the insertion point for imports
        (
            statements_before_imports,
            statements_until_add_imports,
            statements_after_imports,
        ) = self._split_module(original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        # Mapping of modules we're adding to the object with and without alias they should import
        module_and_alias_mapping = defaultdict(list)
        for module, aliases in self.alias_mapping.items():
            module_and_alias_mapping[module].extend(aliases)
        for module, imports in self.module_mapping.items():
            module_and_alias_mapping[module].extend([(object, None)
                                                     for object in imports])
        module_and_alias_mapping = {
            module: sorted(aliases)
            for module, aliases in module_and_alias_mapping.items()
        }

        import_cycle_safe_module_names = [
            'mypy_extensions',
            'typing',
            'typing_extensions',
        ]
        type_checking_cond_import = parse_statement(
            f"from typing import TYPE_CHECKING",
            config=updated_node.config_for_parsing,
        )
        type_checking_cond_statement = libcst.If(
            test=libcst.Name("TYPE_CHECKING"),
            body=libcst.IndentedBlock(body=[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
                and module not in import_cycle_safe_module_names
            ], ),
        )
        if not type_checking_cond_statement.body.body:
            type_checking_cond_statement = libcst.EmptyLine()
            type_checking_cond_import = libcst.EmptyLine()
        # import ptvsd; ptvsd.set_trace()
        # Now, add all of the imports we need!
        return updated_node.with_changes(body=(
            *statements_before_imports,
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module == "__future__"
            ],
            *statements_until_add_imports,
            *[
                parse_statement(f"import {module}",
                                config=updated_node.config_for_parsing)
                for module in sorted(self.module_imports)
            ],
            *[
                parse_statement(
                    f"import {module} as {asname}",
                    config=updated_node.config_for_parsing,
                ) for (module, asname) in self.module_aliases.items()
            ],
            # TODO: 可以进一步用 `from __future__ import annotations` 解决forward ref, 这里加也可以,用其他工具也可以
            type_checking_cond_import,
            type_checking_cond_statement,
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
                and module in import_cycle_safe_module_names
                and not module.startswith("monkeytype")
            ],
            *statements_after_imports,
        ))
Ejemplo n.º 14
0
 def leave_Module(self, original_node: cst.Module,
                  updated_node: cst.Module):
     return updated_node.with_changes(body=self.__get_required_imports() +
                                      list(updated_node.body))
Ejemplo n.º 15
0
    def leave_Module(self, original_node: libcst.Module,
                     updated_node: libcst.Module) -> libcst.Module:
        # Don't try to modify if we have nothing to do
        if (not self.module_imports and not self.module_mapping
                and not self.module_aliases and not self.alias_mapping):
            return updated_node

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        # Mapping of modules we're adding to the object with and without alias they should import
        module_and_alias_mapping = defaultdict(list)
        for module, aliases in self.alias_mapping.items():
            module_and_alias_mapping[module].extend(aliases)
        for module, imports in self.module_mapping.items():
            module_and_alias_mapping[module].extend([(object, None)
                                                     for object in imports])
        module_and_alias_mapping = {
            module: sorted(aliases)
            for module, aliases in module_and_alias_mapping.items()
        }
        # import ptvsd; ptvsd.set_trace()
        # Now, add all of the imports we need!
        return updated_node.with_changes(body=(
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module == "__future__"
            ],
            *statements_before_imports,
            *[
                parse_statement(f"import {module}",
                                config=updated_node.config_for_parsing)
                for module in sorted(self.module_imports)
            ],
            *[
                parse_statement(
                    f"import {module} as {asname}",
                    config=updated_node.config_for_parsing,
                ) for (module, asname) in self.module_aliases.items()
            ],
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
            ],
            *statements_after_imports,
        ))
Ejemplo n.º 16
0
 def transform_module_impl(self, tree: cst.Module) -> cst.Module:
     self.warn("Testing")
     return tree.with_changes(
         header=[cst.EmptyLine(comment=cst.Comment("# A comment"))])