def transform_module_impl(self, tree: cst.Module) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_import_names(import_gatherer.all_imports) context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY ) if context_contents: stub, overwrite_existing_annotations = context_contents self.overwrite_existing_annotations = ( self.overwrite_existing_annotations or overwrite_existing_annotations ) visitor = TypeCollector(existing_import_names, self.context) stub.visit(visitor) self.annotations.function_annotations.update(visitor.function_annotations) self.annotations.attribute_annotations.update(visitor.attribute_annotations) self.annotations.class_definitions.update(visitor.class_definitions) tree_with_imports = AddImportsVisitor(self.context).transform_module(tree) return tree_with_imports.visit(self)
def _add_annotation_to_imports( self, annotation: cst.Attribute ) -> Union[cst.Name, cst.Attribute]: key = get_full_name_for_node(annotation.value) if key is not None: # Don't attempt to re-import existing imports. if key in self.existing_imports: return annotation import_name = get_full_name_for_node(annotation.attr) if import_name is not None: AddImportsVisitor.add_needed_import(self.context, key, import_name) return annotation.attr
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: module = node.module names = node.names # module is None for relative imports like `from .. import foo`. # We ignore these for now. if module is None or isinstance(names, cst.ImportStar): return module_name = get_full_name_for_node(module) if module_name is not None: for import_name in _get_import_alias_names(names): AddImportsVisitor.add_needed_import(self.context, module_name, import_name)
def transform_module_impl( self, tree: cst.Module, ) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_imported_names( import_gatherer.all_imports) context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY) if context_contents is not None: ( stub, overwrite_existing_annotations, use_future_annotations, strict_posargs_matching, strict_annotation_matching, ) = context_contents self.overwrite_existing_annotations = ( self.overwrite_existing_annotations or overwrite_existing_annotations) self.use_future_annotations = (self.use_future_annotations or use_future_annotations) self.strict_posargs_matching = (self.strict_posargs_matching and strict_posargs_matching) self.strict_annotation_matching = (self.strict_annotation_matching or strict_annotation_matching) visitor = TypeCollector(existing_import_names, self.context) cst.MetadataWrapper(stub).visit(visitor) self.annotations.update(visitor.annotations) if self.use_future_annotations: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") tree_with_imports = AddImportsVisitor( self.context).transform_module(tree) tree_with_changes = tree_with_imports.visit(self) # don't modify the imports if we didn't actually add any type information if self.annotation_counts.any_changes_applied(): return tree_with_changes else: return tree
def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: """ Basd on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we should use the qualified name due to a preexisting import. """ split_name = qualified_name.split(".") if len(split_name) > 1 and qualified_name not in self.existing_imports: module, target = ".".join(split_name[:-1]), split_name[-1] if module == "builtins": return False elif module in self.existing_imports: return True else: AddImportsVisitor.add_needed_import(self.context, module, target) return False return False
def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: """ Basd on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we should use the qualified name due to a preexisting import. """ module, target = self._module_and_target(qualified_name) if module in ("", "builtins"): return False elif qualified_name not in self.existing_imports: if module == "builtins": return False elif module in self.existing_imports: return True else: AddImportsVisitor.add_needed_import(self.context, module, target) return False return False