Example #1
0
    def visit_Module(self, node: libcst.Module) -> None:
        # Do a preliminary pass to gather the imports we already have
        gatherer = GatherImportsVisitor(self.context)
        node.visit(gatherer)
        self.all_imports = gatherer.all_imports

        self.module_imports = self.module_imports - gatherer.module_imports
        for module, alias in gatherer.module_aliases.items():
            if module in self.module_aliases and self.module_aliases[
                    module] == alias:
                del self.module_aliases[module]
        for module, aliases in gatherer.alias_mapping.items():
            for (obj, alias) in aliases:
                if (module in self.alias_mapping
                        and (obj, alias) in self.alias_mapping[module]):
                    self.alias_mapping[module].remove((obj, alias))
                    if len(self.alias_mapping[module]) == 0:
                        del self.alias_mapping[module]

        for module, imports in gatherer.object_mapping.items():
            if module not in self.module_mapping:
                # We don't care about this import at all
                continue
            elif "*" in imports:
                # We already implicitly are importing everything
                del self.module_mapping[module]
            else:
                # Lets figure out what's left to import
                self.module_mapping[
                    module] = self.module_mapping[module] - imports
                if not self.module_mapping[module]:
                    # There's nothing left, so lets delete this work item
                    del self.module_mapping[module]
    def transform_module_impl(self, tree: cst.Module) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_import_names(import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY
        )
        if context_contents:
            stub, overwrite_existing_annotations = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations or overwrite_existing_annotations
            )
            visitor = TypeCollector(existing_import_names, self.context)
            stub.visit(visitor)
            self.annotations.function_annotations.update(visitor.function_annotations)
            self.annotations.attribute_annotations.update(visitor.attribute_annotations)
            self.annotations.class_definitions.update(visitor.class_definitions)

        tree_with_imports = AddImportsVisitor(self.context).transform_module(tree)
        return tree_with_imports.visit(self)
Example #3
0
def _annotate_source(stubs: cst.Module, source: cst.Module) -> cst.Module:
    visitor = TypeCollector()
    stubs.visit(visitor)
    transformer = TypeTransformer(
        visitor.function_annotations, visitor.attribute_annotations, visitor.imports
    )
    return source.visit(transformer)
 def _gen_impl(self, module: Module) -> None:
     state = _ReentrantCodegenState(
         default_indent=module.default_indent,
         default_newline=module.default_newline,
         provider=self,
         encoding=module.encoding,
     )
     module._codegen(state)
Example #5
0
 def _gen_impl(self, module: Module) -> None:
     state = SpanProvidingCodegenState(
         default_indent=module.default_indent,
         default_newline=module.default_newline,
         provider=self,
         get_length=byte_length_in_utf8,
     )
     module._codegen(state)
Example #6
0
    def visit_Module(self, node: cst.Module) -> None:
        # Collect current list of imports
        gatherer = GatherImportsVisitor(self.context)

        node.visit(gatherer)

        # Store list of symbols imported from wx package
        self.wx_imports = gatherer.object_mapping.get("wx", set())
 def visit_Module(self, node: cst.Module) -> bool:
     export_collector = GatherExportsVisitor(self.context)
     node.visit(export_collector)
     self._exported_names = export_collector.explicit_exported_objects
     annotation_visitor = GatherNamesFromStringAnnotationsVisitor(
         self.context, typing_functions=self._typing_functions)
     node.visit(annotation_visitor)
     self._string_annotation_names = annotation_visitor.names
     return True
Example #8
0
 def visit_Module(self, node: cst.Module) -> bool:
     visitor = FullyQualifiedNameVisitor(self, self.module_name)
     node.visit(visitor)
     self.set_metadata(
         node,
         {
             QualifiedName(name=self.module_name,
                           source=QualifiedNameSource.LOCAL)
         },
     )
     return True
Example #9
0
def _annotate_source(stubs: cst.Module, source: cst.Module) -> cst.Module:
    import_visitor = ImportCollector()
    source.visit(import_visitor)
    visitor = TypeCollector(import_visitor.existing_imports)
    stubs.visit(visitor)
    transformer = TypeTransformer(
        visitor.function_annotations,
        visitor.attribute_annotations,
        visitor.imports,
        visitor.class_definitions,
    )
    return source.visit(transformer)
