def transform_module_impl(self, tree: cst.Module) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_import_names(import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY
        )
        if context_contents:
            stub, overwrite_existing_annotations = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations or overwrite_existing_annotations
            )
            visitor = TypeCollector(existing_import_names, self.context)
            stub.visit(visitor)
            self.annotations.function_annotations.update(visitor.function_annotations)
            self.annotations.attribute_annotations.update(visitor.attribute_annotations)
            self.annotations.class_definitions.update(visitor.class_definitions)

        tree_with_imports = AddImportsVisitor(self.context).transform_module(tree)
        return tree_with_imports.visit(self)
 def _add_annotation_to_imports(
     self, annotation: cst.Attribute
 ) -> Union[cst.Name, cst.Attribute]:
     key = get_full_name_for_node(annotation.value)
     if key is not None:
         # Don't attempt to re-import existing imports.
         if key in self.existing_imports:
             return annotation
         import_name = get_full_name_for_node(annotation.attr)
         if import_name is not None:
             AddImportsVisitor.add_needed_import(self.context, key, import_name)
     return annotation.attr
    def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
        module = node.module
        names = node.names

        # module is None for relative imports like `from .. import foo`.
        # We ignore these for now.
        if module is None or isinstance(names, cst.ImportStar):
            return
        module_name = get_full_name_for_node(module)
        if module_name is not None:
            for import_name in _get_import_alias_names(names):
                AddImportsVisitor.add_needed_import(self.context, module_name,
                                                    import_name)
    def transform_module_impl(
        self,
        tree: cst.Module,
    ) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_imported_names(
            import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            (
                stub,
                overwrite_existing_annotations,
                use_future_annotations,
                strict_posargs_matching,
                strict_annotation_matching,
            ) = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            self.use_future_annotations = (self.use_future_annotations
                                           or use_future_annotations)
            self.strict_posargs_matching = (self.strict_posargs_matching
                                            and strict_posargs_matching)
            self.strict_annotation_matching = (self.strict_annotation_matching
                                               or strict_annotation_matching)
            visitor = TypeCollector(existing_import_names, self.context)
            cst.MetadataWrapper(stub).visit(visitor)
            self.annotations.update(visitor.annotations)

            if self.use_future_annotations:
                AddImportsVisitor.add_needed_import(self.context, "__future__",
                                                    "annotations")
            tree_with_imports = AddImportsVisitor(
                self.context).transform_module(tree)

        tree_with_changes = tree_with_imports.visit(self)

        # don't modify the imports if we didn't actually add any type information
        if self.annotation_counts.any_changes_applied():
            return tree_with_changes
        else:
            return tree
 def _handle_qualification_and_should_qualify(self,
                                              qualified_name: str) -> bool:
     """
     Basd on a qualified name and the existing module imports, record that
     we need to add an import if necessary and return whether or not we
     should use the qualified name due to a preexisting import.
     """
     split_name = qualified_name.split(".")
     if len(split_name) > 1 and qualified_name not in self.existing_imports:
         module, target = ".".join(split_name[:-1]), split_name[-1]
         if module == "builtins":
             return False
         elif module in self.existing_imports:
             return True
         else:
             AddImportsVisitor.add_needed_import(self.context, module,
                                                 target)
             return False
     return False
示例#6
0
 def _handle_qualification_and_should_qualify(self,
                                              qualified_name: str) -> bool:
     """
     Basd on a qualified name and the existing module imports, record that
     we need to add an import if necessary and return whether or not we
     should use the qualified name due to a preexisting import.
     """
     module, target = self._module_and_target(qualified_name)
     if module in ("", "builtins"):
         return False
     elif qualified_name not in self.existing_imports:
         if module == "builtins":
             return False
         elif module in self.existing_imports:
             return True
         else:
             AddImportsVisitor.add_needed_import(self.context, module,
                                                 target)
             return False
     return False