def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                self,
                decorator,
                QualifiedName(
                    name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT
                ),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.
                if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args)
                        + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(value=""),
                                    whitespace_after=SimpleWhitespace(value=""),
                                ),
                            )
                        ],
                    )
                    self.report(d, replacement=d.with_changes(decorator=new_decorator))
Ejemplo n.º 2
0
    def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine,
                                  updated_node: cst.SimpleStatementLine):
        if match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.Assign(targets=[
                        match.AssignTarget(target=match.Name(
                            value=match.DoNotCare()))
                    ])
                ])):
            t = self.__get_var_type_assign_t(
                original_node.body[0].targets[0].target.value)

            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(
                            target=original_node.body[0].targets[0].target,
                            value=original_node.body[0].value,
                            annotation=t_annot_node,
                            equal=cst.AssignEqual(
                                whitespace_after=original_node.body[0].
                                targets[0].whitespace_after_equal,
                                whitespace_before=original_node.body[0].
                                targets[0].whitespace_before_equal))
                    ])
        elif match.matches(
                original_node,
                match.SimpleStatementLine(body=[
                    match.AnnAssign(target=match.Name(value=match.DoNotCare()))
                ])):
            t = self.__get_var_type_an_assign(
                original_node.body[0].target.value)
            if t is not None:
                t_annot_node_resolved = self.resolve_type_alias(t)
                t_annot_node = self.__name2annotation(t_annot_node_resolved)
                if t_annot_node is not None:
                    self.all_applied_types.add(
                        (t_annot_node_resolved, t_annot_node))
                    return updated_node.with_changes(body=[
                        cst.AnnAssign(target=original_node.body[0].target,
                                      value=original_node.body[0].value,
                                      annotation=t_annot_node,
                                      equal=original_node.body[0].equal)
                    ])

        return original_node
Ejemplo n.º 3
0
    def leave_Call(self, original_node, updated_node):
        """Convert positional to keyword arguments."""
        metadata = self.get_metadata(cst.metadata.QualifiedNameProvider,
                                     original_node)
        qualnames = {qn.name for qn in metadata}

        # If this isn't one of our known functions, or it has no posargs, stop there.
        if (len(qualnames) != 1
                or not qualnames.intersection(self.kwonly_functions)
                or not m.matches(
                    updated_node,
                    m.Call(
                        func=m.DoesNotMatch(m.Call()),
                        args=[m.Arg(keyword=None),
                              m.ZeroOrMore()],
                    ),
                )):
            return updated_node

        # Get the actual function object so that we can inspect the signature.
        # This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
        # but having a single source of truth about the signatures is worth it.
        params = signature(get_fn(*qualnames)).parameters.values()

        # st.floats() has a new allow_subnormal kwonly argument not at the end,
        # so we do a bit more of a dance here.
        if qualnames == {"hypothesis.strategies.floats"}:
            params = [p for p in params if p.name != "allow_subnormal"]

        if len(updated_node.args) > len(params):
            return updated_node

        # Create new arg nodes with the newly required keywords
        assign_nospace = cst.AssignEqual(
            whitespace_before=cst.SimpleWhitespace(""),
            whitespace_after=cst.SimpleWhitespace(""),
        )
        newargs = [
            arg if arg.keyword or arg.star
            or p.kind is not Parameter.KEYWORD_ONLY else arg.with_changes(
                keyword=cst.Name(p.name), equal=assign_nospace)
            for p, arg in zip(params, updated_node.args)
        ]
        return updated_node.with_changes(args=newargs)
Ejemplo n.º 4
0
    def leave_Call(self, original_node, updated_node):
        """Convert positional to keyword arguments."""
        metadata = self.get_metadata(cst.metadata.QualifiedNameProvider, original_node)
        qualnames = {qn.name for qn in metadata}

        # If this isn't one of our known functions, or it has no posargs, stop there.
        if (
            len(qualnames) != 1
            or not qualnames.intersection(self.kwonly_functions)
            or not m.matches(
                updated_node,
                m.Call(
                    func=m.DoesNotMatch(m.Call()),
                    args=[m.Arg(keyword=None), m.ZeroOrMore()],
                ),
            )
        ):
            return updated_node

        # Get the actual function object so that we can inspect the signature.
        # This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
        # but having a single source of truth about the signatures is worth it.
        mod, fn = list(qualnames.intersection(self.kwonly_functions))[0].rsplit(".", 1)
        try:
            func = getattr(importlib.import_module(mod), fn)
        except ImportError:
            return updated_node

        # Create new arg nodes with the newly required keywords
        assign_nospace = cst.AssignEqual(
            whitespace_before=cst.SimpleWhitespace(""),
            whitespace_after=cst.SimpleWhitespace(""),
        )
        newargs = [
            arg
            if arg.keyword or arg.star or p.kind is not Parameter.KEYWORD_ONLY
            else arg.with_changes(keyword=cst.Name(p.name), equal=assign_nospace)
            for p, arg in zip(signature(func).parameters.values(), updated_node.args)
        ]
        return updated_node.with_changes(args=newargs)
