def test_extract_simple(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call()), ]), ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Subscript(), "left"))), m.Element(m.Call()), ]), ) self.assertIsNone(nodes)
def test_predicate_logic_operators_on_attributes(self) -> None: # Verify that we can or things together. matcher = m.BinaryOperation(left=m.Name( metadata=m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 1))) | m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 2))))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) matcher = m.BinaryOperation(left=m.Integer( metadata=m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 1))) | m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 2))))) node, wrapper = self._make_fixture("12 + 3") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("123 + 4") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can and things together matcher = m.BinaryOperation(left=m.Name( metadata=m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 1))) & m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.LOAD))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("ab + cd") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can not things matcher = m.BinaryOperation(left=m.Name(metadata=~(m.MatchMetadata( meta.ExpressionContextProvider, meta.ExpressionContext.STORE)))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
def test_predicate_logic(self) -> None: # Verify that we can or things together. matcher = m.BinaryOperation(left=m.OneOf( m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 1))), m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 2))), )) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("12 + 3") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("123 + 4") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can and things together matcher = m.BinaryOperation(left=m.AllOf( m.MatchMetadata(meta.PositionProvider, self._make_coderange((1, 0), (1, 1))), m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.LOAD), )) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("ab + cd") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can not things matcher = m.BinaryOperation(left=m.DoesNotMatch( m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
def test_extract_metadata(self) -> None: # Verify true behavior module = cst.parse_module("a + b[c], d(e, f * g)") wrapper = cst.MetadataWrapper(module) expression = cst.ensure_type( cst.ensure_type(wrapper.module.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) self.assertIsNone(nodes)
def test_extract_predicates(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func }) expression = cst.parse_expression("a + b[c], d.z(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_attr = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "attr": extracted_node_attr })
def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False
def test_simple_matcher_false(self) -> None: # Fail to match on a simple node based on the type and the position. node, wrapper = self._make_fixture("foo") self.assertFalse( matches( node, m.Name( value="foo", metadata=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((2, 0), (2, 3)), ), ), metadata_resolver=wrapper, ) ) # Fail to match on any binary expression where the two children are in exact spots. node, wrapper = self._make_fixture("foo + bar") self.assertFalse( matches( node, m.BinaryOperation( left=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 0), (1, 1)), ), right=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 4), (1, 5)), ), ), metadata_resolver=wrapper, ) )
def leave_BinaryOperation( self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation ) -> cst.BaseExpression: expr_key = "expr" extracts = m.extract( original_node, m.BinaryOperation( left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), right=m.SaveMatchedNode( m.MatchIfTrue(_gen_match_simple_expression(self.module)), expr_key, ), ), ) if extracts: expr = extracts[expr_key] parts = [] simple_string = cst.ensure_type(original_node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions = ( [elm.value for elm in expr.elements] if isinstance(expr, cst.Tuple) else [expr] ) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # the %-string doesn't come with same number of elements in tuple return original_node try: parts.append( cst.FormattedStringExpression( expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ) ) ) except Exception: return original_node token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" return cst.FormattedString( parts=parts, start=start, end=simple_string.quote ) return original_node
def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: if (m.matches( node, m.BinaryOperation(left=m.SimpleString(), operator=m.Modulo())) # SimpleString can be bytes and fstring don't support bytes. # https://www.python.org/dev/peps/pep-0498/#no-binary-f-strings and isinstance( cst.ensure_type(node.left, cst.SimpleString).evaluated_value, str)): self.report(node)
def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: if not self.logging_stack: return if m.matches( node, m.BinaryOperation( left=m.OneOf(m.SimpleString(), m.ConcatenatedString()), operator=m.Modulo(), ), ): self.report(node)
def test_extract_tautology(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.SaveMatchedNode( m.Tuple(elements=[ m.Element(m.BinaryOperation()), m.Element(m.Call()) ]), name="node", ), ) self.assertEqual(nodes, {"node": expression})
def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: expr_key = "expr" extracts = m.extract( node, m.BinaryOperation( left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), right=m.SaveMatchedNode( m.MatchIfTrue( _gen_match_simple_expression( self.context.wrapper.module)), expr_key, ), ), ) if extracts: expr = extracts[expr_key] parts = [] simple_string = cst.ensure_type(node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions = ([elm.value for elm in expr.elements] if isinstance( expr, cst.Tuple) else [expr]) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # Only generate warning for cases where %-string not comes with same number of elements in tuple self.report(node) return try: parts.append( cst.FormattedStringExpression(expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ))) except Exception: self.report(node) return token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" replacement = cst.FormattedString(parts=parts, start=start, end=simple_string.quote) self.report(node, replacement=replacement) elif m.matches( node, m.BinaryOperation( left=m.SimpleString(), operator=m.Modulo())) and isinstance( cst.ensure_type( node.left, cst.SimpleString).evaluated_value, str): self.report(node)
class ShedFixers(VisitorBasedCodemodCommand): """Fix a variety of small problems. Replaces `raise NotImplemented` with `raise NotImplementedError`, and converts always-failing assert statements to explicit `raise` statements. Also includes code closely modelled on pybetter's fixers, because it's considerably faster to run all transforms in a single pass if possible. """ DESCRIPTION = "Fix a variety of style, performance, and correctness issues." @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented"))) def leave_Name(self, _, updated_node): # noqa return updated_node.with_changes(value="NotImplementedError") def leave_Assert(self, _, updated_node): # noqa test_code = cst.Module("").code_for_node(updated_node.test) try: test_literal = literal_eval(test_code) except Exception: return updated_node if test_literal: return cst.RemovalSentinel.REMOVE if updated_node.msg is None: return cst.Raise(cst.Name("AssertionError")) return cst.Raise( cst.Call(cst.Name("AssertionError"), args=[cst.Arg(updated_node.msg)])) @m.leave( m.ComparisonTarget(comparator=oneof_names("None", "False", "True"), operator=m.Equal())) def convert_none_cmp(self, _, updated_node): """Inspired by Pybetter.""" return updated_node.with_changes(operator=cst.Is()) @m.leave( m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[m.ComparisonTarget(operator=m.In())]), )) def replace_not_in_condition(self, _, updated_node): """Also inspired by Pybetter.""" expr = cst.ensure_type(updated_node.expression, cst.Comparison) return cst.Comparison( left=expr.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ expr.comparisons[0].with_changes(operator=cst.NotIn()) ], ) @m.leave( m.Call( lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())], rpar=[m.AtLeastN(n=1, matcher=m.RightParen())], )) def remove_pointless_parens_around_call(self, _, updated_node): # This is *probably* valid, but we might have e.g. a multi-line parenthesised # chain of attribute accesses ("fluent interface"), where we need the parens. noparens = updated_node.with_changes(lpar=[], rpar=[]) try: compile(self.module.code_for_node(noparens), "<string>", "eval") return noparens except SyntaxError: return updated_node # The following methods fix https://pypi.org/project/flake8-comprehensions/ @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())])) def replace_generator_in_call_with_comprehension(self, _, updated_node): """Fix flake8-comprehensions C400-402 and 403-404. C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension. Note that set and dict conversions are handled by pyupgrade! """ return cst.ListComp(elt=updated_node.args[0].value.elt, for_in=updated_node.args[0].value.for_in) @m.leave( m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")]) | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")]) | m.Call( func=m.Name("list"), args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")], )) def replace_unnecessary_list_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C411 and C413. Unnecessary <list/reversed> call around sorted(). Also covers C411 Unnecessary list call around list comprehension for lists and sets. """ return updated_node.args[0].value @m.leave( m.Call( func=m.Name("reversed"), args=[m.Arg(m.Call(func=m.Name("sorted")), star="")], )) def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args) _sets = oneof_names("set", "frozenset") _seqs = oneof_names("list", "reversed", "sorted", "tuple") @m.leave( m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")]) | m.Call( func=oneof_names("list", "tuple"), args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")], ) | m.Call( func=m.Name("sorted"), args=[m.Arg(m.Call(func=_seqs), star=""), m.ZeroOrMore()], )) def replace_unnecessary_nested_calls(self, _, updated_node): """Fix flake8-comprehensions C414. Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>().. """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.args[0].value)] + list(updated_node.args[1:]), ) @m.leave( m.Call( func=oneof_names("reversed", "set", "sorted"), args=[ m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)])) ], )) def replace_unnecessary_subscript_reversal(self, _, updated_node): """Fix flake8-comprehensions C415. Unnecessary subscript reversal of iterable within <reversed/set/sorted>(). """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.value)], ) @m.leave( multi( m.ListComp, m.SetComp, elt=m.Name(), for_in=m.CompFor(target=m.Name(), ifs=[], inner_for_in=None, asynchronous=None), )) def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): """Fix flake8-comprehensions C416. Unnecessary <list/set> comprehension - rewrite using <list/set>(). """ if updated_node.elt.value == updated_node.for_in.target.value: func = cst.Name( "list" if isinstance(updated_node, cst.ListComp) else "set") return cst.Call(func=func, args=[cst.Arg(updated_node.for_in.iter)]) return updated_node @m.leave(m.Subscript(oneof_names("Union", "Literal"))) def reorder_union_literal_contents_none_last(self, _, updated_node): subscript = list(updated_node.slice) try: subscript.sort(key=lambda elt: elt.slice.value.value == "None") subscript[-1] = remove_trailing_comma(subscript[-1]) return updated_node.with_changes(slice=subscript) except Exception: # Single-element literals are not slices, etc. return updated_node @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) @m.leave( m.BinaryOperation( left=m.Name("None") | m.BinaryOperation(), operator=m.BitOr(), right=m.DoNotCare(), )) def reorder_union_operator_contents_none_last(self, _, updated_node): def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False node_left = updated_node.left if _has_none(node_left): return updated_node.with_changes(left=updated_node.right, right=node_left) else: return updated_node @m.leave(m.Subscript(value=m.Name("Literal"))) def flatten_literal_subscript(self, _, updated_node): new_slice = [] for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))): new_slice += item.slice.value.slice else: new_slice.append(item) return updated_node.with_changes(slice=new_slice) @m.leave(m.Subscript(value=m.Name("Union"))) def flatten_union_subscript(self, _, updated_node): new_slice = [] has_none = False for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): new_slice += item.slice.value.slice # peel off "Optional" has_none = True elif m.matches(item.slice.value, m.Subscript(m.Name("Union"))) and m.matches( updated_node.value, item.slice.value.value): new_slice += item.slice.value.slice # peel off "Union" or "Literal" elif m.matches(item.slice.value, m.Name("None")): has_none = True else: new_slice.append(item) if has_none: new_slice.append( cst.SubscriptElement(slice=cst.Index(cst.Name("None")))) return updated_node.with_changes(slice=new_slice) @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])]))) def discard_empty_else_blocks(self, _, updated_node): # An `else: pass` block can always simply be discarded, and libcst ensures # that an Else node can only ever occur attached to an If, While, For, or Try # node; in each case `None` is the valid way to represent "no else block". if m.findall(updated_node, m.Comment()): return updated_node # If there are any comments, keep the node return cst.RemoveFromParent() @m.leave( m.Lambda(params=m.MatchIfTrue(lambda node: ( node.star_kwarg is None and not node.kwonly_params and not node. posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and all(param.default is None for param in node.params))))) def remove_lambda_indirection(self, _, updated_node): same_args = [ m.Arg(m.Name(param.name.value), star="", keyword=None) for param in updated_node.params.params ] if m.matches(updated_node.body, m.Call(args=same_args)): return cst.ensure_type(updated_node.body, cst.Call).func return updated_node @m.leave( m.BooleanOperation( left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), operator=m.Or(), right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), )) def collapse_isinstance_checks(self, _, updated_node): left_target, left_type = updated_node.left.args right_target, right_type = updated_node.right.args if left_target.deep_equals(right_target): merged_type = cst.Arg( cst.Tuple([ cst.Element(left_type.value), cst.Element(right_type.value) ])) return updated_node.left.with_changes( args=[left_target, merged_type]) return updated_node
class Checker(m.MatcherDecoratableVisitor): METADATA_DEPENDENCIES = (PositionProvider,) def __init__( self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None ): super().__init__() self.path = path self.verbose = verbose self.ignored = set(ignored or []) self.future_division = False self.errors = False self.stack: List[str] = [] @m.call_if_inside(m.ImportFrom(module=m.Name("__future__"))) @m.visit(m.ImportAlias(name=m.Name("division"))) def import_div(self, node: ImportAlias) -> None: self.future_division = True @m.visit(m.BinaryOperation(operator=m.Divide())) def check_div(self, node: BinaryOperation) -> None: if "division" in self.ignored: return if not self.future_division: pos = self.get_metadata(PositionProvider, node).start print( f"{self.path}:{pos.line}:{pos.column}: division without `from __future__ import division`" ) self.errors = True @m.visit(m.Attribute(attr=m.Name("maxint"), value=m.Name("sys"))) def check_maxint(self, node: Attribute) -> None: if "sys.maxint" in self.ignored: return pos = self.get_metadata(PositionProvider, node).start print(f"{self.path}:{pos.line}:{pos.column}: use of sys.maxint") self.errors = True def visit_ClassDef(self, node: ClassDef) -> None: self.stack.append(node.name.value) def leave_ClassDef(self, node: ClassDef) -> None: self.stack.pop() def visit_FunctionDef(self, node: FunctionDef) -> None: self.stack.append(node.name.value) def leave_FunctionDef(self, node: FunctionDef) -> None: self.stack.pop() def visit_ClassDef_bases(self, node: "ClassDef") -> None: return @m.visit( m.Call( func=m.Attribute(attr=m.Name("assertEquals") | m.Name("assertItemsEqual")) ) ) def visit_old_assert(self, node: Call) -> None: name = ensure_type(node.func, Attribute).attr.value if name in self.ignored: return pos = self.get_metadata(PositionProvider, node).start print(f"{self.path}:{pos.line}:{pos.column}: use of {name}") self.errors = True
class DatetimeUtcnow_(VisitorBasedCodemodCommand): DESCRIPTION: str = "Converts from datetime.utcnow() to datetime.utc()" timezone_utc_matcher = m.Arg( value=m.Attribute( value=m.Name(value="timezone"), attr=m.Name(value="utc") ), keyword=m.Name(value="tzinfo"), ) utc_matcher = m.Arg( value=m.OneOf( m.Name(value="utc"), m.Name(value="UTC"), m.Attribute(value=m.Name(value="pytz",), attr=m.Name(value="UTC")), ), keyword=m.Name(value="tzinfo"), ) datetime_utcnow_matcher = m.Call( func=m.Attribute( value=m.Name(value="datetime"), attr=m.Name(value="utcnow") ), args=[], ) datetime_datetime_utcnow_matcher = m.Call( func=m.Attribute( value=m.Attribute( value=m.Name(value="datetime"), attr=m.Name(value="datetime") ), attr=m.Name(value="utcnow"), ), args=[], ) datetime_replace_matcher = m.Call( func=m.Attribute( value=datetime_utcnow_matcher, attr=m.Name(value="replace") ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) datetime_datetime_replace_matcher = m.Call( func=m.Attribute( value=datetime_datetime_utcnow_matcher, attr=m.Name(value="replace"), ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) timedelta_replace_matcher = m.Call( func=m.Attribute( value=m.BinaryOperation( left=m.OneOf( datetime_utcnow_matcher, datetime_datetime_utcnow_matcher ), operator=m.Add(), ), attr=m.Name(value="replace"), ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) utc_localize_matcher = m.Call( func=m.Attribute( value=m.Name(value="UTC"), attr=m.Name(value="localize"), ), args=[ m.Arg( value=m.OneOf( datetime_utcnow_matcher, datetime_datetime_utcnow_matcher ) ) ], ) def _update_imports(self): RemoveImportsVisitor.remove_unused_import(self.context, "pytz") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC") RemoveImportsVisitor.remove_unused_import( self.context, "datetime", "timezone" ) AddImportsVisitor.add_needed_import( self.context, "bulb.platform.common.timezones", "UTC" ) @m.leave(datetime_utcnow_matcher) def datetime_utcnow_call( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name("now") ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_datetime_utcnow_matcher) def datetime_datetime_utcnow_call( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name(value="datetime"), ), attr=cst.Name(value="now"), ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_replace_matcher) def datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name("now") ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_datetime_replace_matcher) def datetime_datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name(value="datetime"), ), attr=cst.Name(value="now"), ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(timedelta_replace_matcher) def timedelta_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.BinaryOperation: self._update_imports() return cast( cst.BinaryOperation, cast(cst.Attribute, cast(cst.Call, updated_node).func).value, ) @m.leave(utc_localize_matcher) def utc_localize( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return cast(cst.Call, updated_node.args[0].value)
def obf_universal(self, node: cst.CSTNode, *types): if m.matches(node, m.Name()): types = ('a', 'ca', 'v', 'cv') if not types else types node = cst.ensure_type(node, cst.Name) if self.can_rename(node.value, *types): node = self.get_new_cst_name(node) elif m.matches(node, m.NameItem()): node = cst.ensure_type(node, cst.NameItem) node = node.with_changes(name=self.obf_universal(node.name)) elif m.matches(node, m.Call()): node = cst.ensure_type(node, cst.Call) if self.change_methods or self.change_functions: node = self.new_obf_function_name(node) if self.change_arguments or self.change_method_arguments: node = self.obf_function_args(node) elif m.matches(node, m.Attribute()): node = cst.ensure_type(node, cst.Attribute) value = node.value attr = node.attr self.obf_universal(value) self.obf_universal(attr) elif m.matches(node, m.AssignTarget()): node = cst.ensure_type(node, cst.AssignTarget) node = node.with_changes(target=self.obf_universal(node.target)) elif m.matches(node, m.List() | m.Tuple()): node = cst.ensure_type(node, cst.List) if m.matches( node, m.List()) else cst.ensure_type(node, cst.Tuple) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.Subscript()): node = cst.ensure_type(node, cst.Subscript) new_slice = [] for el in node.slice: new_slice.append( el.with_changes(slice=self.obf_slice(el.slice))) node = node.with_changes(slice=new_slice) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Element()): node = cst.ensure_type(node, cst.Element) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Dict()): node = cst.ensure_type(node, cst.Dict) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.DictElement()): node = cst.ensure_type(node, cst.DictElement) new_key = self.obf_universal(node.key) new_val = self.obf_universal(node.value) node = node.with_changes(key=new_key, value=new_val) elif m.matches(node, m.StarredDictElement()): node = cst.ensure_type(node, cst.StarredDictElement) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.If() | m.While()): node = cst.ensure_type(node, cst.IfExp) if m.matches( node, cst.If | cst.IfExp) else cst.ensure_type(node, cst.While) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.IfExp()): node = cst.ensure_type(node, cst.IfExp) node = node.with_changes(body=self.obf_universal(node.body)) node = node.with_changes(test=self.obf_universal(node.test)) node = node.with_changes(orelse=self.obf_universal(node.orelse)) elif m.matches(node, m.Comparison()): node = cst.ensure_type(node, cst.Comparison) new_compars = [] for target in node.comparisons: new_compars.append(self.obf_universal(target)) node = node.with_changes(left=self.obf_universal(node.left)) node = node.with_changes(comparisons=new_compars) elif m.matches(node, m.ComparisonTarget()): node = cst.ensure_type(node, cst.ComparisonTarget) node = node.with_changes( comparator=self.obf_universal(node.comparator)) elif m.matches(node, m.FormattedString()): node = cst.ensure_type(node, cst.FormattedString) new_parts = [] for part in node.parts: new_parts.append(self.obf_universal(part)) node = node.with_changes(parts=new_parts) elif m.matches(node, m.FormattedStringExpression()): node = cst.ensure_type(node, cst.FormattedStringExpression) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.BinaryOperation() | m.BooleanOperation()): node = cst.ensure_type(node, cst.BinaryOperation) if m.matches( node, m.BinaryOperation()) else cst.ensure_type( node, cst.BooleanOperation) node = node.with_changes(left=self.obf_universal(node.left), right=self.obf_universal(node.right)) elif m.matches(node, m.UnaryOperation()): node = cst.ensure_type(node, cst.UnaryOperation) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.ListComp()): node = cst.ensure_type(node, cst.ListComp) node = node.with_changes(elt=self.obf_universal(node.elt)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.DictComp()): node = cst.ensure_type(node, cst.DictComp) node = node.with_changes(key=self.obf_universal(node.key)) node = node.with_changes(value=self.obf_universal(node.value)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.CompFor()): node = cst.ensure_type(node, cst.CompFor) new_ifs = [] node = node.with_changes(target=self.obf_universal(node.target)) node = node.with_changes(iter=self.obf_universal(node.iter)) for el in node.ifs: new_ifs.append(self.obf_universal(el)) node = node.with_changes(ifs=new_ifs) elif m.matches(node, m.CompIf()): node = cst.ensure_type(node, cst.CompIf) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.Integer() | m.Float() | m.SimpleString()): pass else: pass # print(node) return node