def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: try: key = original.func.attr.value kword_params = self.METHOD_TO_PARAMS[key] except (AttributeError, KeyError): # Either not a method from the API or too convoluted to be sure. return updated # If the existing code is valid, keyword args come after positional args. # Therefore, all positional args must map to the first parameters. args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) if any(k.keyword.value == "request" for k in kwargs): # We've already fixed this file, don't fix it again. return updated kwargs, ctrl_kwargs = partition( lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs) args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] ctrl_kwargs.extend( cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) request_arg = cst.Arg( value=cst.Dict([ cst.DictElement(cst.SimpleString("'{}'".format(name)), cst.Element(value=arg.value)) # Note: the args + kwargs looks silly, but keep in mind that # the control parameters had to be stripped out, and that # those could have been passed positionally or by keyword. for name, arg in zip(kword_params, args + kwargs) ]), keyword=cst.Name("request")) return updated.with_changes(args=[request_arg] + ctrl_kwargs)
def test_at_most_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)), m.Call(m.Name("foo"), (m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)), m.Call(m.Name("foo"), [m.AtMostN(n=5), m.Arg(m.Integer("1"))]), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())), ) )
def test_at_least_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguments after that are # strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.SimpleString()), n=2), ), ), )) # Fail to match a function call to "foo" where the first argument is the integer # value 1, and there are at least three wildcard arguments after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)), ), )) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguements that are integers with # the value 2 after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer("2")), n=2), ), ), ))
def test_at_least_n_matcher_args_true(self) -> None: # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two wildcard arguments after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements are integers of any value # after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements that are integers with the # value 2 or 3 after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2), ), ), ) )
def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], )
def pluck_asyncio_gather_expression_from_yield_list_or_list_comp( node: cst.Yield, ) -> cst.BaseExpression: return cst.Call( func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("gather")), args=[cst.Arg(value=node.value, star="*")], )
def sampleop_to_logpdf(cst_generator, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst_generator(*args, **kwargs), cst.Name("logpdf_sum")), [cst.Arg(name)], )
def replace_unnecessary_subscript_reversal(self, _, updated_node): """Fix flake8-comprehensions C415. Unnecessary subscript reversal of iterable within <reversed/set/sorted>(). """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.value)], )
def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args)
def to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(value=cst_generator(*args, **kwargs), attr=cst.Name("sample")), args=[cst.Arg(value=rng_key)], )
def visit_ClassDef(self, node: cst.ClassDef) -> None: for d in node.decorators: decorator = d.decorator if QualifiedNameProvider.has_name( self, decorator, QualifiedName( name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT ), ): if isinstance(decorator, cst.Call): func = decorator.func args = decorator.args else: # decorator is either cst.Name or cst.Attribute args = () func = decorator # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function. if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args): new_decorator = cst.Call( func=func, args=list(args) + [ cst.Arg( keyword=cst.Name("frozen"), value=cst.Name("True"), equal=cst.AssignEqual( whitespace_before=SimpleWhitespace(value=""), whitespace_after=SimpleWhitespace(value=""), ), ) ], ) self.report(d, replacement=d.with_changes(decorator=new_decorator))
def _spot_reg_write(node: cst.Expr) -> Optional[NBAssign]: # Spot # # state.gprs.get_reg(foo).write_unsigned(bar) # state.gprs.get_reg(foo).write_signed(bar) # # and turn them into # # GPRs[FOO] = bar # GPRs[FOO] = to_2s_complement(bar) if not isinstance(node.value, cst.Call): return None call = node.value if len(call.args) != 1 or not isinstance(call.func, cst.Attribute): return None value = call.args[0].value if call.func.attr.value == 'write_unsigned': rhs = value elif call.func.attr.value == 'write_signed': rhs = cst.Call(func=cst.Name('to_2s_complement'), args=[cst.Arg(value=value)]) else: return None # We expect call.func.value to be match state.gprs.get_reg(foo). # Extract the array reference if we can. reg_ref = ImplTransformer.match_get_reg(call.func.value) if reg_ref is None: return None return NBAssign.make(reg_ref, rhs)
def test_at_most_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" with at most three arguments, # all of which are the integer 4. self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call(m.Name("foo"), (m.AtMostN(m.Arg(m.Integer("4")), n=3), )), ))
def test_does_not_match_false(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("None")),)), m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))),)), ) ) self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)), m.Call(args=(m.DoesNotMatch(m.Arg(m.Integer("1"))),)), ) ) self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)), m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("1")),))), ) ) # Match any call that takes an argument which isn't True or False. self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")),)), m.Call( args=( m.Arg( value=m.DoesNotMatch( m.OneOf(m.Name("True"), m.Name("False")) ) ), ) ), ) ) self.assertFalse( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("True")),)), m.Call(args=(m.Arg(value=(~m.Name("True")) & (~m.Name("False"))),)), ) ) # Match any name node that doesn't match the regex for True self.assertFalse( matches( cst.Name("True"), m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))) ) )
def replace_unnecessary_nested_calls(self, _, updated_node): """Fix flake8-comprehensions C414. Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>().. """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.args[0].value)] + list(updated_node.args[1:]), )
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: if matchers.matches(updated_node, self.matcher): return updated_node.with_changes(args=[ *updated_node.args, cst.Arg(value=cst.Integer(value="0")) ]) return updated_node
def test_args(self) -> None: # Test that we can insert an argument into a function call normally. statement = parse_template_expression( "foo({arg1}, {arg2})", arg1=cst.Name("bar"), arg2=cst.Name("baz"), ) self.assertEqual( self.code(statement), "foo(bar, baz)", ) # Test that we can insert an argument as a special case. statement = parse_template_expression( "foo({arg1}, {arg2})", arg1=cst.Arg(cst.Name("bar")), arg2=cst.Arg(cst.Name("baz")), ) self.assertEqual( self.code(statement), "foo(bar, baz)", )
def test_at_least_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at least four arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)), ) ) # Fail to match a function call to "foo" with at least four arguments, # the first one being the value 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3)) ), ) ) # Fail to match a function call to "foo" with at least three arguments, # the last one being the value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2"))) ), ) )
def test_at_most_n_matcher_args_true(self) -> None: # Match a function call to "foo" with at most two arguments, both of which # are the integer 1. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(func=m.Name("foo"), args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )), )) # Match a function call to "foo" with at most two arguments, both of which # can be the integer 1 or 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(m.Arg( m.OneOf(m.Integer("1"), m.Integer("2"))), n=2), ), ), )) # Match a function call to "foo" with at most two arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), ))
def leave_Call(self, original_node, updated_node) -> cst.BaseExpression: if m.matches(updated_node.func, m.Attribute(value=m.Call(m.Name('super')))): return updated_node \ .with_deep_changes( updated_node.func, value=cst.Name(self.cls.__name__)) \ .with_changes( args=[cst.Arg(cst.Name('self'))] + list(updated_node.args)) return updated_node
def test_or_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse( matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("4")), m.Arg(m.Integer("5")), m.Arg(m.Integer("6")), ), ), ), ) )
def test_at_most_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at most two arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(n=1), m.Arg(m.Integer("3"))) ), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.ZeroOrOne(), m.Arg(m.Integer("3")))), ) )
def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue( matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("True") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertTrue( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
def datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name("now") ), args=[cst.Arg(value=cst.Name(value="UTC"))], )
def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): """Fix flake8-comprehensions C416. Unnecessary <list/set> comprehension - rewrite using <list/set>(). """ if updated_node.elt.value == updated_node.for_in.target.value: func = cst.Name( "list" if isinstance(updated_node, cst.ListComp) else "set") return cst.Call(func=func, args=[cst.Arg(updated_node.for_in.iter)]) return updated_node
def _replace_nested( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.Call).with_changes(args=[ cst.Arg( cst.Name(value=cst.ensure_type( cst.ensure_type(extraction["inner"], cst.Call).func, cst.Name, ).value + "_immediate")) ])
def collapse_isinstance_checks(self, _, updated_node): left_target, left_type = updated_node.left.args right_target, right_type = updated_node.right.args if left_target.deep_equals(right_target): merged_type = cst.Arg( cst.Tuple([ cst.Element(left_type.value), cst.Element(right_type.value) ])) return updated_node.left.with_changes( args=[left_target, merged_type]) return updated_node
def leave_Yield(self, original_node, updated_node) -> cst.BaseExpression: append = parse_expr(f'{self.ret_var}.append()') yield_val = updated_node.value # If original expr was "yield a, b" then yield_val compiles to # "a, b" (i.e. no parens) which errors if directly inserted into # foo.append(a, b). So we ensure that the tuple has parentheses. if m.matches(yield_val, m.Tuple()): yield_val = yield_val.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) return append.with_changes(args=[cst.Arg(yield_val)])
def test_does_not_match_operator_false(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("None")), )), m.Call(args=(m.Arg(value=~m.Name("None")), )), )) self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=((~m.Arg(m.Integer("1"))), )), )) # Match any call that takes an argument which isn't True or False. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg( value=~(m.Name("True") | m.Name("False"))), )), )) # Roundabout way of verifying ~(x&y) behavior. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))), )), )) # Roundabout way of verifying (~x)|(~y) behavior self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))), )), )) # Match any name node that doesn't match the regex for True self.assertFalse( matches(libcst.Name("True"), m.Name(value=~m.MatchRegex(r"True"))))
def leave_Assert(self, _, updated_node): # noqa test_code = cst.Module("").code_for_node(updated_node.test) try: test_literal = literal_eval(test_code) except Exception: return updated_node if test_literal: return cst.RemovalSentinel.REMOVE if updated_node.msg is None: return cst.Raise(cst.Name("AssertionError")) return cst.Raise( cst.Call(cst.Name("AssertionError"), args=[cst.Arg(updated_node.msg)]))