예제 #1
0
    def visit_Module(self, node: libcst.Module) -> None:
        # Do a preliminary pass to gather the imports we already have
        gatherer = GatherImportsVisitor(self.context)
        node.visit(gatherer)
        self.all_imports = gatherer.all_imports

        self.module_imports = self.module_imports - gatherer.module_imports
        for module, alias in gatherer.module_aliases.items():
            if module in self.module_aliases and self.module_aliases[
                    module] == alias:
                del self.module_aliases[module]
        for module, aliases in gatherer.alias_mapping.items():
            for (obj, alias) in aliases:
                if (module in self.alias_mapping
                        and (obj, alias) in self.alias_mapping[module]):
                    self.alias_mapping[module].remove((obj, alias))
                    if len(self.alias_mapping[module]) == 0:
                        del self.alias_mapping[module]

        for module, imports in gatherer.object_mapping.items():
            if module not in self.module_mapping:
                # We don't care about this import at all
                continue
            elif "*" in imports:
                # We already implicitly are importing everything
                del self.module_mapping[module]
            else:
                # Lets figure out what's left to import
                self.module_mapping[
                    module] = self.module_mapping[module] - imports
                if not self.module_mapping[module]:
                    # There's nothing left, so lets delete this work item
                    del self.module_mapping[module]
    def transform_module_impl(self, tree: cst.Module) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_import_names(import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY
        )
        if context_contents:
            stub, overwrite_existing_annotations = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations or overwrite_existing_annotations
            )
            visitor = TypeCollector(existing_import_names, self.context)
            stub.visit(visitor)
            self.annotations.function_annotations.update(visitor.function_annotations)
            self.annotations.attribute_annotations.update(visitor.attribute_annotations)
            self.annotations.class_definitions.update(visitor.class_definitions)

        tree_with_imports = AddImportsVisitor(self.context).transform_module(tree)
        return tree_with_imports.visit(self)
예제 #3
0
    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree