def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] if not self.toplevel_annotations and not fresh_class_definitions: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign(cst.Name(name), annotation, None) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] # NOTE: The entire change will also be abandoned if # self.annotation_counts is all 0s, so if adding any new category make # sure to record it there. if not (self.toplevel_annotations or fresh_class_definitions or self.annotations.typevars): return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = self._apply_annotation_to_attribute_or_global( name=name, annotation=annotation, value=None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) # TypeVar definitions could be scattered through the file, so do not # attempt to put new ones with existing ones, just add them at the top. typevars = { k: v for k, v in self.annotations.typevars.items() if k not in self.typevars } if typevars: for var, stmt in typevars.items(): toplevel_statements.append(cst.Newline()) toplevel_statements.append(stmt) self.annotation_counts.typevars_and_generics_added += 1 toplevel_statements.append(cst.Newline()) self.annotation_counts.classes_added = len(fresh_class_definitions) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: if self.is_generated: return original_node if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node ) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line(statements_after_imports) imported = set() for statement in self.import_statements: names = statement.names if isinstance(names, cst.ImportStar): continue for name in names: if name.asname: name = name.asname if name: imported.add(_get_name_as_string(name.name)) for _, import_statement in self.imports.items(): # Filter out anything that has already been imported. names = import_statement.names.difference(imported) names = [cst.ImportAlias(cst.Name(name)) for name in sorted(names)] if not names: continue import_statement = cst.ImportFrom( module=import_statement.module, names=names ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append(cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes( body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ] )
def visit_Module(self, node: cst.Module) -> None: if self.rule_disabled: return if not m.matches( node, m.Module(header=[*self.header_matcher, m.ZeroOrMore()])): self.report( node, replacement=node.with_changes( header=[*self.header_replacement, *node.header]), )
def leave_Module(self, node: cst.Module, updated_node: cst.Module) -> cst.CSTNode: body = list(updated_node.body) index = self._get_toplevel_index(body) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) body.insert(index, cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=tuple(body))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: blocks = sortable_blocks(updated_node.body, config=self.config) body: List[cst.CSTNode] = list(updated_node.body) for b in blocks: initial_blank, initial_comment = partition_leading_lines( b.stmts[0].node.leading_lines) b.stmts[0].node = b.stmts[0].node.with_changes( leading_lines=initial_comment) sorted_stmts = fixup_whitespace(initial_blank, sorted(b.stmts)) body[b.start_idx:b.end_idx] = [s.node for s in sorted_stmts] return updated_node.with_changes(body=body)
def update_imports( self, original_module: Module, updated_module: Module, import_name: str, updated_import_node: SimpleStatementLine, current_imports: Dict[str, str], new_imports: Set[str], noqa: bool, ) -> Module: if not new_imports: return updated_module noqa_comment = " # noqa" if noqa else "" if not updated_import_node: i = -1 blank_lines = "\n\n" if self.last_import_node_stmt: blank_lines = "" for i, (original, updated) in enumerate( zip(original_module.body, updated_module.body) ): if original is self.last_import_node_stmt: break stmt = parse_module( f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}", config=updated_module.config_for_parsing, ) body = list(updated_module.body) self.last_import_node_stmt = stmt return updated_module.with_changes( body=body[: i + 1] + stmt.children + body[i + 1 :] ) else: if "*" not in current_imports: current_imports_set = { f"{k}" if k == v else f"{k} as {v}" for k, v in current_imports.items() } stmt = parse_statement( f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}" ) return updated_module.deep_replace(updated_import_node, stmt) # for i, (original, updated) in enumerate( # zip(original_module.body, updated_module.body) # ): # if original is original_import_node: # body = list(updated_module.body) # return updated_module.with_changes( # body=body[:i] + [stmt] + body[i + 1 :] # ) return updated_module
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module: """Insert comments after last non-empty line in header.""" # Split the lines up into a contiguous comment-containing section and # the empty whitespace section that follows last_comment_index = -1 for i, line in enumerate(node.header): if line.comment is not None: last_comment_index = i comment_lines = islice(node.header, last_comment_index + 1) empty_lines = islice(node.header, last_comment_index + 1, None) inserted_lines = [ libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments ] return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.toplevel_annotations and not self.imports: return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for _, import_statement in self.imports.items(): import_statement = cst.ImportFrom( module=import_statement.module, # pyre-fixme[6]: Expected `Union[Sequence[ImportAlias], ImportStar]` # for 2nd param but got `List[ImportFrom]`. names=import_statement.names, ) # Add import statements to module body. # Need to assign an Iterable, and the argument to SimpleStatementLine # must be subscriptable. toplevel_statements.append( cst.SimpleStatementLine([import_statement])) for name, annotation in self.toplevel_annotations.items(): annotated_assign = cst.AnnAssign( cst.Name(name), # pyre-fixme[16]: `CSTNode` has no attribute `annotation`. cst.Annotation(annotation.annotation), None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def with_added_imports( module_node: cst.Module, import_nodes: Sequence[Union[cst.Import, cst.ImportFrom]]) -> cst.Module: """ Adds new import `import_node` after the first import in the module `module_node`. """ updated_body: List[Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [] added_import = False for line in module_node.body: updated_body.append(line) if not added_import and _is_import_line(line): for import_node in import_nodes: updated_body.append( cst.SimpleStatementLine(body=tuple([import_node]))) added_import = True if not added_import: raise RuntimeError("Failed to add imports") return module_node.with_changes(body=tuple(updated_body))
def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module: """ Insert comments after last non-empty line in header. Use this to insert one or more comments after any copyright preamble in a :class:`~libcst.Module`. Each comment in the list of ``comments`` must start with a ``#`` and will be placed on its own line in the appropriate location. """ # Split the lines up into a contiguous comment-containing section and # the empty whitespace section that follows last_comment_index = -1 for i, line in enumerate(node.header): if line.comment is not None: last_comment_index = i comment_lines = islice(node.header, last_comment_index + 1) empty_lines = islice(node.header, last_comment_index + 1, None) inserted_lines = [ libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments ] # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples: # `*comment_lines, *inserted_lines, *empty_lines`. return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: if not self.names or self.already_exists: return original_node modified_body = list(original_node.body) config = original_node.config_for_parsing list_of_names = f",{config.default_newline}{config.default_indent}".join( [repr(name) for name in sorted(self.names)]) all_names = cst.parse_statement( f""" __all__ = [ {config.default_indent}{list_of_names} ] """, config=original_node.config_for_parsing, ) modified_body.append(all_names) return updated_node.with_changes(body=modified_body)
def leave_Module(self, original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module: # Don't try to modify if we have nothing to do if (not self.module_imports and not self.module_mapping and not self.module_aliases and not self.alias_mapping): return updated_node # First, find the insertion point for imports ( statements_before_imports, statements_until_add_imports, statements_after_imports, ) = self._split_module(original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) # Mapping of modules we're adding to the object with and without alias they should import module_and_alias_mapping = defaultdict(list) for module, aliases in self.alias_mapping.items(): module_and_alias_mapping[module].extend(aliases) for module, imports in self.module_mapping.items(): module_and_alias_mapping[module].extend([(object, None) for object in imports]) module_and_alias_mapping = { module: sorted(aliases) for module, aliases in module_and_alias_mapping.items() } import_cycle_safe_module_names = [ 'mypy_extensions', 'typing', 'typing_extensions', ] type_checking_cond_import = parse_statement( f"from typing import TYPE_CHECKING", config=updated_node.config_for_parsing, ) type_checking_cond_statement = libcst.If( test=libcst.Name("TYPE_CHECKING"), body=libcst.IndentedBlock(body=[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module not in import_cycle_safe_module_names ], ), ) if not type_checking_cond_statement.body.body: type_checking_cond_statement = libcst.EmptyLine() type_checking_cond_import = libcst.EmptyLine() # import ptvsd; ptvsd.set_trace() # Now, add all of the imports we need! return updated_node.with_changes(body=( *statements_before_imports, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module == "__future__" ], *statements_until_add_imports, *[ parse_statement(f"import {module}", config=updated_node.config_for_parsing) for module in sorted(self.module_imports) ], *[ parse_statement( f"import {module} as {asname}", config=updated_node.config_for_parsing, ) for (module, asname) in self.module_aliases.items() ], # TODO: 可以进一步用 `from __future__ import annotations` 解决forward ref, 这里加也可以,用其他工具也可以 type_checking_cond_import, type_checking_cond_statement, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module in import_cycle_safe_module_names and not module.startswith("monkeytype") ], *statements_after_imports, ))
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module): return updated_node.with_changes(body=self.__get_required_imports() + list(updated_node.body))
def leave_Module(self, original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module: # Don't try to modify if we have nothing to do if (not self.module_imports and not self.module_mapping and not self.module_aliases and not self.alias_mapping): return updated_node # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) # Mapping of modules we're adding to the object with and without alias they should import module_and_alias_mapping = defaultdict(list) for module, aliases in self.alias_mapping.items(): module_and_alias_mapping[module].extend(aliases) for module, imports in self.module_mapping.items(): module_and_alias_mapping[module].extend([(object, None) for object in imports]) module_and_alias_mapping = { module: sorted(aliases) for module, aliases in module_and_alias_mapping.items() } # import ptvsd; ptvsd.set_trace() # Now, add all of the imports we need! return updated_node.with_changes(body=( *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module == "__future__" ], *statements_before_imports, *[ parse_statement(f"import {module}", config=updated_node.config_for_parsing) for module in sorted(self.module_imports) ], *[ parse_statement( f"import {module} as {asname}", config=updated_node.config_for_parsing, ) for (module, asname) in self.module_aliases.items() ], *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" ], *statements_after_imports, ))
def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.warn("Testing") return tree.with_changes( header=[cst.EmptyLine(comment=cst.Comment("# A comment"))])