def leave_Return( self, original_node: Return, updated_node: Return ) -> Union[BaseSmallStatement, RemovalSentinel]: if self.visiting_permalink_method and m.matches(updated_node.value, m.Tuple()): elem_0 = updated_node.value.elements[0] elem_1_3 = updated_node.value.elements[1:3] args = ( Arg(elem_0.value), Arg(Name("None")), *[Arg(el.value) for el in elem_1_3], ) return updated_node.with_changes( value=Call(func=Name("reverse"), args=args) ) return super().leave_Return(original_node, updated_node)
def update_call_args(self, node: Call) -> Sequence[Arg]: """Remove keyword argument from first argument of `re_path`.""" first_arg, *other_args = node.args if m.matches(first_arg, m.Arg(keyword=m.Name("regex"))): first_arg = Arg(value=first_arg.value) return (first_arg, *other_args) return super().update_call_args(node)
def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression: # TODO test with CompFor etc. # TODO improve join test func_name = ensure_type(updated_node.func, Name).value if func_name not in self.builtins_imports: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node
def update_call_args(self, node: Call) -> Sequence[Arg]: # No args? Just add the default 12. if not node.args: return [Arg(value=self.default_length_value)] # If there's only an allowed chars kwarg, prepend the length arg. allowed_chars_kwarg = find_keyword_arg(node.args, "allowed_chars") if len(node.args) == 1 and allowed_chars_kwarg: return [ Arg(value=self.default_length_value), allowed_chars_kwarg, ] # Otherwise don't do anything. return node.args
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if (self.is_visiting_subclass and m.matches( updated_node, m.Call(func=m.Attribute( attr=m.Name("has_add_permission"), value=m.Call(func=m.Name("super")), )), ) and len(updated_node.args) < 2): updated_args = ( *updated_node.args, Arg(keyword=Name("obj"), value=Name("obj")), ) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if ( is_one_to_one_field(original_node) or is_foreign_key(original_node) ) and not has_on_delete(original_node): AddImportsVisitor.add_needed_import( context=self.context, module="django.db", obj="models", ) updated_args = ( *updated_node.args, Arg( keyword=Name("on_delete"), value=Attribute(value=Name("models"), attr=Name("CASCADE")), ), ) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node
def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: attribute = ensure_type(updated_node.func, Attribute) func_name = attribute.attr dict_name = attribute.value return updated_node.with_changes(func=func_name, args=[Arg(dict_name)])
def build_path_call(self, pattern, other_args): """Build the `Call` node using Django 2.0's `path()` function.""" route = self.build_route(pattern) updated_args = (Arg(value=SimpleString(f"'{route}'")), *other_args) return Call(args=updated_args, func=Name("path"))
def update_call_args(self, node: Call) -> Sequence[Arg]: return (Arg(value=Name("None")), *node.args)
def on_leave(self, original_node: CSTNodeT, updated_node: CSTNodeT) -> Union[CSTNodeT, RemovalSentinel]: # Visit line nodes with print function calls if not isinstance(updated_node, SimpleStatementLine): return updated_node original_line_node = original_node line_node = updated_node if not isinstance(line_node.body[0], Expr): return updated_node node = line_node.body[0].value original_node = original_node.body[0].value if not (isinstance(node, Call) and node.func.value == 'print'): return line_node pos_args = [x.value for x in node.args if not x.keyword] has_vars = False terms = [] n_variables = 0 simple_ixs = [] # indexes of regular, simple strings for ix, arg in enumerate(pos_args): if isinstance(arg, FormattedString): for part in arg.parts: if isinstance(part, FormattedStringExpression): has_vars = True break term = make_str(arg) terms.append(term) elif isinstance(arg, SimpleString): term = extract_string(arg.value) terms.append(term) simple_ixs.append(ix) elif isinstance(arg, ConcatenatedString): visitor = GatherStringVisitor() arg.visit(visitor) term = ''.join([extract_string(s) for s in visitor.strings]) terms.append(term) simple_ixs.append(ix) elif isinstance(arg, Name): has_vars = True n_variables += 1 terms.append('{' + arg.value + '}') # Escape {} in non-f strings if has_vars: for ix in simple_ixs: term = terms[ix] terms[ix] = term.replace('{', '{{').replace('}', '}}') sep = ' ' sep_ = get_keyword(node, 'sep') try: # fails if sep is a variable sep = extract_string(sep_) except TypeError: pass if n_variables == len(terms) == 1: # Avoid putting a single variable inside f-string arg_line = terms[0] else: arg_line = '"' + sep.join(terms) + '"' if has_vars: arg_line = 'f' + arg_line args = [Arg(value=cst.parse_expression(arg_line))] # Gather up comments cst.metadata.MetadataWrapper(original_line_node) cg = GatherCommentsVisitor() original_line_node.visit(cg) comment = cg.get_joined_comment(self.comment_sep) # Remove all comments in order to put them all at the end rc_trans = RemoveCommentsTransformer() line_node = line_node.visit(rc_trans) def get_line_node(level): func = Attribute(value=Name('logging'), attr=Name(level)) node_ = node.with_changes(func=func, args=args) line_node_ = line_node.deep_replace(line_node.body[0].value, node_) line_node_ = line_node_.with_deep_changes( line_node_.trailing_whitespace, comment=comment ) return line_node_ line_node = get_line_node(self.default_level) # pos = self.get_metadata(WhitespaceInclusivePositionProvider, original_line_node) pos = self.get_metadata(PositionProvider, original_line_node) lineix = pos.start.line - 1 # 1 indexed line number end_lineix, end_lineno = pos.end.line - 1, pos.end.line # Predict the source code for the newly changed line node module_node = original_line_node while not isinstance(module_node, Module): module_node = self.get_parent(module_node) # n_lines = len(cst.parse_module("").code_for_node(line_node).splitlines()) # new_code = module_node.deep_replace(original_line_node, line_node).code # line = '\n'.join(new_code.splitlines()[lineix:lineix + n_lines]) line = cst.parse_module("").code_for_node(line_node) # Find the function or class containing the print line context_node = original_line_node while not isinstance(context_node, (FunctionDef, ClassDef, Module)): context_node = self.get_parent(context_node) if isinstance(context_node, Module): source_context = '' else: source_context = '/' + context_node.name.value print( Bcolor.HEADER, f"{self.fpath}{source_context}:" f"{lineix+1}-{end_lineix+1}", Bcolor.ENDC ) print() print_context2(self.lines, lineix, end_lineno, line, self.context_lines) print() import ipdb # ipdb.set_trace() # Query the user to decide whether to accept, modify, or reject changes if self.accept_all: return line_node inp = None while inp not in ['', 'y', 'n', 'A', 'i', 'w', 'e', 'c', 'x', 'q']: inp = input( Bcolor.OKCYAN + "Accept change? (" f"y = yes ({self.default_level}) [default], " "n = no, " "A = yes to all, " "i = yes (info), " "w = yes (warning), " "e = yes (error), " "c = yes (critical), " "x = yes (exception), " "q = quit): " + Bcolor.ENDC ) if inp in ('q', 'Q'): sys.exit(0) elif inp == 'n': return original_line_node elif inp in ['i', 'w', 'e', 'c', 'x']: level = levels[inp] line_node = get_line_node(level) elif inp == 'A': self.accept_all = True return line_node