def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine): if match.matches( original_node, match.SimpleStatementLine(body=[ match.Assign(targets=[ match.AssignTarget(target=match.Name( value=match.DoNotCare())) ]) ])): t = self.__get_var_type_assign_t( original_node.body[0].targets[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign( target=original_node.body[0].targets[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=cst.AssignEqual( whitespace_after=original_node.body[0]. targets[0].whitespace_after_equal, whitespace_before=original_node.body[0]. targets[0].whitespace_before_equal)) ]) elif match.matches( original_node, match.SimpleStatementLine(body=[ match.AnnAssign(target=match.Name(value=match.DoNotCare())) ])): t = self.__get_var_type_an_assign( original_node.body[0].target.value) if t is not None: t_annot_node_resolved = self.resolve_type_alias(t) t_annot_node = self.__name2annotation(t_annot_node_resolved) if t_annot_node is not None: self.all_applied_types.add( (t_annot_node_resolved, t_annot_node)) return updated_node.with_changes(body=[ cst.AnnAssign(target=original_node.body[0].target, value=original_node.body[0].value, annotation=t_annot_node, equal=original_node.body[0].equal) ]) return original_node
def _split_module( self, orig_module: libcst.Module, updated_module: libcst.Module ) -> Tuple[List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[ libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], ]: statement_before_import_location = 0 import_add_location = 0 # never insert an import before initial __strict__ flag if m.matches( orig_module, m.Module(body=[ m.SimpleStatementLine(body=[ m.Assign(targets=[ m.AssignTarget(target=m.Name("__strict__")) ]) ]), m.ZeroOrMore(), ]), ): statement_before_import_location = import_add_location = 1 # This works under the principle that while we might modify node contents, # we have yet to modify the number of statements. So we can match on the # original tree but break up the statements of the modified tree. If we # change this assumption in this visitor, we will have to change this code. for i, statement in enumerate(orig_module.body): if m.matches( statement, m.SimpleStatementLine( body=[m.Expr(value=m.SimpleString())])): statement_before_import_location = import_add_location = 1 elif isinstance(statement, libcst.SimpleStatementLine): for possible_import in statement.body: for last_import in self.all_imports: if possible_import is last_import: import_add_location = i + 1 break return ( list(updated_module.body[:statement_before_import_location]), list(updated_module. body[statement_before_import_location:import_add_location]), list(updated_module.body[import_add_location:]), )
class RemoveBarTransformer(VisitorBasedCodemodCommand): METADATA_DEPENDENCIES = (QualifiedNameProvider, ScopeProvider) @m.leave( m.SimpleStatementLine(body=[ m.Expr( m.Call(metadata=m.MatchMetadata( QualifiedNameProvider, { QualifiedName( source=QualifiedNameSource.IMPORT, name="foo.bar", ) }, ))) ])) def _leave_foo_bar( self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine, ) -> cst.RemovalSentinel: RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) return cst.RemoveFromParent()
def on_leave(self, original_node, updated_node): final_node = super().on_leave(original_node, updated_node) if (isinstance(final_node, cst.BaseStatement) and not m.matches( final_node, m.SimpleStatementLine(body=[m.Expr(m.SimpleString())])) and self.exec_counts[original_node] == 0): return cst.RemoveFromParent() return final_node
def _is_awaitable_callable(annotation: str) -> bool: if not (annotation.startswith("typing.Callable") or annotation.startswith("typing.ClassMethod") or annotation.startswith("StaticMethod")): # Exit early if this is not even a `typing.Callable` annotation. return False try: # Wrap this in a try-except since the type annotation may not be parse-able as a module. # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`. parsed_ann = cst.parse_module(annotation) except Exception: return False # If passed annotation does not match the expected annotation structure for a `typing.Callable` with # typing.Coroutine as the return type, matched_callable_ann will simply be `None`. # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]] matched_callable_ann: Optional[Dict[str, Union[ Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract( parsed_ann, m.Module(body=[ m.SimpleStatementLine(body=[ m.Expr(value=m.Subscript(slice=[ m.SubscriptElement(), m.SubscriptElement(slice=m.Index(value=m.Subscript( value=m.SaveMatchedNode( m.Attribute(), "base_return_type", )))), ], )) ]), ]), ) if (matched_callable_ann is not None and "base_return_type" in matched_callable_ann): base_return_type = get_full_name_for_node( cst.ensure_type(matched_callable_ann["base_return_type"], cst.CSTNode)) return (base_return_type is not None and base_return_type == "typing.Coroutine") return False
def on_leave(self, old_node, new_node): new_node = super().on_leave(old_node, new_node) if isinstance(new_node, self.block_types): cur_stmts = new_node.body any_change = False while True: change = False N = len(cur_stmts) for i in reversed(range(N)): stmt = cur_stmts[i] is_return = m.matches(stmt, m.SimpleStatementLine([m.Return()])) is_return_block = isinstance(stmt, cst.BaseCompoundStatement) and \ stmt.body in self.return_blocks if is_return or is_return_block: change = True any_change = True [cur_stmts, block] = [cur_stmts[:i], cur_stmts[i + 1:]] if is_return_block: self.return_blocks.remove(stmt.body) cur_stmts.append(stmt) if i < N - 1: cur_stmts.append(self._build_if(block)) break if not change: break new_node = new_node.with_changes(body=cur_stmts) if any_change: self.return_blocks.add(new_node) return new_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: doc_string = node.get_docstring() if not doc_string or "@sorted-attributes" not in doc_string: return found_any_assign: bool = False pre_assign_lines: List[LineType] = [] assign_lines: List[LineType] = [] post_assign_lines: List[LineType] = [] def _add_unmatched_line(line: LineType) -> None: post_assign_lines.append( line) if found_any_assign else pre_assign_lines.append(line) for line in node.body.body: if m.matches( line, m.SimpleStatementLine( body=[m.Assign(targets=[m.AssignTarget()])])): found_any_assign = True assign_lines.append(line) else: _add_unmatched_line(line) continue sorted_assign_lines = sorted( assign_lines, key=lambda line: line.body[0].targets[0].target.value) if sorted_assign_lines == assign_lines: return self.report( node, replacement=node.with_changes(body=node.body.with_changes( body=pre_assign_lines + sorted_assign_lines + post_assign_lines)), )
def _is_import_line( line: Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]) -> bool: return m.matches(line, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()]))
class Modernizer(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (PositionProvider,) # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node def __init__( self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None ): super().__init__() self.path = path self.verbose = verbose self.ignored = set(ignored or []) self.errors = False self.stack: List[Tuple[str, ...]] = [] self.annotations: Dict[ Tuple[str, ...], Comment # key: tuple of canonical variable name ] = {} self.python_future_updated_node: Optional[SimpleStatementLine] = None self.python_future_imports: Dict[str, str] = {} self.python_future_new_imports: Set[str] = set() self.builtins_imports: Dict[str, str] = {} self.builtins_new_imports: Set[str] = set() self.builtins_updated_node: Optional[SimpleStatementLine] = None self.future_utils_imports: Dict[str, str] = {} self.future_utils_new_imports: Set[str] = set() self.future_utils_updated_node: Optional[SimpleStatementLine] = None # self.last_import_node: Optional[CSTNode] = None self.last_import_node_stmt: Optional[CSTNode] = None # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__"))) # @m.visit(m.ImportAlias() | m.ImportStar()) # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None: # self.add_import(self.python_future_imports, node) # @m.leave(m.ImportFrom(module=m.Name("__future__"))) # def import_python_future_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside(m.ImportFrom(module=m.Name("builtins"))) @m.visit(m.ImportAlias() | m.ImportStar()) def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.builtins_imports, node) # @m.leave(m.ImportFrom(module=m.Name("builtins"))) # def builtins_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside( m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) ) @m.visit(m.ImportAlias() | m.ImportStar()) def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.future_utils_imports, node) # @m.leave( # m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) # ) # def future_utils_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @staticmethod def add_import( imports: Dict[str, str], node: Union[ImportAlias, ImportStar] ) -> None: if isinstance(node, ImportAlias): imports[node.name.value] = ( node.asname.name.value if node.asname else node.name.value ) else: imports["*"] = "*" # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_Import(self, node: Import) -> Optional[bool]: # self.last_import_node = node # return None # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]: # self.last_import_node = node # return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]: for n in node.body: if m.matches(n, m.Import() | m.ImportFrom()): self.last_import_node_stmt = node return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def leave_SimpleStatementLine( self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine ) -> Union[BaseStatement, RemovalSentinel]: for n in updated_node.body: if m.matches(n, m.ImportFrom(module=m.Name("__future__"))): self.python_future_updated_node = updated_node elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))): self.builtins_updated_node = updated_node elif m.matches( n, m.ImportFrom( module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")) ), ): self.future_utils_updated_node = updated_node return updated_node # @m.visit( # m.AllOf( # m.SimpleStatementLine(), # m.MatchIfTrue( # lambda node: any(m.matches(c, m.Assign()) for c in node.children) # ), # m.MatchIfTrue( # lambda node: "# type:" in node.trailing_whitespace.comment.value # ), # ) # ) # def visit_assign(self, node: SimpleStatementSuite) -> None: # return None def visit_Param(self, node: Param) -> Optional[bool]: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.ptype: Optional[str] = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment and "type:" in node.comment.value: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) self.ptype = mo.group(1) if mo else None return None v = Visitor() node.visit(v) if self.verbose: pos = self.get_metadata(PositionProvider, node).start print( f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}" ) return None @m.visit(m.SimpleStatementLine()) def visit_simple_stmt(self, node: SimpleStatementLine) -> None: assign = None for c in node.children: if m.matches(c, m.Assign()): assign = ensure_type(c, Assign) if assign: if m.MatchIfTrue( lambda n: n.trailing_whitespace.comment and "type:" in n.trailing_whitespace.comment.value ): class TypingVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.vtype = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) if mo: vtype = mo.group(1) return None tv = TypingVisitor() node.visit(tv) vtype = tv.vtype else: vtype = None class NameVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.names: List[str] = [] def visit_Name(self, node: Name) -> Optional[bool]: self.names.append(node.value) return None if self.verbose: pos = self.get_metadata(PositionProvider, node).start for target in assign.targets: v = NameVisitor() target.visit(v) for name in v.names: print( f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}" ) def visit_FunctionDef_body(self, node: FunctionDef) -> None: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() def visit_EmptyLine_comment(self, node: "EmptyLine") -> None: # FIXME too many matches on test_param_02 if not node.comment: return # TODO: use comment.value return None v = Visitor() node.visit(v) return None map_matcher = m.Call( func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range") ) @m.visit(map_matcher) def visit_map(self, node: Call) -> None: func_name = ensure_type(node.func, Name).value if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(map_matcher) def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression: # TODO test with CompFor etc. # TODO improve join test func_name = ensure_type(updated_node.func, Name).value if func_name not in self.builtins_imports: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def visit_xrange(self, node: Call) -> None: orig_func_name = ensure_type(node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression: orig_func_name = ensure_type(updated_node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" return updated_node.with_changes(func=Name(func_name)) iter_matcher = m.Call( func=m.Attribute( attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems") ) ) @m.visit(iter_matcher) def visit_iter(self, node: Call) -> None: func_name = ensure_type(node.func, Attribute).attr.value if func_name not in self.future_utils_imports: self.future_utils_new_imports.add(func_name) @m.leave(iter_matcher) def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: attribute = ensure_type(updated_node.func, Attribute) func_name = attribute.attr dict_name = attribute.value return updated_node.with_changes(func=func_name, args=[Arg(dict_name)]) not_iter_matcher = m.Call( func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items")) ) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(not_iter_matcher) def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.call_if_not_inside(m.Import() | m.ImportFrom()) @m.leave(m.Name(value="unicode")) def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression: value = "text_type" if value not in self.future_utils_imports: self.future_utils_new_imports.add(value) return updated_node.with_changes(value=value) def leave_Module(self, original_node: Module, updated_node: Module) -> Module: updated_node = self.update_imports( original_node, updated_node, "builtins", self.builtins_updated_node, self.builtins_imports, self.builtins_new_imports, True, ) updated_node = self.update_imports( original_node, updated_node, "future.utils", self.future_utils_updated_node, self.future_utils_imports, self.future_utils_new_imports, False, ) return updated_node def update_imports( self, original_module: Module, updated_module: Module, import_name: str, updated_import_node: SimpleStatementLine, current_imports: Dict[str, str], new_imports: Set[str], noqa: bool, ) -> Module: if not new_imports: return updated_module noqa_comment = " # noqa" if noqa else "" if not updated_import_node: i = -1 blank_lines = "\n\n" if self.last_import_node_stmt: blank_lines = "" for i, (original, updated) in enumerate( zip(original_module.body, updated_module.body) ): if original is self.last_import_node_stmt: break stmt = parse_module( f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}", config=updated_module.config_for_parsing, ) body = list(updated_module.body) self.last_import_node_stmt = stmt return updated_module.with_changes( body=body[: i + 1] + stmt.children + body[i + 1 :] ) else: if "*" not in current_imports: current_imports_set = { f"{k}" if k == v else f"{k} as {v}" for k, v in current_imports.items() } stmt = parse_statement( f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}" ) return updated_module.deep_replace(updated_import_node, stmt) # for i, (original, updated) in enumerate( # zip(original_module.body, updated_module.body) # ): # if original is original_import_node: # body = list(updated_module.body) # return updated_module.with_changes( # body=body[:i] + [stmt] + body[i + 1 :] # ) return updated_module
class ShedFixers(VisitorBasedCodemodCommand): """Fix a variety of small problems. Replaces `raise NotImplemented` with `raise NotImplementedError`, and converts always-failing assert statements to explicit `raise` statements. Also includes code closely modelled on pybetter's fixers, because it's considerably faster to run all transforms in a single pass if possible. """ DESCRIPTION = "Fix a variety of style, performance, and correctness issues." @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented"))) def leave_Name(self, _, updated_node): # noqa return updated_node.with_changes(value="NotImplementedError") def leave_Assert(self, _, updated_node): # noqa test_code = cst.Module("").code_for_node(updated_node.test) try: test_literal = literal_eval(test_code) except Exception: return updated_node if test_literal: return cst.RemovalSentinel.REMOVE if updated_node.msg is None: return cst.Raise(cst.Name("AssertionError")) return cst.Raise( cst.Call(cst.Name("AssertionError"), args=[cst.Arg(updated_node.msg)])) @m.leave( m.ComparisonTarget(comparator=oneof_names("None", "False", "True"), operator=m.Equal())) def convert_none_cmp(self, _, updated_node): """Inspired by Pybetter.""" return updated_node.with_changes(operator=cst.Is()) @m.leave( m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[m.ComparisonTarget(operator=m.In())]), )) def replace_not_in_condition(self, _, updated_node): """Also inspired by Pybetter.""" expr = cst.ensure_type(updated_node.expression, cst.Comparison) return cst.Comparison( left=expr.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ expr.comparisons[0].with_changes(operator=cst.NotIn()) ], ) @m.leave( m.Call( lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())], rpar=[m.AtLeastN(n=1, matcher=m.RightParen())], )) def remove_pointless_parens_around_call(self, _, updated_node): # This is *probably* valid, but we might have e.g. a multi-line parenthesised # chain of attribute accesses ("fluent interface"), where we need the parens. noparens = updated_node.with_changes(lpar=[], rpar=[]) try: compile(self.module.code_for_node(noparens), "<string>", "eval") return noparens except SyntaxError: return updated_node # The following methods fix https://pypi.org/project/flake8-comprehensions/ @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())])) def replace_generator_in_call_with_comprehension(self, _, updated_node): """Fix flake8-comprehensions C400-402 and 403-404. C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension. Note that set and dict conversions are handled by pyupgrade! """ return cst.ListComp(elt=updated_node.args[0].value.elt, for_in=updated_node.args[0].value.for_in) @m.leave( m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")]) | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")]) | m.Call( func=m.Name("list"), args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")], )) def replace_unnecessary_list_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C411 and C413. Unnecessary <list/reversed> call around sorted(). Also covers C411 Unnecessary list call around list comprehension for lists and sets. """ return updated_node.args[0].value @m.leave( m.Call( func=m.Name("reversed"), args=[m.Arg(m.Call(func=m.Name("sorted")), star="")], )) def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args) _sets = oneof_names("set", "frozenset") _seqs = oneof_names("list", "reversed", "sorted", "tuple") @m.leave( m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")]) | m.Call( func=oneof_names("list", "tuple"), args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")], ) | m.Call( func=m.Name("sorted"), args=[m.Arg(m.Call(func=_seqs), star=""), m.ZeroOrMore()], )) def replace_unnecessary_nested_calls(self, _, updated_node): """Fix flake8-comprehensions C414. Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>().. """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.args[0].value)] + list(updated_node.args[1:]), ) @m.leave( m.Call( func=oneof_names("reversed", "set", "sorted"), args=[ m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)])) ], )) def replace_unnecessary_subscript_reversal(self, _, updated_node): """Fix flake8-comprehensions C415. Unnecessary subscript reversal of iterable within <reversed/set/sorted>(). """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.value)], ) @m.leave( multi( m.ListComp, m.SetComp, elt=m.Name(), for_in=m.CompFor(target=m.Name(), ifs=[], inner_for_in=None, asynchronous=None), )) def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): """Fix flake8-comprehensions C416. Unnecessary <list/set> comprehension - rewrite using <list/set>(). """ if updated_node.elt.value == updated_node.for_in.target.value: func = cst.Name( "list" if isinstance(updated_node, cst.ListComp) else "set") return cst.Call(func=func, args=[cst.Arg(updated_node.for_in.iter)]) return updated_node @m.leave(m.Subscript(oneof_names("Union", "Literal"))) def reorder_union_literal_contents_none_last(self, _, updated_node): subscript = list(updated_node.slice) try: subscript.sort(key=lambda elt: elt.slice.value.value == "None") subscript[-1] = remove_trailing_comma(subscript[-1]) return updated_node.with_changes(slice=subscript) except Exception: # Single-element literals are not slices, etc. return updated_node @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) @m.leave( m.BinaryOperation( left=m.Name("None") | m.BinaryOperation(), operator=m.BitOr(), right=m.DoNotCare(), )) def reorder_union_operator_contents_none_last(self, _, updated_node): def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False node_left = updated_node.left if _has_none(node_left): return updated_node.with_changes(left=updated_node.right, right=node_left) else: return updated_node @m.leave(m.Subscript(value=m.Name("Literal"))) def flatten_literal_subscript(self, _, updated_node): new_slice = [] for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))): new_slice += item.slice.value.slice else: new_slice.append(item) return updated_node.with_changes(slice=new_slice) @m.leave(m.Subscript(value=m.Name("Union"))) def flatten_union_subscript(self, _, updated_node): new_slice = [] has_none = False for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): new_slice += item.slice.value.slice # peel off "Optional" has_none = True elif m.matches(item.slice.value, m.Subscript(m.Name("Union"))) and m.matches( updated_node.value, item.slice.value.value): new_slice += item.slice.value.slice # peel off "Union" or "Literal" elif m.matches(item.slice.value, m.Name("None")): has_none = True else: new_slice.append(item) if has_none: new_slice.append( cst.SubscriptElement(slice=cst.Index(cst.Name("None")))) return updated_node.with_changes(slice=new_slice) @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])]))) def discard_empty_else_blocks(self, _, updated_node): # An `else: pass` block can always simply be discarded, and libcst ensures # that an Else node can only ever occur attached to an If, While, For, or Try # node; in each case `None` is the valid way to represent "no else block". if m.findall(updated_node, m.Comment()): return updated_node # If there are any comments, keep the node return cst.RemoveFromParent() @m.leave( m.Lambda(params=m.MatchIfTrue(lambda node: ( node.star_kwarg is None and not node.kwonly_params and not node. posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and all(param.default is None for param in node.params))))) def remove_lambda_indirection(self, _, updated_node): same_args = [ m.Arg(m.Name(param.name.value), star="", keyword=None) for param in updated_node.params.params ] if m.matches(updated_node.body, m.Call(args=same_args)): return cst.ensure_type(updated_node.body, cst.Call).func return updated_node @m.leave( m.BooleanOperation( left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), operator=m.Or(), right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), )) def collapse_isinstance_checks(self, _, updated_node): left_target, left_type = updated_node.left.args right_target, right_type = updated_node.right.args if left_target.deep_equals(right_target): merged_type = cst.Arg( cst.Tuple([ cst.Element(left_type.value), cst.Element(right_type.value) ])) return updated_node.left.with_changes( args=[left_target, merged_type]) return updated_node
def is_docstring(node): return m.matches( node, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]))
newline=m.Newline()) _django_model_field_name_value = m.Call(func=m.Attribute( attr=m.Name(m.MatchIfTrue(is_model_field_type)))) | m.Call( func=m.Name(m.MatchIfTrue(is_model_field_type))) _django_model_field_name_with_leading_comment_value = m.Call( func=m.Attribute(attr=m.Name(m.MatchIfTrue(is_model_field_type))), whitespace_before_args=m.ParenthesizedWhitespace(_any_comment), ) | m.Call( func=m.Name(m.MatchIfTrue(is_model_field_type)), whitespace_before_args=m.ParenthesizedWhitespace(_any_comment), ) _django_model_field_with_leading_comment = m.SimpleStatementLine(body=[ m.Assign(value=_django_model_field_name_with_leading_comment_value) | m.AnnAssign(value=_django_model_field_name_with_leading_comment_value) ]) _django_model_field_with_trailing_comment = m.SimpleStatementLine( body=[ m.Assign(value=_django_model_field_name_value) | m.AnnAssign(value=_django_model_field_name_value) ], trailing_whitespace=_any_comment, ) django_model_field_with_comments = (_django_model_field_with_leading_comment | _django_model_field_with_trailing_comment) def get_leading_comment(node: cst.SimpleStatementLine) -> typing.Optional[str]:
comment=m.Comment(m.MatchIfTrue(is_valid_comment)), newline=m.Newline(), ) field_without_comment = m.SimpleStatementLine( body=[ m.Assign(value=(m.Call( args=[ m.ZeroOrMore(), m.Arg(keyword=m.Name('null'), value=m.Name('True')), m.ZeroOrMore(), ], whitespace_before_args=m.DoesNotMatch( m.ParenthesizedWhitespace(null_comment)), ) | m.Call( func=m.Attribute(attr=m.Name('NullBooleanField')), whitespace_before_args=m.DoesNotMatch( m.ParenthesizedWhitespace(null_comment)), ) | m.Call( func=m.Name('NullBooleanField'), whitespace_before_args=m.DoesNotMatch( m.ParenthesizedWhitespace(null_comment)), ))) ], trailing_whitespace=m.DoesNotMatch(null_comment), ) class FieldValidator(m.MatcherDecoratableVisitor): METADATA_DEPENDENCIES = (PositionProvider, )
class LoggerTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = ( WhitespaceInclusivePositionProvider, PositionProvider, ParentNodeProvider ) def __init__( self, fpath, lines, default_level='info', accept_all=False, comment_sep=' / ', context_lines=13, ): self.fpath = fpath self.lines = lines self.default_level: str = default_level self.accept_all: bool = accept_all self.comment_sep: str = comment_sep self.context_lines: int = context_lines def get_parent(self, node) -> CSTNodeT: return self.get_metadata(cst.metadata.ParentNodeProvider, node) @m.call_if_not_inside(m.SimpleStatementLine()) def leave_line(self, original_node, updated_node): return updated_node @m.call_if_inside(m.Call()) def leave_call(self, original_node, updated_node): if not (original_node.func.value == 'print'): return line_node def on_leave(self, original_node: CSTNodeT, updated_node: CSTNodeT) -> Union[CSTNodeT, RemovalSentinel]: # Visit line nodes with print function calls if not isinstance(updated_node, SimpleStatementLine): return updated_node original_line_node = original_node line_node = updated_node if not isinstance(line_node.body[0], Expr): return updated_node node = line_node.body[0].value original_node = original_node.body[0].value if not (isinstance(node, Call) and node.func.value == 'print'): return line_node #Arg.value, Arg.keyword pos_args = [x.value for x in node.args if not x.keyword] has_vars = False terms = [] n_variables = 0 simple_ixs = [] # indexes of regular, simple strings for ix, arg in enumerate(pos_args): if isinstance(arg, FormattedString): for part in arg.parts: if isinstance(part, FormattedStringExpression): has_vars = True break term = make_str(arg) terms.append(term) elif isinstance(arg, SimpleString): term = extract_string(arg.value) terms.append(term) simple_ixs.append(ix) elif isinstance(arg, ConcatenatedString): visitor = GatherStringVisitor() arg.visit(visitor) term = ''.join([extract_string(s) for s in visitor.strings]) terms.append(term) simple_ixs.append(ix) elif isinstance(arg, Name): has_vars = True n_variables += 1 terms.append('{' + arg.value + '}') # Escape {} in non-f strings if has_vars: for ix in simple_ixs: term = terms[ix] terms[ix] = term.replace('{', '{{').replace('}', '}}') sep = ' ' sep_ = get_keyword(node, 'sep') try: # fails if sep is a variable sep = extract_string(sep_) except TypeError: pass if n_variables == len(terms) == 1: # Avoid putting a single variable inside f-string arg_line = terms[0] else: arg_line = '"' + sep.join(terms) + '"' if has_vars: arg_line = 'f' + arg_line args = [Arg(value=cst.parse_expression(arg_line))] # Gather up comments cst.metadata.MetadataWrapper(original_line_node) cg = GatherCommentsVisitor() original_line_node.visit(cg) comment = cg.get_joined_comment(self.comment_sep) # Remove all comments in order to put them all at the end rc_trans = RemoveCommentsTransformer() line_node = line_node.visit(rc_trans) def get_line_node(level): func = Attribute(value=Name('logging'), attr=Name(level)) node_ = node.with_changes(func=func, args=args) line_node_ = line_node.deep_replace(line_node.body[0].value, node_) line_node_ = line_node_.with_deep_changes( line_node_.trailing_whitespace, comment=comment ) return line_node_ line_node = get_line_node(self.default_level) # pos = self.get_metadata(WhitespaceInclusivePositionProvider, original_line_node) pos = self.get_metadata(PositionProvider, original_line_node) lineix = pos.start.line - 1 # 1 indexed line number end_lineix, end_lineno = pos.end.line - 1, pos.end.line # Predict the source code for the newly changed line node module_node = original_line_node while not isinstance(module_node, Module): module_node = self.get_parent(module_node) # n_lines = len(cst.parse_module("").code_for_node(line_node).splitlines()) # new_code = module_node.deep_replace(original_line_node, line_node).code # line = '\n'.join(new_code.splitlines()[lineix:lineix + n_lines]) line = cst.parse_module("").code_for_node(line_node) # Find the function or class containing the print line context_node = original_line_node while not isinstance(context_node, (FunctionDef, ClassDef, Module)): context_node = self.get_parent(context_node) if isinstance(context_node, Module): source_context = '' else: source_context = '/' + context_node.name.value print( Bcolor.HEADER, f"{self.fpath}{source_context}:" f"{lineix + 1}-{end_lineix + 1}", Bcolor.ENDC ) print() print_context2(self.lines, lineix, end_lineno, line, self.context_lines) print() import ipdb # ipdb.set_trace() # Query the user to decide whether to accept, modify, or reject changes if self.accept_all: return line_node inp = None while inp not in ['', 'y', 'n', 'A', 'i', 'w', 'e', 'c', 'x', 'q']: inp = input( Bcolor.OKCYAN + "Accept change? (" f"y = yes ({self.default_level}) [default], " "n = no, " "A = yes to all, " "i = yes (info), " "w = yes (warning), " "e = yes (error), " "c = yes (critical), " "x = yes (exception), " "q = quit): " + Bcolor.ENDC ) if inp in ('q', 'Q'): sys.exit(0) elif inp == 'n': return original_line_node elif inp in ['i', 'w', 'e', 'c', 'x']: level = levels[inp] line_node = get_line_node(level) elif inp == 'A': self.accept_all = True return line_node
def inline_function(func_obj, call, ret_var, cls=None, f_ast=None, is_toplevel=False): log.debug('Inlining {}'.format(a2s(call))) inliner = ctx_inliner.get() pass_ = ctx_pass.get() if f_ast is None: # Get the source code for the function try: f_source = inspect.getsource(func_obj) except TypeError: print('Failed to get source of {}'.format(a2s(call))) raise # Record statistics about length of inlined source inliner.length_inlined += len(f_source.split('\n')) # Then parse the function into an AST f_ast = parse_statement(f_source) # Give the function a fresh name so it won't conflict with other calls to # the same function f_ast = f_ast.with_changes( name=cst.Name(pass_.fresh_var(f_ast.name.value))) # TODO # If function has decorators, deal with those first. Just inline decorator call # and stop there. decorators = f_ast.decorators assert len(decorators) <= 1 # TODO: deal with multiple decorators if len(decorators) == 1: d = decorators[0].decorator builtin_decorator = (isinstance(d, cst.Name) and (d.value in ['property', 'classmethod', 'staticmethod'])) derived_decorator = (isinstance(d, cst.Attribute) and (d.attr.value in ['setter'])) if not (builtin_decorator or derived_decorator): return inline_decorators(f_ast, call, func_obj, ret_var) # # If we're inlining a decorator, we need to remove @functools.wraps calls # # to avoid messing up inspect.getsource f_ast = f_ast.with_changes(body=f_ast.body.visit(RemoveFunctoolsWraps())) new_stmts = [] # If the function is a method (which we proxy by first arg being named "self"), # then we need to replace uses of special "super" keywords. args_def = f_ast.params if len(args_def.params) > 0: first_arg_is_self = m.matches(args_def.params[0], m.Param(m.Name('self'))) if first_arg_is_self: f_ast = replace_super(f_ast, cls, call, func_obj, new_stmts) # Add bindings from arguments in the call expression to arguments in function def f_ast = bind_arguments(f_ast, call, new_stmts) scopes = cst.MetadataWrapper( f_ast, unsafe_skip_copy=True).resolve(ScopeProviderFunction) func_scope = scopes[f_ast.body] for assgn in func_scope.assignments: if m.matches(assgn.node, m.Name()): var = assgn.node.value f_ast = unique_and_rename(f_ast, var) # Add an explicit return None at the end to reify implicit return f_body = f_ast.body last_stmt_is_return = m.matches(f_body.body[-1], m.SimpleStatementLine([m.Return()])) if (not is_toplevel and # If function return is being assigned cls is None and # And not an __init__ fn not last_stmt_is_return): f_ast = f_ast.with_deep_changes(f_body, body=list(f_body.body) + [parse_statement("return None")]) # Replace returns with if statements f_ast = f_ast.with_changes(body=f_ast.body.visit(ReplaceReturn(ret_var))) # Inline function body new_stmts.extend(f_ast.body.body) # Create imports for non-local variables imports = generate_imports_for_nonlocals(f_ast, func_obj, call) new_stmts = imports + new_stmts if inliner.add_comments: # Add header comment to first statement call_str = a2s(call) header_comment = [ cst.EmptyLine(comment=cst.Comment(f'# {line}')) for line in call_str.splitlines() ] first_stmt = new_stmts[0] new_stmts[0] = first_stmt.with_changes( leading_lines=[cst.EmptyLine(indent=False)] + header_comment + list(first_stmt.leading_lines)) return new_stmts