Ejemplo n.º 5
0
class ClassDefParserTest(CSTNodeTest):
    @data_provider((
        # Simple classdef
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node":
            cst.ClassDef(cst.Name("Foo"),
                         cst.SimpleStatementSuite((cst.Pass(), ))),
            "code":
            "class Foo: pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(): pass\n",
        },
        # Positional arguments render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(cst.Arg(cst.Name("obj")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(obj): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(
                    cst.Arg(
                        cst.Name("Bar"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        cst.Name("Baz"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(cst.Name("object")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(Bar, Baz, object): pass\n",
        },
        # Keyword arguments render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(cst.Arg(
                    keyword=cst.Name("metaclass"),
                    equal=cst.AssignEqual(),
                    value=cst.Name("Bar"),
                ), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(metaclass = Bar): pass\n",
        },
        # Iterator expansion render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(cst.Arg(star="*", value=cst.Name("one")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(*one): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(
                    cst.Arg(
                        star="*",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(*one, *two, *three): pass\n",
        },
        # Dictionary expansion render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(cst.Arg(star="**", value=cst.Name("one")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(**one): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(
                    cst.Arg(
                        star="**",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(**one, **two, **three): pass\n",
        },
        # Decorator render tests
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                decorators=(cst.Decorator(cst.Name("foo")), ),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "@foo\nclass Foo(): pass\n",
            "expected_position":
            CodeRange((2, 0), (2, 17)),
        },
        {
            "node":
            cst.ClassDef(
                leading_lines=(
                    cst.EmptyLine(),
                    cst.EmptyLine(comment=cst.Comment("# leading comment 1")),
                ),
                decorators=(
                    cst.Decorator(cst.Name("foo"), leading_lines=()),
                    cst.Decorator(
                        cst.Name("bar"),
                        leading_lines=(cst.EmptyLine(
                            comment=cst.Comment("# leading comment 2")), ),
                    ),
                    cst.Decorator(
                        cst.Name("baz"),
                        leading_lines=(cst.EmptyLine(
                            comment=cst.Comment("# leading comment 3")), ),
                    ),
                ),
                lines_after_decorators=(cst.EmptyLine(
                    comment=cst.Comment("# class comment")), ),
                name=cst.Name("Foo"),
                body=cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "\n# leading comment 1\n@foo\n# leading comment 2\n@bar\n# leading comment 3\n@baz\n# class comment\nclass Foo(): pass\n",
            "expected_position":
            CodeRange((9, 0), (9, 17)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs, parser=parse_statement)
Ejemplo n.º 6
0
class AnnAssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")),
                          cst.Integer("5")),
            "code":
            "foo: str = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 12)),
        },
        # Annotation creation without assignment
        {
            "node": cst.AnnAssign(cst.Name("foo"),
                                  cst.Annotation(cst.Name("str"))),
            "code": "foo: str",
            "parser": None,
            "expected_position": CodeRange((1, 0), (1, 8)),
        },
        # Complex annotation creation
        {
            "node":
            cst.AnnAssign(
                cst.Name("foo"),
                cst.Annotation(
                    cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    )),
                cst.Integer("5"),
            ),
            "code":
            "foo: Optional[str] = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 22)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: str = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Annotation without assignment
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                value=None,
            ), )),
            "code":
            "foo: str\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Complex annotation
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: Optional[str] = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test
        {
            "node":
            cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ),
            "code":
            "foo :  Optional[str]  =  5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 26)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo :  Optional[str]  =  5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.AnnAssign(
            target=cst.Name("foo"),
            annotation=cst.Annotation(cst.Name("str")),
            equal=cst.AssignEqual(),
            value=None,
        )),
        "expected_re":
        "Must have a value when specifying an AssignEqual.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 7
