def visit_Module(self, node: libcst.Module) -> None: # Do a preliminary pass to gather the imports we already have gatherer = GatherImportsVisitor(self.context) node.visit(gatherer) self.all_imports = gatherer.all_imports self.module_imports = self.module_imports - gatherer.module_imports for module, alias in gatherer.module_aliases.items(): if module in self.module_aliases and self.module_aliases[ module] == alias: del self.module_aliases[module] for module, aliases in gatherer.alias_mapping.items(): for (obj, alias) in aliases: if (module in self.alias_mapping and (obj, alias) in self.alias_mapping[module]): self.alias_mapping[module].remove((obj, alias)) if len(self.alias_mapping[module]) == 0: del self.alias_mapping[module] for module, imports in gatherer.object_mapping.items(): if module not in self.module_mapping: # We don't care about this import at all continue elif "*" in imports: # We already implicitly are importing everything del self.module_mapping[module] else: # Lets figure out what's left to import self.module_mapping[ module] = self.module_mapping[module] - imports if not self.module_mapping[module]: # There's nothing left, so lets delete this work item del self.module_mapping[module]
def transform_module_impl(self, tree: cst.Module) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_import_names(import_gatherer.all_imports) context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY) if context_contents is not None: stub, overwrite_existing_annotations = context_contents self.overwrite_existing_annotations = ( self.overwrite_existing_annotations or overwrite_existing_annotations) visitor = TypeCollector(existing_import_names, self.context) stub.visit(visitor) self.annotations.function_annotations.update( visitor.function_annotations) self.annotations.attribute_annotations.update( visitor.attribute_annotations) self.annotations.class_definitions.update( visitor.class_definitions) tree_with_imports = AddImportsVisitor( self.context).transform_module(tree) return tree_with_imports.visit(self)
def visit_Module(self, node: cst.Module) -> None: # Collect current list of imports gatherer = GatherImportsVisitor(self.context) node.visit(gatherer) # Store list of symbols imported from wx package self.wx_imports = gatherer.object_mapping.get("wx", set())
def gather_imports(self, code: str) -> GatherImportsVisitor: transform_instance = GatherImportsVisitor( CodemodContext(full_module_name="a.b.foobar")) input_tree = parse_module(CodemodTest.make_fixture_data(code)) input_tree.visit(transform_instance) return transform_instance