def test_at_most_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(m.Name("foo"), (m.AtMostN(n=2), )), )) # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=2), )), )) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(m.Name("foo"), [m.AtMostN(n=5), m.Arg(m.Integer("1"))]), )) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( libcst.Call( libcst.Name("foo"), (libcst.Arg( libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))), ), m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())), ))
def test_and_matcher_true(self) -> None: # Match on True identifier in roundabout way. self.assertTrue( matches( cst.Name("True"), m.AllOf(m.Name(), m.Name(value=m.MatchRegex(r"True"))) ) ) self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
def test_and_matcher_false(self) -> None: # Fail to match since True and False cannot match. self.assertFalse( matches(cst.Name("None"), m.AllOf(m.Name("True"), m.Name("False"))) ) self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=m.AllOf( (m.Arg(), m.Arg(), m.Arg()), ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ), ), ) )
def test_does_not_match_true(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.DoesNotMatch(m.Arg(m.Name("None"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("2")), ))), )) # Match any call that takes an argument which isn't True or False. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.Arg(value=m.DoesNotMatch( m.OneOf(m.Name("True"), m.Name("False")))), )), )) # Match any name node that doesn't match the regex for True self.assertTrue( matches( libcst.Name("False"), m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))), ))
def test_at_least_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguments after that are # strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.SimpleString()), n=2), ), ), )) # Fail to match a function call to "foo" where the first argument is the integer # value 1, and there are at least three wildcard arguments after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)), ), )) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguements that are integers with # the value 2 after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer("2")), n=2), ), ), ))
def test_at_least_n_matcher_args_true(self) -> None: # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two wildcard arguments after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements are integers of any value # after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and there are at least two arguements that are integers with the # value 2 or 3 after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2), ), ), ) )
def test_does_not_match_operator_true(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=~m.Name("None")), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(~m.Arg(m.Name("None")), )), )) # Match any call that takes an argument which isn't True or False. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.Arg( value=~(m.Name("True") | m.Name("False"))), )), )) self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("None")), )), m.Call(args=(m.Arg(value=(~m.Name("True")) & (~m.Name("False"))), )), )) # Roundabout way to verify that or operator works with inverted nodes. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))), )), )) # Roundabout way to verify that inverse operator works properly on AllOf. self.assertTrue( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("True"))), )), )) # Match any name node that doesn't match the regex for True self.assertTrue( matches(libcst.Name("False"), m.Name(value=~m.MatchRegex(r"True"))))
def leave_For( self, original_node: cst.For, updated_node: cst.For ) -> tp.Union[cst.For, cst.FlattenSentinel[cst.BaseStatement]]: try: iter_obj = eval(to_module(updated_node.iter).code, {}, self.env) is_constant = True except Exception as e: is_constant = False if is_constant and isinstance(iter_obj, unroll): body = [] if not isinstance(updated_node.target, cst.Name): raise NotImplementedError('Unrolling with non-name target') for i in iter_obj: if isinstance(i, int): symbol_table = { updated_node.target.value: cst.Integer(value=repr(i)) } for child in updated_node.body.body: body.append(replace_symbols(child, symbol_table)) else: raise NotImplementedError('Unrolling non-int iterator') updated_node = cst.FlattenSentinel(body) return super().leave_For(original_node, updated_node)
def test_deprecated_construction(self) -> None: module = cst.Module(body=[ cst.SimpleStatementLine(body=[ cst.Expr(value=cst.Subscript( value=cst.Name(value="foo"), slice=[ cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="1"))), cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="2"))), ], )) ]) ]) self.assertEqual(module.code, "foo[1, 2]\n")
def test_at_most_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" with at most three arguments, # all of which are the integer 4. self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call(m.Name("foo"), (m.AtMostN(m.Arg(m.Integer("4")), n=3), )), ))
def test_parse_without_docstring(self) -> None: # pylint: disable=pointless-statement def test_function(): 1 # pylint: enable=pointless-statement self.assert_node_equal( function_parser.parse(test_function)[0], libcst.Integer("1"))
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: if matchers.matches(updated_node, self.matcher): return updated_node.with_changes(args=[ *updated_node.args, cst.Arg(value=cst.Integer(value="0")) ]) return updated_node
def test_annotation(self) -> None: # Test that we can insert an annotation expression normally. statement = parse_template_statement( "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", ) # Test that we can insert an annotation node as a special case. statement = parse_template_statement( "x: {type} = {val}", type=cst.Annotation(cst.Name("int")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "x: int = 5\n", )
def test_at_least_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at least four arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)), ) ) # Fail to match a function call to "foo" with at least four arguments, # the first one being the value 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3)) ), ) ) # Fail to match a function call to "foo" with at least three arguments, # the last one being the value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2"))) ), ) )
def test_at_most_n_matcher_args_true(self) -> None: # Match a function call to "foo" with at most two arguments, both of which # are the integer 1. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(func=m.Name("foo"), args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )), )) # Match a function call to "foo" with at most two arguments, both of which # can be the integer 1 or 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(m.Arg( m.OneOf(m.Integer("1"), m.Integer("2"))), n=2), ), ), )) # Match a function call to "foo" with at most two arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), )) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1 and the second one, if included, being the # integer 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne(m.Arg(m.Integer("2")))), ), ))
def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue( matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("True") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertTrue( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
def test_or_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse( matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("4")), m.Arg(m.Integer("5")), m.Arg(m.Integer("6")), ), ), ), ) )
def test_at_most_n_matcher_no_args_false(self) -> None: # Fail to match a function call to "foo" with at most two arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(n=1), m.Arg(m.Integer("3"))) ), ) ) # Fail to match a function call to "foo" with at most two arguments, # the last one being the integer 3. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.ZeroOrOne(), m.Arg(m.Integer("3")))), ) )
def test_assign_target(self) -> None: # Test that we can insert an assignment target normally. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.Name("first"), b=cst.Name("second"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", ) # Test that we can insert an assignment target as a special case. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.AssignTarget(cst.Name("first")), b=cst.AssignTarget(cst.Name("second")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", )
def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int(cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), )
def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], )
def test_collect_targets(): tree = cst.parse_module(''' x = [0, 1] x[0] = 1 x.attr = 2 ''') x = cst.Name(value='x') x0 = cst.Subscript( value=x, slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))], ) xa = cst.Attribute( value=x, attr=cst.Name('attr'), ) golds = x, x0, xa targets = collect_targets(tree) assert all(t.deep_equals(g) for t, g in zip(targets, golds))
def test_does_not_match_operator_false(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("None")), )), m.Call(args=(m.Arg(value=~m.Name("None")), )), )) self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")), )), m.Call(args=((~m.Arg(m.Integer("1"))), )), )) # Match any call that takes an argument which isn't True or False. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg( value=~(m.Name("True") | m.Name("False"))), )), )) # Roundabout way of verifying ~(x&y) behavior. self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("False")), )), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))), )), )) # Roundabout way of verifying (~x)|(~y) behavior self.assertFalse( matches( libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Name("True")), )), m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))), )), )) # Match any name node that doesn't match the regex for True self.assertFalse( matches(libcst.Name("True"), m.Name(value=~m.MatchRegex(r"True"))))
def test_zero_or_more_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguments are strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))), ), ) ) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and the rest of the arguements are integers with the # value 2. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))), ), ) )
class NumberTest(CSTNodeTest): @data_provider( ( # Simple number (cst.Integer("5"), "5", parse_expression), # Negted number ( cst.UnaryOperation(operator=cst.Minus(), expression=cst.Integer("5")), "-5", parse_expression, CodeRange((1, 0), (1, 2)), ), # In parenthesis ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Minus(), expression=cst.Integer("5"), rpar=(cst.RightParen(),), ), "(-5)", parse_expression, CodeRange((1, 1), (1, 3)), ), ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Minus(), expression=cst.Integer( "5", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) ), rpar=(cst.RightParen(),), ), "(-(5))", parse_expression, CodeRange((1, 1), (1, 5)), ), ( cst.UnaryOperation( operator=cst.Minus(), expression=cst.UnaryOperation( operator=cst.Minus(), expression=cst.Integer("5") ), ), "--5", parse_expression, CodeRange((1, 0), (1, 3)), ), # multiple nested parenthesis ( cst.Integer( "5", lpar=(cst.LeftParen(), cst.LeftParen()), rpar=(cst.RightParen(), cst.RightParen()), ), "((5))", parse_expression, CodeRange((1, 2), (1, 3)), ), ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Plus(), expression=cst.Integer( "5", lpar=(cst.LeftParen(), cst.LeftParen()), rpar=(cst.RightParen(), cst.RightParen()), ), rpar=(cst.RightParen(),), ), "(+((5)))", parse_expression, CodeRange((1, 1), (1, 7)), ), ) ) def test_valid( self, node: cst.CSTNode, code: str, parser: Optional[Callable[[str], cst.CSTNode]], position: Optional[CodeRange] = None, ) -> None: self.validate_node(node, code, parser, expected_position=position) @data_provider( ( ( lambda: cst.Integer("5", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Integer("5", rpar=(cst.RightParen(),)), "right paren without left paren", ), ( lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Float("5.5", rpar=(cst.RightParen(),)), "right paren without left paren", ), ( lambda: cst.Imaginary("5i", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Imaginary("5i", rpar=(cst.RightParen(),)), "right paren without left paren", ), ) ) def test_invalid( self, get_node: Callable[[], cst.CSTNode], expected_re: str ) -> None: self.assert_invalid(get_node, expected_re)
class FooterBehaviorTest(UnitTest): @data_provider({ # Literally the most basic example "simple_module": { "code": "\n", "expected_module": cst.Module(body=()) }, # A module with a header comment "header_only_module": { "code": "# This is a header comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[], ), }, # A module with a header and footer "simple_header_footer_module": { "code": "# This is a header comment\npass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[cst.SimpleStatementLine([cst.Pass()])], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # A module which should have a footer comment taken from the # if statement's indented block. "simple_reparented_footer_module": { "code": "# This is a header comment\nif True:\n pass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( header=cst.TrailingWhitespace(), body=[ cst.SimpleStatementLine( body=[cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( ), ) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verifying that we properly parse and spread out footer comments to the # relative indents they go with. "complex_reparented_footer_module": { "code": ("# This is a header comment\nif True:\n if True:\n pass" + "\n # This is an inner indented block comment\n # This " + "is an outer indented block comment\n# This is a footer comment\n" ), "expected_module": cst.Module( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine( body=[cst.Pass()]) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an inner indented block comment" )) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an outer indented block comment" )) ], ), ) ], header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verify that comments belonging to statements are still owned even # after an indented block. "statement_comment_reparent": { "code": "if foo:\n return\n# comment\nx = 7\n", "expected_module": cst.Module(body=[ cst.If( test=cst.Name(value="foo"), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Return( whitespace_after_return=cst.SimpleWhitespace( value="")) ]) ]), ), cst.SimpleStatementLine( body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value="x")) ], value=cst.Integer(value="7"), ) ], leading_lines=[ cst.EmptyLine(comment=cst.Comment(value="# comment")) ], ), ]), }, # Verify that even if there are completely empty lines, we give all lines # up to and including the last line that's indented correctly. That way # comments that line up with indented block's indentation level aren't # parented to the next line just because there's a blank line or two # between them. "statement_comment_with_empty_lines": { "code": ("def foo():\n if True:\n pass\n\n # Empty " + "line before me\n\n else:\n pass\n"), "expected_module": cst.Module(body=[ cst.FunctionDef( name=cst.Name(value="foo"), params=cst.Parameters(), body=cst.IndentedBlock(body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ], footer=[ cst.EmptyLine(indent=False), cst.EmptyLine(comment=cst.Comment( value="# Empty line before me")), ], ), orelse=cst.Else( body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ]), leading_lines=[cst.EmptyLine(indent=False)], ), ) ]), ) ]), }, }) def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None: parsed_module = parse_module(dedent(code)) self.assertTrue( deep_equals(parsed_module, expected_module), msg= f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}", )
class SubscriptTest(CSTNodeTest): @data_provider(( # Simple subscript expression ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[5]", True, ), # Test creation of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", False, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3, 5]", False, CodeRange((1, 0), (1, 13)), ), # Test parsing of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", True, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), comma=cst.Comma(), ), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3,5]", True, ), # Some more wild slice creations ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, step=cst.Integer("3"), )), ), ), "foo[1::3]", False, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=None, step=cst.Integer("3"))), ), ), "foo[::3]", False, CodeRange((1, 0), (1, 8)), ), # Some more wild slice parsings ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1::3]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[::3]", True, ), # Valid list clone operations rendering ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # Valid list clone operations parsing ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # In parenthesis ( cst.Subscript( lpar=(cst.LeftParen(), ), value=cst.Name("foo"), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "(foo[5])", True, ), # Verify spacing ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 5 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), )), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 , 5 ] )", True, CodeRange((1, 2), (1, 24)), ), # Test Index, Slice, SubscriptElement (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))), ( cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), "::", False, CodeRange((1, 0), (1, 2)), ), ( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), "1 : 2 : 3 , ", False, CodeRange((1, 0), (1, 9)), ), )) def test_valid( self, node: cst.CSTNode, code: str, check_parsing: bool, position: Optional[CodeRange] = None, ) -> None: if check_parsing: self.validate_node(node, code, parse_expression, expected_position=position) else: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class YieldConstructionTest(CSTNodeTest): @data_provider(( # Simple yield (cst.Yield(), "yield"), # yield expression (cst.Yield(cst.Name("a")), "yield a"), # yield from expression (cst.Yield(cst.From(cst.Call(cst.Name("a")))), "yield from a()"), # Parenthesizing tests ( cst.Yield( lpar=(cst.LeftParen(), ), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "(yield 5)", ), # Whitespace oddities tests ( cst.Yield( cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_after_yield=cst.SimpleWhitespace(""), ), "yield(a)", CodeRange((1, 0), (1, 8)), ), ( cst.Yield( cst.From( cst.Call( cst.Name("a"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), whitespace_after_from=cst.SimpleWhitespace(""), )), "yield from(a())", ), # Whitespace rendering/parsing tests ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Integer("5"), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield 5 )", ), ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.From( cst.Call(cst.Name("bla")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield from bla() )", CodeRange((1, 2), (1, 20)), ), # From expression position tests ( cst.From(cst.Integer("5"), whitespace_after_from=cst.SimpleWhitespace(" ")), "from 5", CodeRange((1, 0), (1, 6)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, expected_position=position) @data_provider(( # Paren validation ( lambda: cst.Yield(lpar=(cst.LeftParen(), )), "left paren without right paren", ), ( lambda: cst.Yield(rpar=(cst.RightParen(), )), "right paren without left paren", ), # Make sure we have adequate space after yield ( lambda: cst.Yield(cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace("")), "Must have at least one space after 'yield' keyword", ), ( lambda: cst.Yield( cst.From(cst.Call(cst.Name("a"))), whitespace_after_yield=cst.SimpleWhitespace(""), ), "Must have at least one space after 'yield' keyword", ), # MAke sure we have adequate space after from ( lambda: cst.Yield( cst.From( cst.Call(cst.Name("a")), whitespace_after_from=cst.SimpleWhitespace(""), )), "Must have at least one space after 'from' keyword", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class YieldParsingTest(CSTNodeTest): @data_provider(( # Simple yield (cst.Yield(), "yield"), # yield expression ( cst.Yield(cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace(" ")), "yield a", ), # yield from expression ( cst.Yield( cst.From( cst.Call(cst.Name("a")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), ), "yield from a()", ), # Parenthesizing tests ( cst.Yield( lpar=(cst.LeftParen(), ), whitespace_after_yield=cst.SimpleWhitespace(" "), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "(yield 5)", ), # Whitespace oddities tests ( cst.Yield( cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_after_yield=cst.SimpleWhitespace(""), ), "yield(a)", ), ( cst.Yield( cst.From( cst.Call( cst.Name("a"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), whitespace_after_from=cst.SimpleWhitespace(""), ), whitespace_after_yield=cst.SimpleWhitespace(" "), ), "yield from(a())", ), # Whitespace rendering/parsing tests ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Integer("5"), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield 5 )", ), ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.From( cst.Call(cst.Name("bla")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield from bla() )", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node( node, code, lambda code: ensure_type( ensure_type(parse_statement(code), cst.SimpleStatementLine). body[0], cst.Expr, ).value, ) @data_provider(( { "code": "yield from x", "parser": parse_statement_as(python_version="3.3"), "expect_success": True, }, { "code": "yield from x", "parser": parse_statement_as(python_version="3.1"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: self.assert_parses(**kwargs)
class NamedExprTest(CSTNodeTest): @data_provider(( # Simple named expression { "node": cst.NamedExpr(cst.Name("x"), cst.Float("5.5")), "code": "x := 5.5", "parser": None, # Walrus operator is illegal as top-level statement "expected_position": None, }, # Parenthesized named expression { "node": cst.NamedExpr( lpar=(cst.LeftParen(), ), target=cst.Name("foo"), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "code": "(foo := 5)", "parser": _parse_expression_force_38, "expected_position": CodeRange((1, 1), (1, 9)), }, # Make sure that spacing works { "node": cst.NamedExpr( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), target=cst.Name("foo"), whitespace_before_walrus=cst.SimpleWhitespace(" "), whitespace_after_walrus=cst.SimpleWhitespace(" "), value=cst.Name("bar"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( foo := bar )", "parser": _parse_expression_force_38, "expected_position": CodeRange((1, 2), (1, 14)), }, # Make sure we can use these where allowed in if/while statements { "node": cst.While( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Call(func=cst.Name(value="some_input")), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "while x := some_input(): pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, { "node": cst.If( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Call(func=cst.Name(value="some_input")), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "if x := some_input(): pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, { "node": cst.If( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace(""), whitespace_after_walrus=cst.SimpleWhitespace(""), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "if x:=1: pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, # Function args { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg(value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace(""), whitespace_after_walrus=cst.SimpleWhitespace(""), )), ], ), "code": "f(y:=1)", "parser": _parse_expression_force_38, "expected_position": None, }, # Whitespace handling on args is fragile { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg( value=cst.Name(value="x"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace( " "), whitespace_after_walrus=cst.SimpleWhitespace( " "), ), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ], ), "code": "f(x, y := 1 )", "parser": _parse_expression_force_38, "expected_position": None, }, { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg( value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace( " "), whitespace_after_walrus=cst.SimpleWhitespace( " "), ), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ], whitespace_before_args=cst.SimpleWhitespace(" "), ), "code": "f( y := 1 )", "parser": _parse_expression_force_38, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": (lambda: cst.NamedExpr( cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.NamedExpr( cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)