def visit_Module(self, node: libcst.Module) -> None: # Do a preliminary pass to gather the imports we already have gatherer = GatherImportsVisitor(self.context) node.visit(gatherer) self.all_imports = gatherer.all_imports self.module_imports = self.module_imports - gatherer.module_imports for module, alias in gatherer.module_aliases.items(): if module in self.module_aliases and self.module_aliases[ module] == alias: del self.module_aliases[module] for module, aliases in gatherer.alias_mapping.items(): for (obj, alias) in aliases: if (module in self.alias_mapping and (obj, alias) in self.alias_mapping[module]): self.alias_mapping[module].remove((obj, alias)) if len(self.alias_mapping[module]) == 0: del self.alias_mapping[module] for module, imports in gatherer.object_mapping.items(): if module not in self.module_mapping: # We don't care about this import at all continue elif "*" in imports: # We already implicitly are importing everything del self.module_mapping[module] else: # Lets figure out what's left to import self.module_mapping[ module] = self.module_mapping[module] - imports if not self.module_mapping[module]: # There's nothing left, so lets delete this work item del self.module_mapping[module]
def transform_module_impl(self, tree: cst.Module) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_import_names(import_gatherer.all_imports) context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY ) if context_contents: stub, overwrite_existing_annotations = context_contents self.overwrite_existing_annotations = ( self.overwrite_existing_annotations or overwrite_existing_annotations ) visitor = TypeCollector(existing_import_names, self.context) stub.visit(visitor) self.annotations.function_annotations.update(visitor.function_annotations) self.annotations.attribute_annotations.update(visitor.attribute_annotations) self.annotations.class_definitions.update(visitor.class_definitions) tree_with_imports = AddImportsVisitor(self.context).transform_module(tree) return tree_with_imports.visit(self)
def _annotate_source(stubs: cst.Module, source: cst.Module) -> cst.Module: visitor = TypeCollector() stubs.visit(visitor) transformer = TypeTransformer( visitor.function_annotations, visitor.attribute_annotations, visitor.imports ) return source.visit(transformer)
def _gen_impl(self, module: Module) -> None: state = _ReentrantCodegenState( default_indent=module.default_indent, default_newline=module.default_newline, provider=self, encoding=module.encoding, ) module._codegen(state)
def _gen_impl(self, module: Module) -> None: state = SpanProvidingCodegenState( default_indent=module.default_indent, default_newline=module.default_newline, provider=self, get_length=byte_length_in_utf8, ) module._codegen(state)
def visit_Module(self, node: cst.Module) -> None: # Collect current list of imports gatherer = GatherImportsVisitor(self.context) node.visit(gatherer) # Store list of symbols imported from wx package self.wx_imports = gatherer.object_mapping.get("wx", set())
def visit_Module(self, node: cst.Module) -> bool: export_collector = GatherExportsVisitor(self.context) node.visit(export_collector) self._exported_names = export_collector.explicit_exported_objects annotation_visitor = GatherNamesFromStringAnnotationsVisitor( self.context, typing_functions=self._typing_functions) node.visit(annotation_visitor) self._string_annotation_names = annotation_visitor.names return True
def visit_Module(self, node: cst.Module) -> bool: visitor = FullyQualifiedNameVisitor(self, self.module_name) node.visit(visitor) self.set_metadata( node, { QualifiedName(name=self.module_name, source=QualifiedNameSource.LOCAL) }, ) return True
def _annotate_source(stubs: cst.Module, source: cst.Module) -> cst.Module: import_visitor = ImportCollector() source.visit(import_visitor) visitor = TypeCollector(import_visitor.existing_imports) stubs.visit(visitor) transformer = TypeTransformer( visitor.function_annotations, visitor.attribute_annotations, visitor.imports, visitor.class_definitions, ) return source.visit(transformer)
def update_imports( self, original_module: Module, updated_module: Module, import_name: str, updated_import_node: SimpleStatementLine, current_imports: Dict[str, str], new_imports: Set[str], noqa: bool, ) -> Module: if not new_imports: return updated_module noqa_comment = " # noqa" if noqa else "" if not updated_import_node: i = -1 blank_lines = "\n\n" if self.last_import_node_stmt: blank_lines = "" for i, (original, updated) in enumerate( zip(original_module.body, updated_module.body) ): if original is self.last_import_node_stmt: break stmt = parse_module( f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}", config=updated_module.config_for_parsing, ) body = list(updated_module.body) self.last_import_node_stmt = stmt return updated_module.with_changes( body=body[: i + 1] + stmt.children + body[i + 1 :] ) else: if "*" not in current_imports: current_imports_set = { f"{k}" if k == v else f"{k} as {v}" for k, v in current_imports.items() } stmt = parse_statement( f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}" ) return updated_module.deep_replace(updated_import_node, stmt) # for i, (original, updated) in enumerate( # zip(original_module.body, updated_module.body) # ): # if original is original_import_node: # body = list(updated_module.body) # return updated_module.with_changes( # body=body[:i] + [stmt] + body[i + 1 :] # ) return updated_module
def transform_module_impl( self, tree: cst.Module, ) -> cst.Module: """ Collect type annotations from all stubs and apply them to ``tree``. Gather existing imports from ``tree`` so that we don't add duplicate imports. """ import_gatherer = GatherImportsVisitor(CodemodContext()) tree.visit(import_gatherer) existing_import_names = _get_imported_names( import_gatherer.all_imports) context_contents = self.context.scratch.get( ApplyTypeAnnotationsVisitor.CONTEXT_KEY) if context_contents is not None: ( stub, overwrite_existing_annotations, use_future_annotations, strict_posargs_matching, strict_annotation_matching, ) = context_contents self.overwrite_existing_annotations = ( self.overwrite_existing_annotations or overwrite_existing_annotations) self.use_future_annotations = (self.use_future_annotations or use_future_annotations) self.strict_posargs_matching = (self.strict_posargs_matching and strict_posargs_matching) self.strict_annotation_matching = (self.strict_annotation_matching or strict_annotation_matching) visitor = TypeCollector(existing_import_names, self.context) cst.MetadataWrapper(stub).visit(visitor) self.annotations.update(visitor.annotations) if self.use_future_annotations: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") tree_with_imports = AddImportsVisitor( self.context).transform_module(tree) tree_with_changes = tree_with_imports.visit(self) # don't modify the imports if we didn't actually add any type information if self.annotation_counts.any_changes_applied(): return tree_with_changes else: return tree
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] if not self.toplevel_annotations and not fresh_class_definitions: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] # NOTE: The entire change will also be abandoned if # self.annotation_counts is all 0s, so if adding any new category make # sure to record it there. if not (self.toplevel_annotations or fresh_class_definitions or self.annotations.typevars): return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = self._apply_annotation_to_attribute_or_global( name=name, annotation=annotation, value=None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) # TypeVar definitions could be scattered through the file, so do not # attempt to put new ones with existing ones, just add them at the top. typevars = { k: v for k, v in self.annotations.typevars.items() if k not in self.typevars } if typevars: for var, stmt in typevars.items(): toplevel_statements.append(cst.Newline()) toplevel_statements.append(stmt) self.annotation_counts.typevars_and_generics_added += 1 toplevel_statements.append(cst.Newline()) self.annotation_counts.classes_added = len(fresh_class_definitions) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def visit_Module(self, node: cst.Module) -> None: if self.rule_disabled: return if not m.matches( node, m.Module(header=[*self.header_matcher, m.ZeroOrMore()])): self.report( node, replacement=node.with_changes( header=[*self.header_replacement, *node.header]), )
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: blocks = sortable_blocks(updated_node.body, config=self.config) body: List[cst.CSTNode] = list(updated_node.body) for b in blocks: initial_blank, initial_comment = partition_leading_lines( b.stmts[0].node.leading_lines) b.stmts[0].node = b.stmts[0].node.with_changes( leading_lines=initial_comment) sorted_stmts = fixup_whitespace(initial_blank, sorted(b.stmts)) body[b.start_idx:b.end_idx] = [s.node for s in sorted_stmts] return updated_node.with_changes(body=body)
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module: """Insert comments after last non-empty line in header.""" # Split the lines up into a contiguous comment-containing section and # the empty whitespace section that follows last_comment_index = -1 for i, line in enumerate(node.header): if line.comment is not None: last_comment_index = i comment_lines = islice(node.header, last_comment_index + 1) empty_lines = islice(node.header, last_comment_index + 1, None) inserted_lines = [ libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments ] return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def transform_module(self, tree: Module) -> Module: """ Transform entrypoint which handles multi-pass logic and metadata calculation for you. This is the method that you should call if you wish to invoke a codemod directly. This is the method that is called by :func:`~libcst.codemod.transform_module`. """ if not self.should_allow_multiple_passes(): with self._handle_metadata_reference(tree) as tree_with_metadata: return self.transform_module_impl(tree_with_metadata) # We allow multiple passes, so we execute 1+ passes until there are # no more changes. previous: Module = tree while True: with self._handle_metadata_reference(tree) as tree_with_metadata: tree = self.transform_module_impl(tree_with_metadata) if tree.deep_equals(previous): break previous = tree return tree
def with_added_imports( module_node: cst.Module, import_nodes: Sequence[Union[cst.Import, cst.ImportFrom]]) -> cst.Module: """ Adds new import `import_node` after the first import in the module `module_node`. """ updated_body: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [] added_import = False for line in module_node.body: updated_body.append(line) if not added_import and _is_import_line(line): for import_node in import_nodes: updated_body.append( cst.SimpleStatementLine(body=tuple([import_node]))) added_import = True if not added_import: raise RuntimeError("Failed to add imports") return module_node.with_changes(body=tuple(updated_body))
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module: """ Insert comments after last non-empty line in header. Use this to insert one or more comments after any copyright preamble in a :class:`~libcst.Module`. Each comment in the list of ``comments`` must start with a ``#`` and will be placed on its own line in the appropriate location. """ # Split the lines up into a contiguous comment-containing section and # the empty whitespace section that follows last_comment_index = -1 for i, line in enumerate(node.header): if line.comment is not None: last_comment_index = i comment_lines = islice(node.header, last_comment_index + 1) empty_lines = islice(node.header, last_comment_index + 1, None) inserted_lines = [ libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments ] # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples: # `*comment_lines, *inserted_lines, *empty_lines`. return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.names or self.already_exists: return original_node modified_body = list(original_node.body) config = original_node.config_for_parsing list_of_names = f",{config.default_newline}{config.default_indent}".join( [repr(name) for name in sorted(self.names)]) all_names = cst.parse_statement( f""" __all__ = [ {config.default_indent}{list_of_names} ] """, config=original_node.config_for_parsing, ) modified_body.append(all_names) return updated_node.with_changes(body=modified_body)
def visit_Module(self, node: cst.Module) -> Optional[bool]: visitor = QualifiedNameVisitor(self) node.visit(visitor)
def visit_Module(self, node: cst.Module) -> Optional[bool]: visitor = ScopeVisitor(self) node.visit(visitor) visitor.infer_accesses()
def visit_Module(self, node: cst.Module) -> Optional[bool]: node.visit(ExpressionContextVisitor(self, ExpressionContext.LOAD))
def test_code_for_node(self, module: cst.Module, node: cst.CSTNode, expected: str) -> None: self.assertEqual(module.code_for_node(node), expected)
def visit_Module(self, node: cst.Module) -> Optional[bool]: node.visit(ParentNodeVisitor(self))
def leave_Module(self, original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module: # Don't try to modify if we have nothing to do if (not self.module_imports and not self.module_mapping and not self.module_aliases and not self.alias_mapping): return updated_node # First, find the insertion point for imports ( statements_before_imports, statements_until_add_imports, statements_after_imports, ) = self._split_module(original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) # Mapping of modules we're adding to the object with and without alias they should import module_and_alias_mapping = defaultdict(list) for module, aliases in self.alias_mapping.items(): module_and_alias_mapping[module].extend(aliases) for module, imports in self.module_mapping.items(): module_and_alias_mapping[module].extend([(object, None) for object in imports]) module_and_alias_mapping = { module: sorted(aliases) for module, aliases in module_and_alias_mapping.items() } import_cycle_safe_module_names = [ 'mypy_extensions', 'typing', 'typing_extensions', ] type_checking_cond_import = parse_statement( f"from typing import TYPE_CHECKING", config=updated_node.config_for_parsing, ) type_checking_cond_statement = libcst.If( test=libcst.Name("TYPE_CHECKING"), body=libcst.IndentedBlock(body=[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module not in import_cycle_safe_module_names ], ), ) if not type_checking_cond_statement.body.body: type_checking_cond_statement = libcst.EmptyLine() type_checking_cond_import = libcst.EmptyLine() # import ptvsd; ptvsd.set_trace() # Now, add all of the imports we need! return updated_node.with_changes(body=( *statements_before_imports, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module == "__future__" ], *statements_until_add_imports, *[ parse_statement(f"import {module}", config=updated_node.config_for_parsing) for module in sorted(self.module_imports) ], *[ parse_statement( f"import {module} as {asname}", config=updated_node.config_for_parsing, ) for (module, asname) in self.module_aliases.items() ], # TODO: 可以进一步用 `from __future__ import annotations` 解决forward ref, 这里加也可以,用其他工具也可以 type_checking_cond_import, type_checking_cond_statement, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module in import_cycle_safe_module_names and not module.startswith("monkeytype") ], *statements_after_imports, ))
def transform_module_impl(self, tree: libcst.Module) -> libcst.Module: return tree.visit(self)