def collect_targets( self, stack: Tuple[cst.BaseExpression, ...] ) -> Tuple[ List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]] ]: targets = {} operands = [] for operand in stack: if m.matches( operand, m.Call(func=m.DoNotCare(), args=[m.Arg(), m.Arg(~m.Tuple())]) ): call = cst.ensure_type(operand, cst.Call) if not QualifiedNameProvider.has_name(self, call, _ISINSTANCE): operands.append(operand) continue target, match = call.args[0].value, call.args[1].value for possible_target in targets: if target.deep_equals(possible_target): targets[possible_target].append(match) break else: operands.append(target) targets[target] = [match] else: operands.append(operand) return operands, targets
def test_at_most_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(m.Name("foo"), (m.AtMostN(n=2), )), )) # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=2), )), )) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(m.Name("foo"), [m.AtMostN(n=5), m.Arg(m.Integer("1"))]), )) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())), ))
def test_does_not_match_true(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.DoesNotMatch(m.Arg(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("2")), ))), )) # Match any call that takes an argument which isn't True or False. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch( m.OneOf(m.Name("True"), m.Name("False")))), )), )) # Match any name node that doesn't match the regex for True self.assertTrue( matches( libcst.Name("False"), m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))), ))
def leave_Call(self, original_node: cst.Call) -> None: if self.current_classes and m.matches( original_node, m.Call( func=m.Name("super"), args=[ m.Arg(value=self._build_arg_class_matcher()), m.Arg(), ], ), ): self.report(original_node, replacement=original_node.with_changes(args=()))
def test_at_least_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at least four arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)), ) ) # Fail to match a function call to "foo" with at least four arguments, # the first one being the value 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3)) ), ) ) # Fail to match a function call to "foo" with at least three arguments, # the last one being the value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2"))) ), ) )
def test_extract_sequence_element(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())), "args") ])), ]), ) self.assertIsNone(nodes)
def test_replace_add_one_to_foo_args(self) -> None: def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int( cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), ) # Verify way more complex transform behavior. original = cst.parse_module( "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(41)\n" ) replaced = cst.ensure_type( m.replace( original, m.Call( func=m.Name("bar"), args=[m.Arg(m.SaveMatchedNode(m.Integer(), "arg"))], ), _add_one_to_arg, ), cst.Module, ).code self.assertEqual( replaced, "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(42)\n", )
def find_keyword_arg(args: Sequence[Arg], keyword_name: str) -> Optional[Arg]: """Find a kwarg among a sequence of arguments.""" matcher = m.Arg(keyword=m.Name(keyword_name)) for arg in args: if m.matches(arg, matcher): return arg return None
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: """ Remove the `weak` argument if present in the call. This is only changing calls with keyword arguments. """ if self.disconnect_call_matchers and m.matches( updated_node, m.OneOf(*self.disconnect_call_matchers)): updated_args = [] should_change = False last_comma = MaybeSentinel.DEFAULT # Keep all arguments except the one with the keyword `weak` (if present) for index, arg in enumerate(updated_node.args): if m.matches(arg, m.Arg(keyword=m.Name("weak"))): # An argument with the keyword `weak` was found # -> we need to rewrite the statement should_change = True else: updated_args.append(arg) last_comma = arg.comma if should_change: # Make sure the end of line is formatted as initially updated_args[-1] = updated_args[-1].with_changes( comma=last_comma) return updated_node.with_changes(args=updated_args) return super().leave_Call(original_node, updated_node)
def test_replace_updated_node_changes(self) -> None: def _replace_nested( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.Call).with_changes(args=[ cst.Arg( cst.Name(value=cst.ensure_type( cst.ensure_type(extraction["inner"], cst.Call).func, cst.Name, ).value + "_immediate")) ]) original = cst.parse_module( "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar(baz(biz(5))))\n" ) replaced = cst.ensure_type( m.replace( original, m.Call(args=[m.Arg(m.SaveMatchedNode(m.Call(), "inner"))]), _replace_nested, ), cst.Module, ).code self.assertEqual( replaced, "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar_immediate)\n", )
def update_call_args(self, node: Call) -> Sequence[Arg]: """Remove keyword argument from first argument of `re_path`.""" first_arg, *other_args = node.args if m.matches(first_arg, m.Arg(keyword=m.Name("regex"))): first_arg = Arg(value=first_arg.value) return (first_arg, *other_args) return super().update_call_args(node)
def test_at_most_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at most two arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(n=1), m.Arg(m.Integer("3"))) ), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.ZeroOrOne(), m.Arg(m.Integer("3")))), ) )
def test_extract_precedence_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), ])), ]), ) extracted_node = (cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args[1].value) self.assertEqual(nodes, {"arg": extracted_node})
def remove_lambda_indirection(self, _, updated_node): same_args = [ m.Arg(m.Name(param.name.value), star="", keyword=None) for param in updated_node.params.params ] if m.matches(updated_node.body, m.Call(args=same_args)): return cst.ensure_type(updated_node.body, cst.Call).func return updated_node
def leave_ImportFrom( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Union[BaseSmallStatement, RemovalSentinel]: base_cls_matcher = [] if m.matches( updated_node, m.ImportFrom(module=module_matcher(["django", "contrib", "admin"])), ): for imported_name in updated_node.names: if m.matches( imported_name, m.ImportAlias(name=m.Name("TabularInline")) ): base_cls_matcher.append(m.Arg(m.Name("TabularInline"))) if m.matches( imported_name, m.ImportAlias(name=m.Name("StackedInline")) ): base_cls_matcher.append(m.Arg(m.Name("StackedInline"))) if m.matches( updated_node, m.ImportFrom(module=module_matcher(["django", "contrib"])), ): for imported_name in updated_node.names: if m.matches(imported_name, m.ImportAlias(name=m.Name("admin"))): base_cls_matcher.extend( [ m.Arg( m.Attribute( value=m.Name("admin"), attr=m.Name("TabularInline") ) ), m.Arg( m.Attribute( value=m.Name("admin"), attr=m.Name("StackedInline") ) ), ] ) # Save valid matchers in the context if base_cls_matcher: self.context.scratch[self.ctx_key_base_cls_matcher] = m.OneOf( *base_cls_matcher ) return super().leave_ImportFrom(original_node, updated_node)
def has_on_delete(node: Call) -> bool: # if on_delete exists in any kwarg we return True for arg in node.args: if m.matches(arg, m.Arg(keyword=m.Name("on_delete"))): return True # if there are two or more nodes and there are no keywords # then we can assume that positional arguments are being used # and on_delete is being handled. return len(node.args) >= 2 and node.args[1].keyword is None
class InlineHasAddPermissionsTransformer(ContextAwareTransformer): """Add the ``obj`` argument to ``InlineModelAdmin.has_add_permission()``.""" context_key = "InlineHasAddPermissionsTransformer" base_cls_matcher = m.OneOf( m.Arg(m.Attribute(value=m.Name("admin"), attr=m.Name("TabularInline"))), m.Arg(m.Name("TabularInline")), m.Arg(m.Attribute(value=m.Name("admin"), attr=m.Name("StackedInline"))), m.Arg(m.Name("StackedInline")), ) def visit_ClassDef_bases(self, node: ClassDef) -> None: for base_cls in node.bases: if m.matches(base_cls, self.base_cls_matcher): self.context.scratch[self.context_key] = True super().visit_ClassDef_bases(node) def leave_ClassDef( self, original_node: ClassDef, updated_node: ClassDef) -> Union[BaseStatement, RemovalSentinel]: self.context.scratch.pop(self.context_key, None) return super().leave_ClassDef(original_node, updated_node) @property def _is_context_right(self): return self.context.scratch.get(self.context_key, False) def leave_FunctionDef( self, original_node: FunctionDef, updated_node: FunctionDef ) -> Union[BaseStatement, RemovalSentinel]: if (m.matches(updated_node, m.FunctionDef(name=m.Name("has_add_permission"))) and self._is_context_right): if len(updated_node.params.params) == 2: old_params = updated_node.params updated_params = old_params.with_changes(params=( *old_params.params, Param(name=Name("obj"), default=Name("None")), )) return updated_node.with_changes(params=updated_params) return super().leave_FunctionDef(original_node, updated_node)
def test_extractall_simple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr"))) extracted_args = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call ).args self.assertEqual( matches, [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}], )
def test_at_least_n_matcher_args_true(self) -> None: # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two wildcard arguments after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements are integers of any value # after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements that are integers with the # value 2 or 3 after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2), ), ), ) )
def test_and_matcher_false(self) -> None: # Fail to match since True and False cannot match. self.assertFalse( matches(cst.Name("None"), m.AllOf(m.Name("True"), m.Name("False"))) ) self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ), ), ) )
def test_and_matcher_true(self) -> None: # Match on True identifier in roundabout way. self.assertTrue( matches( cst.Name("True"), m.AllOf(m.Name(), m.Name(value=m.MatchRegex(r"True"))) ) ) self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
def test_at_least_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguments after that are # strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.SimpleString()), n=2), ), ), )) # Fail to match a function call to "foo" where the first argument is the integer # value 1, and there are at least three wildcard arguments after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)), ), )) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguements that are integers with # the value 2 after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer("2")), n=2), ), ), ))
def test_does_not_match_operator_false(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("None")), )), m.Call(args=(m.Arg(value=~m.Name("None")), )), )) self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=((~m.Arg(m.Integer("1"))), )), )) # Match any call that takes an argument which isn't True or False. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg( value=~(m.Name("True") | m.Name("False"))), )), )) # Roundabout way of verifying ~(x&y) behavior. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))), )), )) # Roundabout way of verifying (~x)|(~y) behavior self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))), )), )) # Match any name node that doesn't match the regex for True self.assertFalse( matches(libcst.Name("True"), m.Name(value=~m.MatchRegex(r"True"))))
def test_zero_or_more_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguments are strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))), ), ) ) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguements are integers with the # value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))), ), ) )
def test_at_most_n_matcher_args_true(self) -> None: # Match a function call to "foo" with at most two arguments, both of which # are the integer 1. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(func=m.Name("foo"), args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )), )) # Match a function call to "foo" with at most two arguments, both of which # can be the integer 1 or 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(m.Arg( m.OneOf(m.Integer("1"), m.Integer("2"))), n=2), ), ), )) # Match a function call to "foo" with at most two arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), ))
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
class UnhashableListTransformer(NoqaAwareTransformer): @m.call_if_inside( m.Call( func=m.OneOf(m.Name(value="set"), m.Name(value="frozenset")), args=[m.Arg(value=m.OneOf(m.List(), m.Tuple(), m.Set()))], ) | m.Set() # noqa: W503 ) @m.leave(m.List() | m.Set() | m.Tuple()) def convert_list_arg( self, _, updated_node: Union[cst.Set, cst.List, cst.Tuple]) -> cst.BaseExpression: modified_elements = convert_lists_to_tuples(updated_node.elements) return updated_node.with_changes(elements=modified_elements)
def test_extract_optional_wildcard(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.ZeroOrMore(), m.ZeroOrOne( m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))), ])), ]), ) self.assertEqual(nodes, {})
def visit_Call(self, node: cst.Call) -> None: # print(node) d = m.extract( node, m.Call( func=m.OneOf(m.Name("Extension"), m.Name("addMacExtension")), args=( m.Arg(value=m.SaveMatchedNode(m.SimpleString(), "extension_name")), m.ZeroOrMore(m.DoNotCare()), ), ), ) if d: assert isinstance(d["extension_name"], cst.SimpleString) self.extension_names.append(d["extension_name"].evaluated_value)
def test_at_most_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" with at most three arguments, # all of which are the integer 4. self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call(m.Name("foo"), (m.AtMostN(m.Arg(m.Integer("4")), n=3), )), ))