def test_predicate_logic_on_attributes(self) -> None: # Verify that we can or things together. matcher = m.BinaryOperation(left=m.Name(metadata=m.OneOf( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), ))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) matcher = m.BinaryOperation(left=m.Integer(metadata=m.OneOf( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), ))) node, wrapper = self._make_fixture("12 + 3") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("123 + 4") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can and things together matcher = m.BinaryOperation(left=m.Name(metadata=m.AllOf( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.LOAD), ))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper)) node, wrapper = self._make_fixture("ab + cd") self.assertFalse(matches(node, matcher, metadata_resolver=wrapper)) # Verify that we can not things matcher = m.BinaryOperation(left=m.Name(metadata=m.DoesNotMatch( m.MatchMetadata(meta.ExpressionContextProvider, meta.ExpressionContext.STORE)))) node, wrapper = self._make_fixture("a + b") self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
def test_at_least_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at least four arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)), ) ) # Fail to match a function call to "foo" with at least four arguments, # the first one being the value 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3)) ), ) ) # Fail to match a function call to "foo" with at least three arguments, # the last one being the value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2"))) ), ) )
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if (self.is_visiting_subclass and m.matches( updated_node, m.Call(func=m.Attribute( attr=m.Name("has_add_permission"), value=m.Call(func=m.Name("super")), )), ) and len(updated_node.args) < 2): updated_args = ( *updated_node.args, parse_arg("obj=obj"), ) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
class TestVisitor(MatcherDecoratableVisitor): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] self.leaves: List[str] = [] @visit(m.FunctionDef(m.Name("foo"))) @visit(m.FunctionDef(m.Name("bar"))) def visit_function(self, node: cst.FunctionDef) -> None: self.visits.append(node.name.value) @leave(m.FunctionDef(m.Name("bar"))) @leave(m.FunctionDef(m.Name("baz"))) def leave_function(self, original_node: cst.FunctionDef) -> None: self.leaves.append(original_node.name.value)
def test_replace_sentinel(self) -> None: def _swap_bools( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.Name("True" if cst.ensure_type(node, cst.Name).value == "False" else "False") # Verify behavior when provided a sentinel replaced = m.replace(cst.RemovalSentinel.REMOVE, m.Name("True") | m.Name("False"), _swap_bools) self.assertEqual(replaced, cst.RemovalSentinel.REMOVE) replaced = m.replace(cst.MaybeSentinel.DEFAULT, m.Name("True") | m.Name("False"), _swap_bools) self.assertEqual(replaced, cst.MaybeSentinel.DEFAULT)
def _check_import_from_parent( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]: """ Check for when the parent module of thing to replace is imported. When `parent.module.the_thing` is transformed, detect such import: from parent import module """ # First, exit early if 'import *' is used if isinstance(updated_node.names, ImportStar): return None # Check whether parent module is imported if not import_from_matches(updated_node, self.old_parent_module_parts): return None # Match, update the node an return it new_import_aliases = [] for import_alias in updated_node.names: if import_alias.evaluated_name == self.old_parent_name: self.save_import_scope(original_node) module_name_str = (import_alias.evaluated_alias or import_alias.evaluated_name) self.context.scratch[self.ctx_key_name_matcher] = m.Attribute( value=m.Name(module_name_str), attr=m.Name(self.old_name), ) self.context.scratch[self.ctx_key_new_func] = Attribute( attr=Name(self.new_name), value=Name(import_alias.evaluated_alias or self.new_parent_name), ) if self.old_parent_module_parts != self.new_parent_module_parts: # import statement needs updating AddImportsVisitor.add_needed_import( context=self.context, module=".".join(self.new_parent_module_parts), obj=self.new_parent_name, asname=import_alias.evaluated_alias, ) continue new_import_aliases.append(import_alias) if not new_import_aliases: # Nothing left in the import statement: remove it return RemoveFromParent() # Some imports are left, update the statement new_import_aliases = clean_new_import_aliases(new_import_aliases) return updated_node.with_changes(names=new_import_aliases)
def test_extract_metadata(self) -> None: # Verify true behavior module = cst.parse_module("a + b[c], d(e, f * g)") wrapper = cst.MetadataWrapper(module) expression = cst.ensure_type( cst.ensure_type(wrapper.module.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) self.assertIsNone(nodes)
def test_extract_predicates(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func }) expression = cst.parse_expression("a + b[c], d.z(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_attr = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "attr": extracted_node_attr })
def _library_import_matcher(self, node: ImportFrom) -> Optional[m.Call]: """Return matcher if django.template.Library is imported.""" imported_name_str = self._get_imported_name(node, "django.template.Library") if not imported_name_str: return None # Build the `Call` matcher to look out for, e.g. `Library()` return m.Call(func=m.Name(imported_name_str))
def _gen_decorator_matchers( assign_targets: Sequence[AssignTarget], ) -> Generator[m.Decorator, None, None]: """Generate matchers for all possible decorators.""" for assign_target in assign_targets: # for each variable it's assigned to if isinstance(assign_target.target, Name): # get the name of the target target_str = assign_target.target.value # matcher we should use for finding decorators to modify yield m.Decorator( decorator=m.Attribute( value=m.Name(target_str), attr=m.Name("assignment_tag"), ) )
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))
def new_obf_function_name(self, func: cst.Call): func_name = func.func # Обфускация имени функции if m.matches(func_name, m.Attribute()): func_name = cst.ensure_type(func_name, cst.Attribute) # Переименовывание имени if self.change_variables: func_name = func_name.with_changes( value=self.obf_universal(func_name.value, 'v')) # Переименовывание метода if self.change_methods: func_name = func_name.with_changes( attr=self.obf_universal(func_name.attr, 'cf')) elif m.matches(func_name, m.Name()): func_name = cst.ensure_type(func_name, cst.Name) if (self.change_functions or self.change_classes) and self.can_rename( func_name.value, 'c', 'f'): func_name = self.get_new_cst_name(func_name.value) else: pass func = func.with_changes(func=func_name) return func
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call(func=m.Attribute(value=m.SimpleString(), attr=m.Name(value="format"))), ): self.report(node)
def test_extract_simple(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call()), ]), ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Subscript(), "left"))), m.Element(m.Call()), ]), ) self.assertIsNone(nodes)
def find_keyword_arg(args: Sequence[Arg], keyword_name: str) -> Optional[Arg]: """Find a kwarg among a sequence of arguments.""" matcher = m.Arg(keyword=m.Name(keyword_name)) for arg in args: if m.matches(arg, matcher): return arg return None
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef): self.class_stack.pop() if not self.change_classes: return updated_node class_name = updated_node.name.value new_bases = [] if self.can_rename(class_name, 'c'): updated_node = self.renamed(updated_node) for base in updated_node.bases: full_name = base.value if m.matches(full_name, m.Name()): full_name = cst.ensure_type(full_name, cst.Name) if self.can_rename(full_name.value, 'c'): base = base.with_changes( value=self.get_new_cst_name(full_name.value)) elif m.matches(full_name, m.Attribute()): # TODO поддержка импортов pass else: pass new_bases.append(base) updated_node = updated_node.with_changes(bases=new_bases) return updated_node
def obf_function_args(self, func: cst.Call): new_args = [] func_root = func.func func_name = '' if m.matches(func_root, m.Name()): func_name = cst.ensure_type(func_root, cst.Name).value elif m.matches(func_root, m.Attribute()): func_name = split_attribute( cst.ensure_type(func_root, cst.Attribute))[-1] if self.change_arguments or self.change_method_arguments: for arg in func.args: # Значения аргументов arg = arg.with_changes(value=self.obf_universal(arg.value)) # Имена аргументов if arg.keyword is not None and self.can_rename_func_param( arg.keyword.value, func_name): new_keyword = self.get_new_cst_name( arg.keyword) if arg.keyword is not None else None arg = arg.with_changes(keyword=new_keyword) new_args.append(arg) func = func.with_changes(args=new_args) return func
def leave_SimpleStatementLine( self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine ) -> Union[BaseStatement, RemovalSentinel]: for n in updated_node.body: if m.matches(n, m.ImportFrom(module=m.Name("__future__"))): self.python_future_updated_node = updated_node elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))): self.builtins_updated_node = updated_node elif m.matches( n, m.ImportFrom( module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")) ), ): self.future_utils_updated_node = updated_node return updated_node
def test_replace_add_one_to_foo_args(self) -> None: def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int( cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), ) # Verify way more complex transform behavior. original = cst.parse_module( "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(41)\n" ) replaced = cst.ensure_type( m.replace( original, m.Call( func=m.Name("bar"), args=[m.Arg(m.SaveMatchedNode(m.Integer(), "arg"))], ), _add_one_to_arg, ), cst.Module, ).code self.assertEqual( replaced, "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(42)\n", )
def test_simple_matcher_false(self) -> None: # Fail to match on a simple node based on the type and the position. node, wrapper = self._make_fixture("foo") self.assertFalse( matches( node, m.Name( value="foo", metadata=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((2, 0), (2, 3)), ), ), metadata_resolver=wrapper, ) ) # Fail to match on any binary expression where the two children are in exact spots. node, wrapper = self._make_fixture("foo + bar") self.assertFalse( matches( node, m.BinaryOperation( left=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 0), (1, 1)), ), right=m.MatchMetadata( meta.SyntacticPositionProvider, self._make_coderange((1, 4), (1, 5)), ), ), metadata_resolver=wrapper, ) )
def test_replace_metadata(self) -> None: def _rename_foo( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.Name).with_changes(value="replaced") original = cst.parse_module( "foo: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n" ) wrapper = cst.MetadataWrapper(original) replaced = cst.ensure_type( m.replace( wrapper, m.Name(metadata=m.MatchMetadataIfTrue( meta.QualifiedNameProvider, lambda qualnames: any(n.name == "foo" for n in qualnames), )), _rename_foo, ), cst.Module, ).code self.assertEqual( replaced, "replaced: int = 37\ndef bar(foo: int) -> int:\n return foo\n\nbiz: int = bar(42)\n", )
def test_lambda_matcher_true(self) -> None: # Match based on identical attributes. self.assertTrue( matches( cst.Name("foo"), m.Name(value=m.MatchIfTrue(lambda value: "o" in value)) ) )
def _func_name(self, func): if m.matches(func, m.Name()): return func.value elif m.matches(func, m.Attribute()): return func.attr.value else: return 'func'
def obf_function_name(self, func: cst.Call, updated_node): func_name = func.func # Обфускация имени функции if m.matches( func_name, m.Attribute()) and self.change_methods and self.can_rename( func_name.attr.value, 'cf'): func_name = cst.ensure_type(func_name, cst.Attribute) func_name = func_name.with_changes( attr=self.get_new_cst_name(func_name.attr)) updated_node = updated_node.with_changes(func=func_name) elif m.matches(func_name, m.Name()) and ( self.change_functions and self.can_rename(func_name.value, 'f') or self.change_classes and self.can_rename(func_name.value, 'c')): func_name = cst.ensure_type(func_name, cst.Name) func_name = self.get_new_cst_name(func_name.value) updated_node = updated_node.with_changes(func=func_name) else: pass return updated_node
def test_lambda_matcher_false(self) -> None: # Fail to match due to incorrect value on Name. self.assertFalse( matches( cst.Name("foo"), m.Name(value=m.MatchIfTrue(lambda value: "a" in value)) ) )
def handle_any_string( self, node: Union[cst.SimpleString, cst.ConcatenatedString]) -> None: value = node.evaluated_value if value is None: return mod = cst.parse_module(value) extracted_nodes = m.extractall( mod, m.Name( value=m.SaveMatchedNode(m.DoNotCare(), "name"), metadata=m.MatchMetadataIfTrue( cst.metadata.ParentNodeProvider, lambda parent: not isinstance(parent, cst.Attribute), ), ) | m.SaveMatchedNode(m.Attribute(), "attribute"), metadata_resolver=MetadataWrapper(mod, unsafe_skip_copy=True), ) names = { cast(str, values["name"]) for values in extracted_nodes if "name" in values } | { name for values in extracted_nodes if "attribute" in values for name, _ in cst.metadata.scope_provider._gen_dotted_names( cast(cst.Attribute, values["attribute"])) } self.names.update(names)
def process_variable(self, node: Union[cst.BaseExpression, cst.BaseAssignTargetExpression]): if m.matches(node, m.Name()): node = cst.ensure_type(node, cst.Name) if self.class_stack and not self.function_stack: self.class_stack[-1].variables.append(node.value) else: self.info.variables.append(node.value) elif m.matches(node, m.Attribute()): node = cst.ensure_type(node, cst.Attribute) splitted_attributes = split_attribute(node) if splitted_attributes[ 0] == 'self' and self.class_stack and self.function_stack and len( splitted_attributes) > 1: self.class_stack[-1].variables.append(splitted_attributes[1]) else: self.info.variables.append(splitted_attributes[0]) elif m.matches(node, m.Tuple()): node = cst.ensure_type(node, cst.Tuple) for el in node.elements: self.process_variable(el.value) else: pass
def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args)
def leave_Yield(self, node: cst.Yield, updated_node: cst.Yield) -> Union[cst.Await, cst.Yield]: if not self.in_coroutine(self.coroutine_stack): return updated_node if not isinstance(updated_node.value, cst.BaseExpression): return updated_node if isinstance(updated_node.value, (cst.List, cst.ListComp)): self.required_imports.add("asyncio") expression = self.pluck_asyncio_gather_expression_from_yield_list_or_list_comp( updated_node) elif m.matches( updated_node, m.Yield(value=((m.Dict() | m.DictComp())) | m.Call(func=m.Name("dict"))), ): raise TransformError( "Yielding a dict of futures (https://www.tornadoweb.org/en/branch3.2/releases/v3.2.0.html#tornado-gen) added in tornado 3.2 is unsupported by the codemod. This file has not been modified. Manually update to supported syntax before running again." ) else: expression = updated_node.value return cst.Await( expression=expression, whitespace_after_await=updated_node.whitespace_after_yield, lpar=updated_node.lpar, rpar=updated_node.rpar, )
def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False