0
class LambdaParserTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(
                    cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(cst.Name("baz"), star=""),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(
                        cst.Name("bar"),
                        default=cst.SimpleString('"one"'),
                        equal=cst.AssignEqual(),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ), ),
                    default_params=(cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ), ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(cst.Name("baz"), star=""),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(
                            cst.Name("first"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params"), star="*")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**"),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                params=cst.Parameters(),
                colon=cst.Colon(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace(" "),
                ),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, parse_expression, position)
Ejemplo n.º 8
0
class LambdaCreationTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar")),
                                       cst.Param(cst.Name("baz")))),
                cst.Integer("5"),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz"), default=cst.Integer("5")),
                )),
                cst.Integer("5"),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(cst.Name("bar")), ),
                    default_params=(cst.Param(cst.Name("baz"),
                                              default=cst.Integer("5")), ),
                ),
                cst.Integer("5"),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(kwonly_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz")),
                )),
                cst.Integer("5"),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(cst.Name("first"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("second"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
            CodeRange((1, 0), (1, 84)),
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("params"))),
                cst.Integer("5"),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))),
                cst.Integer("5"),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    star_kwarg=cst.Param(cst.Name("kwparams")),
                ),
                cst.Integer("5"),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                whitespace_after_lambda=cst.SimpleWhitespace("  "),
                params=cst.Parameters(),
                colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
            CodeRange((1, 2), (1, 13)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"), default=cst.Integer("5")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    equal=cst.AssignEqual())),
                cst.Integer("5"),
            ),
            "Must have a default when specifying an AssignEqual.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    star="***")),
                cst.Integer("5"),
            ),
            r"Must specify either '', '\*' or '\*\*' for star.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("bar"), default=cst.SimpleString('"one"')), )),
                cst.Integer("5"),
            ),
            "Cannot have defaults for params",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )),
                cst.Integer("5"),
            ),
            "Must have defaults for default_params",
        ),
        (
            lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()),
                               cst.Integer("5")),
            "Must have at least one kwonly param if ParamStar is used.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), )
                               ),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("bar"),
                    default=cst.SimpleString('"one"'),
                    star="*",
                ), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"),
                                                        star="*"), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")
                               ),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"),
                    default=cst.Integer("5"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"),
                                                  annotation=cst.Annotation(
                                                      cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"),
                                                    annotation=cst.Annotation(
                                                        cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
Ejemplo n.º 9
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(
                value=a.value,
                keyword=cst.Name(value=ctrl),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            ) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        if self._use_keywords:
            new_kwargs = [
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=name),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for name, arg in zip(kword_params, args + kwargs)
            ]
            new_kwargs.extend([
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=arg.keyword.value),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for arg in ctrl_kwargs
            ])
            return updated.with_changes(args=new_kwargs)
        else:
            request_arg = cst.Arg(
                value=cst.Dict([
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(name)),
                        cst.Element(value=arg.value),
                    ) for name, arg in zip(kword_params, args + kwargs)
                ] + [
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(arg.keyword.value)),
                        cst.Element(value=arg.value),
                    ) for arg in ctrl_kwargs
                ]),
                keyword=cst.Name("request"),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            )

            return updated.with_changes(args=[request_arg])
