Esempio n. 1
0
    def visit_Module(self, node: libcst.Module) -> None:
        # Do a preliminary pass to gather the imports we already have
        gatherer = GatherImportsVisitor(self.context)
        node.visit(gatherer)
        self.all_imports = gatherer.all_imports

        self.module_imports = self.module_imports - gatherer.module_imports
        for module, alias in gatherer.module_aliases.items():
            if module in self.module_aliases and self.module_aliases[
                    module] == alias:
                del self.module_aliases[module]
        for module, aliases in gatherer.alias_mapping.items():
            for (obj, alias) in aliases:
                if (module in self.alias_mapping
                        and (obj, alias) in self.alias_mapping[module]):
                    self.alias_mapping[module].remove((obj, alias))
                    if len(self.alias_mapping[module]) == 0:
                        del self.alias_mapping[module]

        for module, imports in gatherer.object_mapping.items():
            if module not in self.module_mapping:
                # We don't care about this import at all
                continue
            elif "*" in imports:
                # We already implicitly are importing everything
                del self.module_mapping[module]
            else:
                # Lets figure out what's left to import
                self.module_mapping[
                    module] = self.module_mapping[module] - imports
                if not self.module_mapping[module]:
                    # There's nothing left, so lets delete this work item
                    del self.module_mapping[module]
Esempio n. 2
0
    def transform_module_impl(self, tree: cst.Module) -> cst.Module:
        """
        Collect type annotations from all stubs and apply them to ``tree``.

        Gather existing imports from ``tree`` so that we don't add duplicate imports.
        """
        import_gatherer = GatherImportsVisitor(CodemodContext())
        tree.visit(import_gatherer)
        existing_import_names = _get_import_names(import_gatherer.all_imports)

        context_contents = self.context.scratch.get(
            ApplyTypeAnnotationsVisitor.CONTEXT_KEY)
        if context_contents is not None:
            stub, overwrite_existing_annotations = context_contents
            self.overwrite_existing_annotations = (
                self.overwrite_existing_annotations
                or overwrite_existing_annotations)
            visitor = TypeCollector(existing_import_names, self.context)
            stub.visit(visitor)
            self.annotations.function_annotations.update(
                visitor.function_annotations)
            self.annotations.attribute_annotations.update(
                visitor.attribute_annotations)
            self.annotations.class_definitions.update(
                visitor.class_definitions)

        tree_with_imports = AddImportsVisitor(
            self.context).transform_module(tree)
        return tree_with_imports.visit(self)
Esempio n. 3
0
    def visit_Module(self, node: cst.Module) -> None:
        # Collect current list of imports
        gatherer = GatherImportsVisitor(self.context)

        node.visit(gatherer)

        # Store list of symbols imported from wx package
        self.wx_imports = gatherer.object_mapping.get("wx", set())
Esempio n. 4
0
 def gather_imports(self, code: str) -> GatherImportsVisitor:
     transform_instance = GatherImportsVisitor(
         CodemodContext(full_module_name="a.b.foobar"))
     input_tree = parse_module(CodemodTest.make_fixture_data(code))
     input_tree.visit(transform_instance)
     return transform_instance