Example #10
0
 def update_imports(
     self,
     original_module: Module,
     updated_module: Module,
     import_name: str,
     updated_import_node: SimpleStatementLine,
     current_imports: Dict[str, str],
     new_imports: Set[str],
     noqa: bool,
 ) -> Module:
     if not new_imports:
         return updated_module
     noqa_comment = "  # noqa" if noqa else ""
     if not updated_import_node:
         i = -1
         blank_lines = "\n\n"
         if self.last_import_node_stmt:
             blank_lines = ""
             for i, (original, updated) in enumerate(
                 zip(original_module.body, updated_module.body)
             ):
                 if original is self.last_import_node_stmt:
                     break
         stmt = parse_module(
             f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}",
             config=updated_module.config_for_parsing,
         )
         body = list(updated_module.body)
         self.last_import_node_stmt = stmt
         return updated_module.with_changes(
             body=body[: i + 1] + stmt.children + body[i + 1 :]
         )
     else:
         if "*" not in current_imports:
             current_imports_set = {
                 f"{k}" if k == v else f"{k} as {v}"
                 for k, v in current_imports.items()
             }
             stmt = parse_statement(
                 f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}"
             )
             return updated_module.deep_replace(updated_import_node, stmt)
             # for i, (original, updated) in enumerate(
             #     zip(original_module.body, updated_module.body)
             # ):
             #     if original is original_import_node:
             #         body = list(updated_module.body)
             #         return updated_module.with_changes(
             #             body=body[:i] + [stmt] + body[i + 1 :]
             #         )
     return updated_module
    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]
        if not self.toplevel_annotations and not fresh_class_definitions:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None)
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
    def leave_Module(
        self,
        original_node: cst.Module,
        updated_node: cst.Module,
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]

        # NOTE: The entire change will also be abandoned if
        # self.annotation_counts is all 0s, so if adding any new category make
        # sure to record it there.
        if not (self.toplevel_annotations or fresh_class_definitions
                or self.annotations.typevars):
            return updated_node

        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = self._apply_annotation_to_attribute_or_global(
                name=name,
                annotation=annotation,
                value=None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        # TypeVar definitions could be scattered through the file, so do not
        # attempt to put new ones with existing ones, just add them at the top.
        typevars = {
            k: v
            for k, v in self.annotations.typevars.items()
            if k not in self.typevars
        }
        if typevars:
            for var, stmt in typevars.items():
                toplevel_statements.append(cst.Newline())
                toplevel_statements.append(stmt)
                self.annotation_counts.typevars_and_generics_added += 1
            toplevel_statements.append(cst.Newline())

        self.annotation_counts.classes_added = len(fresh_class_definitions)
        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Example #14
0
    def leave_Module(
        self, original_node: cst.Module, updated_node: cst.Module
    ) -> cst.Module:
        if self.is_generated:
            return original_node
        if not self.toplevel_annotations and not self.imports:
            return updated_node
        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node
        )

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(statements_after_imports)

        imported = set()
        for statement in self.import_statements:
            names = statement.names
            if isinstance(names, cst.ImportStar):
                continue
            for name in names:
                if name.asname:
                    name = name.asname
                if name:
                    imported.add(_get_name_as_string(name.name))

        for _, import_statement in self.imports.items():
            # Filter out anything that has already been imported.
            names = import_statement.names.difference(imported)
            names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)]
            if not names:
                continue
            import_statement = cst.ImportFrom(
                module=import_statement.module, names=names
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(
            body=[
                *statements_before_imports,
                *toplevel_statements,
                *statements_after_imports,
            ]
        )
Example #15
0
 def visit_Module(self, node: cst.Module) -> None:
     if self.rule_disabled:
         return
     if not m.matches(
             node, m.Module(header=[*self.header_matcher,
                                    m.ZeroOrMore()])):
         self.report(
             node,
             replacement=node.with_changes(
                 header=[*self.header_replacement, *node.header]),
         )
Example #16
0
 def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
Example #17
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        blocks = sortable_blocks(updated_node.body, config=self.config)
        body: List[cst.CSTNode] = list(updated_node.body)

        for b in blocks:
            initial_blank, initial_comment = partition_leading_lines(
                b.stmts[0].node.leading_lines)
            b.stmts[0].node = b.stmts[0].node.with_changes(
                leading_lines=initial_comment)
            sorted_stmts = fixup_whitespace(initial_blank, sorted(b.stmts))
            body[b.start_idx:b.end_idx] = [s.node for s in sorted_stmts]
        return updated_node.with_changes(body=body)
Example #18
0
def insert_header_comments(node: libcst.Module,
                           comments: List[str]) -> libcst.Module:
    """Insert comments after last non-empty line in header."""
    # Split the lines up into a contiguous comment-containing section and
    # the empty whitespace section that follows
    last_comment_index = -1
    for i, line in enumerate(node.header):
        if line.comment is not None:
            last_comment_index = i

    comment_lines = islice(node.header, last_comment_index + 1)
    empty_lines = islice(node.header, last_comment_index + 1, None)
    inserted_lines = [
        libcst.EmptyLine(comment=libcst.Comment(value=comment))
        for comment in comments
    ]
    return node.with_changes(header=(*comment_lines, *inserted_lines,
                                     *empty_lines))
Example #19
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.toplevel_annotations and not self.imports:
            return updated_node

        toplevel_statements = []

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for _, import_statement in self.imports.items():
            import_statement = cst.ImportFrom(
                module=import_statement.module,
                # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]`
                #  for 2nd param but got `List[ImportFrom]`.
                names=import_statement.names,
            )
            # Add import statements to module body.
            # Need to assign an Iterable, and the argument to SimpleStatementLine
            # must be subscriptable.
            toplevel_statements.append(
                cst.SimpleStatementLine([import_statement]))

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = cst.AnnAssign(
                cst.Name(name),
                # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
                cst.Annotation(annotation.annotation),
                None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Example #20
0
    def transform_module(self, tree: Module) -> Module:
        """
        Transform entrypoint which handles multi-pass logic and metadata calculation
        for you. This is the method that you should call if you wish to invoke a
        codemod directly. This is the method that is called by
        :func:`~libcst.codemod.transform_module`.
        """

        if not self.should_allow_multiple_passes():
            with self._handle_metadata_reference(tree) as tree_with_metadata:
                return self.transform_module_impl(tree_with_metadata)

        # We allow multiple passes, so we execute 1+ passes until there are
        # no more changes.
        previous: Module = tree
        while True:
            with self._handle_metadata_reference(tree) as tree_with_metadata:
                tree = self.transform_module_impl(tree_with_metadata)
            if tree.deep_equals(previous):
                break
            previous = tree
        return tree
def with_added_imports(
        module_node: cst.Module,
        import_nodes: Sequence[Union[cst.Import,
                                     cst.ImportFrom]]) -> cst.Module:
    """
    Adds new import `import_node` after the first import in the module `module_node`.
    """
    updated_body: List[Union[cst.SimpleStatementLine,
                             cst.BaseCompoundStatement]] = []
    added_import = False
    for line in module_node.body:
        updated_body.append(line)
        if not added_import and _is_import_line(line):
            for import_node in import_nodes:
                updated_body.append(
                    cst.SimpleStatementLine(body=tuple([import_node])))
            added_import = True

    if not added_import:
        raise RuntimeError("Failed to add imports")

    return module_node.with_changes(body=tuple(updated_body))
Example #22
0
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module:
    """
    Insert comments after last non-empty line in header. Use this to insert one or more
    comments after any copyright preamble in a :class:`~libcst.Module`. Each comment in
    the list of ``comments`` must start with a ``#`` and will be placed on its own line
    in the appropriate location.
    """
    # Split the lines up into a contiguous comment-containing section and
    # the empty whitespace section that follows
    last_comment_index = -1
    for i, line in enumerate(node.header):
        if line.comment is not None:
            last_comment_index = i

    comment_lines = islice(node.header, last_comment_index + 1)
    empty_lines = islice(node.header, last_comment_index + 1, None)
    inserted_lines = [
        libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments
    ]
    # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples:
    #  `*comment_lines, *inserted_lines, *empty_lines`.
    return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
Example #23
0
    def leave_Module(self, original_node: cst.Module,
                     updated_node: cst.Module) -> cst.Module:
        if not self.names or self.already_exists:
            return original_node

        modified_body = list(original_node.body)
        config = original_node.config_for_parsing

        list_of_names = f",{config.default_newline}{config.default_indent}".join(
            [repr(name) for name in sorted(self.names)])

        all_names = cst.parse_statement(
            f"""

__all__ = [
{config.default_indent}{list_of_names}
]
        """,
            config=original_node.config_for_parsing,
        )

        modified_body.append(all_names)
        return updated_node.with_changes(body=modified_body)
Example #24
0
 def visit_Module(self, node: cst.Module) -> Optional[bool]:
     visitor = QualifiedNameVisitor(self)
     node.visit(visitor)
Example #25
0
 def visit_Module(self, node: cst.Module) -> Optional[bool]:
     visitor = ScopeVisitor(self)
     node.visit(visitor)
     visitor.infer_accesses()
 def visit_Module(self, node: cst.Module) -> Optional[bool]:
     node.visit(ExpressionContextVisitor(self, ExpressionContext.LOAD))
Example #27
0
 def test_code_for_node(self, module: cst.Module, node: cst.CSTNode,
                        expected: str) -> None:
     self.assertEqual(module.code_for_node(node), expected)
 def visit_Module(self, node: cst.Module) -> Optional[bool]:
     node.visit(ParentNodeVisitor(self))
Example #29
0
    def leave_Module(self, original_node: libcst.Module,
                     updated_node: libcst.Module) -> libcst.Module:
        # Don't try to modify if we have nothing to do
        if (not self.module_imports and not self.module_mapping
                and not self.module_aliases and not self.alias_mapping):
            return updated_node

        # First, find the insertion point for imports
        (
            statements_before_imports,
            statements_until_add_imports,
            statements_after_imports,
        ) = self._split_module(original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        # Mapping of modules we're adding to the object with and without alias they should import
        module_and_alias_mapping = defaultdict(list)
        for module, aliases in self.alias_mapping.items():
            module_and_alias_mapping[module].extend(aliases)
        for module, imports in self.module_mapping.items():
            module_and_alias_mapping[module].extend([(object, None)
                                                     for object in imports])
        module_and_alias_mapping = {
            module: sorted(aliases)
            for module, aliases in module_and_alias_mapping.items()
        }

        import_cycle_safe_module_names = [
            'mypy_extensions',
            'typing',
            'typing_extensions',
        ]
        type_checking_cond_import = parse_statement(
            f"from typing import TYPE_CHECKING",
            config=updated_node.config_for_parsing,
        )
        type_checking_cond_statement = libcst.If(
            test=libcst.Name("TYPE_CHECKING"),
            body=libcst.IndentedBlock(body=[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
                and module not in import_cycle_safe_module_names
            ], ),
        )
        if not type_checking_cond_statement.body.body:
            type_checking_cond_statement = libcst.EmptyLine()
            type_checking_cond_import = libcst.EmptyLine()
        # import ptvsd; ptvsd.set_trace()
        # Now, add all of the imports we need!
        return updated_node.with_changes(body=(
            *statements_before_imports,
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module == "__future__"
            ],
            *statements_until_add_imports,
            *[
                parse_statement(f"import {module}",
                                config=updated_node.config_for_parsing)
                for module in sorted(self.module_imports)
            ],
            *[
                parse_statement(
                    f"import {module} as {asname}",
                    config=updated_node.config_for_parsing,
                ) for (module, asname) in self.module_aliases.items()
            ],
            # TODO: 可以进一步用 `from __future__ import annotations` 解决forward ref, 这里加也可以,用其他工具也可以
            type_checking_cond_import,
            type_checking_cond_statement,
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
                and module in import_cycle_safe_module_names
                and not module.startswith("monkeytype")
            ],
            *statements_after_imports,
        ))
Example #30
0
 def transform_module_impl(self, tree: libcst.Module) -> libcst.Module:
     return tree.visit(self)