Ejemplo n.º 10
0
class CallTest(CSTNodeTest):
    @data_provider((
        # Simple call
        {
            "node": cst.Call(cst.Name("foo")),
            "code": "foo()",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(cst.Name("foo"),
                     whitespace_before_args=cst.SimpleWhitespace(" ")),
            "code":
            "foo( )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Call with attribute dereference
        {
            "node": cst.Call(cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
            "code": "foo.bar()",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Positional arguments render test
        {
            "node": cst.Call(cst.Name("foo"), (cst.Arg(cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": None,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(cst.Integer("1")),
                    cst.Arg(cst.Integer("2")),
                    cst.Arg(cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Positional arguments parse test
        {
            "node": cst.Call(cst.Name("foo"),
                             (cst.Arg(value=cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    whitespace_after_arg=cst.SimpleWhitespace(" "),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1 )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    comma=cst.Comma(
                        whitespace_after=cst.SimpleWhitespace(" ")),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1, )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Keyword arguments render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")),
                    cst.Arg(keyword=cst.Name("two"), value=cst.Integer("2")),
                    cst.Arg(keyword=cst.Name("three"), value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Keyword arguments parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    keyword=cst.Name("one"),
                    equal=cst.AssignEqual(),
                    value=cst.Integer("1"),
                ), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        keyword=cst.Name("one"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("two"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("three"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                    ),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Iterator expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="*", value=cst.Name("one")),
                    cst.Arg(star="*", value=cst.Name("two")),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Iterator expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="*",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Dictionary expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="**", value=cst.Name("one")),
                    cst.Arg(star="**", value=cst.Name("two")),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Dictionary expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="**",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Complicated mingling rules render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(value=cst.Name("pos1")),
                    cst.Arg(star="*", value=cst.Name("list1")),
                    cst.Arg(value=cst.Name("pos2")),
                    cst.Arg(value=cst.Name("pos3")),
                    cst.Arg(star="*", value=cst.Name("list2")),
                    cst.Arg(value=cst.Name("pos4")),
                    cst.Arg(star="*", value=cst.Name("list3")),
                    cst.Arg(keyword=cst.Name("kw1"), value=cst.Integer("1")),
                    cst.Arg(star="*", value=cst.Name("list4")),
                    cst.Arg(keyword=cst.Name("kw2"), value=cst.Integer("2")),
                    cst.Arg(star="*", value=cst.Name("list5")),
                    cst.Arg(keyword=cst.Name("kw3"), value=cst.Integer("3")),
                    cst.Arg(star="**", value=cst.Name("dict1")),
                    cst.Arg(keyword=cst.Name("kw4"), value=cst.Integer("4")),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Complicated mingling rules parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw2"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list5"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw3"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("dict1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw4"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Test whitespace
        {
            "node":
            cst.Call(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                func=cst.Name("foo"),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
                args=(
                    cst.Arg(
                        keyword=None,
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.Arg(
                        star="*",
                        whitespace_after_star=cst.SimpleWhitespace("  "),
                        keyword=None,
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        keyword=None,
                        whitespace_after_star=cst.SimpleWhitespace(" "),
                        value=cst.Name("dict1"),
                        whitespace_after_arg=cst.SimpleWhitespace(" "),
                    ),
                ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( foo ( pos1 ,  *  list1, kw1=1, ** dict1 ) )",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 43)),
        },
        # Test args
        {
            "node":
            cst.Arg(
                star="*",
                whitespace_after_star=cst.SimpleWhitespace("  "),
                keyword=None,
                value=cst.Name("list1"),
                comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
            ),
            "code":
            "*  list1, ",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 8)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Basic expression parenthesizing tests.
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), lpar=(cst.LeftParen(), )),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), rpar=(cst.RightParen(), )),
            "expected_re":
            "right paren without left paren",
        },
        # Test that we handle keyword stuff correctly.
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(equal=cst.AssignEqual(),
                              value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            "Must have a keyword when specifying an AssignEqual",
        },
        # Test that we separate *, ** and keyword args correctly
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(
                    star="*",
                    keyword=cst.Name("bar"),
                    value=cst.SimpleString("'baz'"),
                ), ),
            ),
            "expected_re":
            "Cannot specify a star and a keyword together",
        },
        # Test for expected star inputs only
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                # pyre-ignore: Ignore type on 'star' since we're testing behavior
                # when somebody isn't using a type checker.
                args=(cst.Arg(star="***", value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            r"Must specify either '', '\*' or '\*\*' for star",
        },
        # Test ordering exceptions
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(star="*", value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have iterable argument unpacking after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(keyword=cst.Name("arg"),
                            value=cst.SimpleString("'baz'")),
                    cst.Arg(value=cst.SimpleString("'bar'")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 11
0
class AtomTest(CSTNodeTest):
    @data_provider(
        (
            # Simple identifier
            {
                "node": cst.Name("test"),
                "code": "test",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized identifier
            {
                "node": cst.Name(
                    "test", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": "(test)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 5)),
            },
            # Decimal integers
            {
                "node": cst.Integer("12345"),
                "code": "12345",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("0000"),
                "code": "0000",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("1_234_567"),
                "code": "1_234_567",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("0_000"),
                "code": "0_000",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Binary integers
            {
                "node": cst.Integer("0b0000"),
                "code": "0b0000",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("0B1011_0100"),
                "code": "0B1011_0100",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Octal integers
            {
                "node": cst.Integer("0o12345"),
                "code": "0o12345",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("0O12_345"),
                "code": "0O12_345",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Hex numbers
            {
                "node": cst.Integer("0x123abc"),
                "code": "0x123abc",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Integer("0X12_3ABC"),
                "code": "0X12_3ABC",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized integers
            {
                "node": cst.Integer(
                    "123", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": "(123)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 4)),
            },
            # Non-exponent floats
            {
                "node": cst.Float("12345."),
                "code": "12345.",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("00.00"),
                "code": "00.00",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("12.21"),
                "code": "12.21",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float(".321"),
                "code": ".321",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("1_234_567."),
                "code": "1_234_567.",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("0.000_000"),
                "code": "0.000_000",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Exponent floats
            {
                "node": cst.Float("12345.e10"),
                "code": "12345.e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("00.00e10"),
                "code": "00.00e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("12.21e10"),
                "code": "12.21e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float(".321e10"),
                "code": ".321e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("1_234_567.e10"),
                "code": "1_234_567.e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("0.000_000e10"),
                "code": "0.000_000e10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("1e+10"),
                "code": "1e+10",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Float("1e-10"),
                "code": "1e-10",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized floats
            {
                "node": cst.Float(
                    "123.4", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": "(123.4)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 6)),
            },
            # Imaginary numbers
            {
                "node": cst.Imaginary("12345j"),
                "code": "12345j",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Imaginary("1_234_567J"),
                "code": "1_234_567J",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Imaginary("12345.e10j"),
                "code": "12345.e10j",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.Imaginary(".321J"),
                "code": ".321J",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized imaginary
            {
                "node": cst.Imaginary(
                    "123.4j", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": "(123.4j)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 7)),
            },
            # Simple elipses
            {
                "node": cst.Ellipsis(),
                "code": "...",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized elipses
            {
                "node": cst.Ellipsis(lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)),
                "code": "(...)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 4)),
            },
            # Simple strings
            {
                "node": cst.SimpleString('""'),
                "code": '""',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString("''"),
                "code": "''",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString('"test"'),
                "code": '"test"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString('b"test"'),
                "code": 'b"test"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString('r"test"'),
                "code": 'r"test"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString('"""test"""'),
                "code": '"""test"""',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Validate parens
            {
                "node": cst.SimpleString(
                    '"test"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": '("test")',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.SimpleString(
                    'rb"test"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                ),
                "code": '(rb"test")',
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 9)),
            },
            # Test that _safe_to_use_with_word_operator allows no space around quotes
            {
                "node": cst.Comparison(
                    cst.SimpleString('"a"'),
                    [
                        cst.ComparisonTarget(
                            cst.In(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            cst.SimpleString('"abc"'),
                        )
                    ],
                ),
                "code": '"a"in"abc"',
                "parser": parse_expression,
            },
            {
                "node": cst.Comparison(
                    cst.SimpleString('"a"'),
                    [
                        cst.ComparisonTarget(
                            cst.In(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            cst.ConcatenatedString(
                                cst.SimpleString('"a"'), cst.SimpleString('"bc"')
                            ),
                        )
                    ],
                ),
                "code": '"a"in"a""bc"',
                "parser": parse_expression,
            },
            # Parenthesis make no spaces around a prefix okay
            {
                "node": cst.Comparison(
                    cst.SimpleString('b"a"'),
                    [
                        cst.ComparisonTarget(
                            cst.In(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            cst.SimpleString(
                                'b"abc"',
                                lpar=[cst.LeftParen()],
                                rpar=[cst.RightParen()],
                            ),
                        )
                    ],
                ),
                "code": 'b"a"in(b"abc")',
                "parser": parse_expression,
            },
            {
                "node": cst.Comparison(
                    cst.SimpleString('b"a"'),
                    [
                        cst.ComparisonTarget(
                            cst.In(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            cst.ConcatenatedString(
                                cst.SimpleString('b"a"'),
                                cst.SimpleString('b"bc"'),
                                lpar=[cst.LeftParen()],
                                rpar=[cst.RightParen()],
                            ),
                        )
                    ],
                ),
                "code": 'b"a"in(b"a"b"bc")',
                "parser": parse_expression,
            },
            # Empty formatted strings
            {
                "node": cst.FormattedString(start='f"', parts=(), end='"'),
                "code": 'f""',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(start="f'", parts=(), end="'"),
                "code": "f''",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(start='f"""', parts=(), end='"""'),
                "code": 'f""""""',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(start="f'''", parts=(), end="'''"),
                "code": "f''''''",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Non-empty formatted strings
            {
                "node": cst.FormattedString(parts=(cst.FormattedStringText("foo"),)),
                "code": 'f"foo"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(cst.FormattedStringExpression(cst.Name("foo")),)
                ),
                "code": 'f"{foo}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringText("foo "),
                        cst.FormattedStringExpression(cst.Name("bar")),
                        cst.FormattedStringText(" baz"),
                    )
                ),
                "code": 'f"foo {bar} baz"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringText("foo "),
                        cst.FormattedStringExpression(cst.Call(cst.Name("bar"))),
                        cst.FormattedStringText(" baz"),
                    )
                ),
                "code": 'f"foo {bar()} baz"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Formatted strings with conversions and format specifiers
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(cst.Name("foo"), conversion="s"),
                    )
                ),
                "code": 'f"{foo!s}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(cst.Name("foo"), format_spec=()),
                    )
                ),
                "code": 'f"{foo:}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("today"),
                            format_spec=(cst.FormattedStringText("%B %d, %Y"),),
                        ),
                    )
                ),
                "code": 'f"{today:%B %d, %Y}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            format_spec=(
                                cst.FormattedStringExpression(cst.Name("bar")),
                            ),
                        ),
                    )
                ),
                "code": 'f"{foo:{bar}}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            format_spec=(
                                cst.FormattedStringExpression(cst.Name("bar")),
                                cst.FormattedStringText("."),
                                cst.FormattedStringExpression(cst.Name("baz")),
                            ),
                        ),
                    )
                ),
                "code": 'f"{foo:{bar}.{baz}}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            conversion="s",
                            format_spec=(
                                cst.FormattedStringExpression(cst.Name("bar")),
                            ),
                        ),
                    )
                ),
                "code": 'f"{foo!s:{bar}}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Test equality expression added in 3.8.
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            equal=cst.AssignEqual(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                        ),
                    ),
                ),
                "code": 'f"{foo=}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            equal=cst.AssignEqual(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            conversion="s",
                        ),
                    ),
                ),
                "code": 'f"{foo=!s}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Name("foo"),
                            equal=cst.AssignEqual(
                                whitespace_before=cst.SimpleWhitespace(""),
                                whitespace_after=cst.SimpleWhitespace(""),
                            ),
                            conversion="s",
                            format_spec=(
                                cst.FormattedStringExpression(cst.Name("bar")),
                            ),
                        ),
                    ),
                ),
                "code": 'f"{foo=!s:{bar}}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            # Test that equality support doesn't break existing support
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Comparison(
                                left=cst.Name(
                                    value="a",
                                ),
                                comparisons=[
                                    cst.ComparisonTarget(
                                        operator=cst.Equal(),
                                        comparator=cst.Name(
                                            value="b",
                                        ),
                                    ),
                                ],
                            ),
                        ),
                    ),
                ),
                "code": 'f"{a == b}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Comparison(
                                left=cst.Name(
                                    value="a",
                                ),
                                comparisons=[
                                    cst.ComparisonTarget(
                                        operator=cst.NotEqual(),
                                        comparator=cst.Name(
                                            value="b",
                                        ),
                                    ),
                                ],
                            ),
                        ),
                    ),
                ),
                "code": 'f"{a != b}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.NamedExpr(
                                target=cst.Name(
                                    value="a",
                                ),
                                value=cst.Integer(
                                    value="5",
                                ),
                                lpar=(cst.LeftParen(),),
                                rpar=(cst.RightParen(),),
                            ),
                        ),
                    ),
                ),
                "code": 'f"{(a := 5)}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringExpression(
                            cst.Yield(
                                value=cst.Integer("1"),
                                whitespace_after_yield=cst.SimpleWhitespace(" "),
                            ),
                        ),
                    ),
                ),
                "code": 'f"{yield 1}"',
                "parser": _parse_expression_force_38,
                "expected_position": None,
            },
            {
                "node": cst.FormattedString(
                    parts=(
                        cst.FormattedStringText("\\N{X Y}"),
                        cst.FormattedStringExpression(
                            cst.Name(value="Z"),
                        ),
                    ),
                ),
                "code": 'f"\\N{X Y}{Z}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Validate parens
            {
                "node": cst.FormattedString(
                    start='f"',
                    parts=(),
                    end='"',
                    lpar=(cst.LeftParen(),),
                    rpar=(cst.RightParen(),),
                ),
                "code": '(f"")',
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 4)),
            },
            # Generator expression (doesn't make sense, but legal syntax)
            {
                "node": cst.FormattedString(
                    start='f"',
                    parts=[
                        cst.FormattedStringExpression(
                            expression=cst.GeneratorExp(
                                elt=cst.Name(
                                    value="x",
                                ),
                                for_in=cst.CompFor(
                                    target=cst.Name(
                                        value="x",
                                    ),
                                    iter=cst.Name(
                                        value="y",
                                    ),
                                ),
                                lpar=[],
                                rpar=[],
                            ),
                        ),
                    ],
                    end='"',
                ),
                "code": 'f"{x for x in y}"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Concatenated strings
            {
                "node": cst.ConcatenatedString(
                    cst.SimpleString('"ab"'), cst.SimpleString('"c"')
                ),
                "code": '"ab""c"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.ConcatenatedString(
                    cst.SimpleString('"ab"'),
                    cst.ConcatenatedString(
                        cst.SimpleString('"c"'), cst.SimpleString('"d"')
                    ),
                ),
                "code": '"ab""c""d"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # mixed SimpleString and FormattedString
            {
                "node": cst.ConcatenatedString(
                    cst.FormattedString([cst.FormattedStringText("ab")]),
                    cst.SimpleString('"c"'),
                ),
                "code": 'f"ab""c"',
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.ConcatenatedString(
                    cst.SimpleString('"ab"'),
                    cst.FormattedString([cst.FormattedStringText("c")]),
                ),
                "code": '"ab"f"c"',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Concatenated parenthesized strings
            {
                "node": cst.ConcatenatedString(
                    lpar=(cst.LeftParen(),),
                    left=cst.SimpleString('"ab"'),
                    right=cst.SimpleString('"c"'),
                    rpar=(cst.RightParen(),),
                ),
                "code": '("ab""c")',
                "parser": parse_expression,
                "expected_position": None,
            },
            # Validate spacing
            {
                "node": cst.ConcatenatedString(
                    lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),),
                    left=cst.SimpleString('"ab"'),
                    whitespace_between=cst.SimpleWhitespace(" "),
                    right=cst.SimpleString('"c"'),
                    rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),),
                ),
                "code": '( "ab" "c" )',
                "parser": parse_expression,
                "expected_position": CodeRange((1, 2), (1, 10)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "node": cst.FormattedStringExpression(
                    cst.Name("today"),
                    format_spec=(cst.FormattedStringText("%B %d, %Y"),),
                ),
                "code": "{today:%B %d, %Y}",
                "parser": None,
                "expected_position": CodeRange((1, 0), (1, 17)),
            },
        )
    )
    def test_valid_no_parse(self, **kwargs: Any) -> None:
        # Test some nodes that aren't valid source code by themselves
        self.validate_node(**kwargs)

    @data_provider(
        (
            # Expression wrapping parenthesis rules
            {
                "get_node": (lambda: cst.Name("foo", lpar=(cst.LeftParen(),))),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.Name("foo", rpar=(cst.RightParen(),)),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.Ellipsis(lpar=(cst.LeftParen(),)),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.Ellipsis(rpar=(cst.RightParen(),)),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.Integer("5", lpar=(cst.LeftParen(),)),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.Integer("5", rpar=(cst.RightParen(),)),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.Float("5.5", rpar=(cst.RightParen(),)),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": (lambda: cst.Imaginary("5j", lpar=(cst.LeftParen(),))),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (lambda: cst.Imaginary("5j", rpar=(cst.RightParen(),))),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": (lambda: cst.Integer("5", lpar=(cst.LeftParen(),))),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (lambda: cst.Integer("5", rpar=(cst.RightParen(),))),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": (
                    lambda: cst.SimpleString("'foo'", lpar=(cst.LeftParen(),))
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (
                    lambda: cst.SimpleString("'foo'", rpar=(cst.RightParen(),))
                ),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": (
                    lambda: cst.FormattedString(parts=(), lpar=(cst.LeftParen(),))
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (
                    lambda: cst.FormattedString(parts=(), rpar=(cst.RightParen(),))
                ),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": (
                    lambda: cst.ConcatenatedString(
                        cst.SimpleString("'foo'"),
                        cst.SimpleString("'foo'"),
                        lpar=(cst.LeftParen(),),
                    )
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (
                    lambda: cst.ConcatenatedString(
                        cst.SimpleString("'foo'"),
                        cst.SimpleString("'foo'"),
                        rpar=(cst.RightParen(),),
                    )
                ),
                "expected_re": "right paren without left paren",
            },
            # Node-specific rules
            {
                "get_node": (lambda: cst.Name("")),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": (lambda: cst.Name(r"\/")),
                "expected_re": "not a valid identifier",
            },
            {
                "get_node": (lambda: cst.Integer("")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("012345")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("012345")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("_12345")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("0b2")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("0o8")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("0xg")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("123.45")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Integer("12345j")),
                "expected_re": "not a valid integer",
            },
            {
                "get_node": (lambda: cst.Float("12.3.45")),
                "expected_re": "not a valid float",
            },
            {"get_node": (lambda: cst.Float("12")), "expected_re": "not a valid float"},
            {
                "get_node": (lambda: cst.Float("12.3j")),
                "expected_re": "not a valid float",
            },
            {
                "get_node": (lambda: cst.Imaginary("_12345j")),
                "expected_re": "not a valid imaginary",
            },
            {
                "get_node": (lambda: cst.Imaginary("0b0j")),
                "expected_re": "not a valid imaginary",
            },
            {
                "get_node": (lambda: cst.Imaginary("0o0j")),
                "expected_re": "not a valid imaginary",
            },
            {
                "get_node": (lambda: cst.Imaginary("0x0j")),
                "expected_re": "not a valid imaginary",
            },
            {
                "get_node": (lambda: cst.SimpleString('wee""')),
                "expected_re": "Invalid string prefix",
            },
            {
                "get_node": (lambda: cst.SimpleString("'")),
                "expected_re": "must have enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString('"')),
                "expected_re": "must have enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString("\"'")),
                "expected_re": "must have matching enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString("")),
                "expected_re": "must have enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString("'bla")),
                "expected_re": "must have matching enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString("f''")),
                "expected_re": "Invalid string prefix",
            },
            {
                "get_node": (lambda: cst.SimpleString("'''bla''")),
                "expected_re": "must have matching enclosing quotes",
            },
            {
                "get_node": (lambda: cst.SimpleString("'''bla\"\"\"")),
                "expected_re": "must have matching enclosing quotes",
            },
            {
                "get_node": (lambda: cst.FormattedString(start="'", parts=(), end="'")),
                "expected_re": "Invalid f-string prefix",
            },
            {
                "get_node": (
                    lambda: cst.FormattedString(start="f'", parts=(), end='"')
                ),
                "expected_re": "must have matching enclosing quotes",
            },
            {
                "get_node": (
                    lambda: cst.ConcatenatedString(
                        cst.SimpleString(
                            '"ab"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                        ),
                        cst.SimpleString('"c"'),
                    )
                ),
                "expected_re": "Cannot concatenate parenthesized",
            },
            {
                "get_node": (
                    lambda: cst.ConcatenatedString(
                        cst.SimpleString('"ab"'),
                        cst.SimpleString(
                            '"c"', lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                        ),
                    )
                ),
                "expected_re": "Cannot concatenate parenthesized",
            },
            {
                "get_node": (
                    lambda: cst.ConcatenatedString(
                        cst.SimpleString('"ab"'), cst.SimpleString('b"c"')
                    )
                ),
                "expected_re": "Cannot concatenate string and bytes",
            },
            # This isn't valid code: `"a" inb"abc"`
            {
                "get_node": (
                    lambda: cst.Comparison(
                        cst.SimpleString('"a"'),
                        [
                            cst.ComparisonTarget(
                                cst.In(whitespace_after=cst.SimpleWhitespace("")),
                                cst.SimpleString('b"abc"'),
                            )
                        ],
                    )
                ),
                "expected_re": "Must have at least one space around comparison operator.",
            },
            # Also not valid: `"a" in b"a"b"bc"`
            {
                "get_node": (
                    lambda: cst.Comparison(
                        cst.SimpleString('"a"'),
                        [
                            cst.ComparisonTarget(
                                cst.In(whitespace_after=cst.SimpleWhitespace("")),
                                cst.ConcatenatedString(
                                    cst.SimpleString('b"a"'), cst.SimpleString('b"bc"')
                                ),
                            )
                        ],
                    )
                ),
                "expected_re": "Must have at least one space around comparison operator.",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider(
        (
            {
                "code": "u'x'",
                "parser": parse_expression_as(python_version="3.3"),
                "expect_success": True,
            },
            {
                "code": "u'x'",
                "parser": parse_expression_as(python_version="3.1"),
                "expect_success": False,
            },
        )
    )
    def test_versions(self, **kwargs: Any) -> None:
        self.assert_parses(**kwargs)