示例#1
0
    def _check_formatted_string(
        self,
        _original_node: libcst.FormattedString,
        updated_node: libcst.FormattedString,
    ) -> libcst.BaseExpression:
        old_string_inner = libcst.ensure_type(updated_node.parts[0],
                                              libcst.FormattedStringText).value
        if "{{" in old_string_inner or "}}" in old_string_inner:
            # there are only two characters we need to worry about escaping.
            return updated_node

        old_string_literal = updated_node.start + old_string_inner + updated_node.end
        new_string_literal = (
            updated_node.start.replace("f", "").replace("F", "") +
            old_string_inner + updated_node.end)

        old_string_evaled = eval(old_string_literal)  # noqa
        new_string_evaled = eval(new_string_literal)  # noqa
        if old_string_evaled != new_string_evaled:
            warn_message = (
                f"Attempted to codemod |{old_string_literal}| to " +
                f"|{new_string_literal}| but don't eval to the same! First is |{old_string_evaled}| and "
                + f"second is |{new_string_evaled}|")
            self.warn(warn_message)
            return updated_node

        return libcst.SimpleString(new_string_literal)
示例#2
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(value=a.value, keyword=cst.Name(value=ctrl))
            for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        request_arg = cst.Arg(
            value=cst.Dict([
                cst.DictElement(cst.SimpleString("'{}'".format(name)),
                                cst.Element(value=arg.value))
                # Note: the args + kwargs looks silly, but keep in mind that
                # the control parameters had to be stripped out, and that
                # those could have been passed positionally or by keyword.
                for name, arg in zip(kword_params, args + kwargs)
            ]),
            keyword=cst.Name("request"))

        return updated.with_changes(args=[request_arg] + ctrl_kwargs)
示例#3
0
    def leave_FunctionDef(self, original_node: cst.FunctionDef,
                          updated_node: cst.FunctionDef) -> cst.FunctionDef:
        docstring = None
        docstring_node = get_docstring_node(updated_node.body)
        if docstring_node:
            if isinstance(docstring_node.value,
                          (cst.SimpleString, cst.ConcatenatedString)):
                docstring = docstring_node.value.evaluated_value
        if not docstring:
            return updated_node
        new_docstring, types = gather_types(docstring)
        if types.get(RETURN):
            updated_node = updated_node.with_changes(returns=cst.Annotation(
                cst.Name(types.pop(RETURN))), )

        if types:

            def get_annotation(p: cst.Param) -> Optional[cst.Annotation]:
                pname = p.name.value
                if types.get(pname):
                    return cst.Annotation(cst.parse_expression(types[pname]))
                return None

            updated_node = updated_node.with_changes(params=update_parameters(
                updated_node.params, get_annotation, False))

        new_docstring_node = cst.SimpleString('"""%s"""' % new_docstring)
        return updated_node.deep_replace(docstring_node,
                                         cst.Expr(new_docstring_node))
示例#4
0
 def leave_Expr(
     self, original_node: cst.Expr, updated_node: cst.Expr
 ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]:
     if match.matches(original_node,
                      match.Expr(value=match.SimpleString())):
         return updated_node.with_changes(value=cst.SimpleString(
             value='"""[docstring]"""'))
     else:
         return updated_node
 def test_simple_statement(self) -> None:
     statement = parse_template_statement(
         "assert {test}, {msg}\n",
         test=cst.Name("True"),
         msg=cst.SimpleString('"Somehow True is no longer True..."'),
     )
     self.assertEqual(
         self.code(statement), 'assert True, "Somehow True is no longer True..."\n',
     )
示例#6
0
    def visit_FormattedString(self, node: cst.FormattedString) -> None:
        if not m.matches(node, m.FormattedString(parts=(m.FormattedStringText(),))):
            return

        old_string_inner = cst.ensure_type(node.parts[0], cst.FormattedStringText).value
        if "{{" in old_string_inner or "}}" in old_string_inner:
            old_string_inner = old_string_inner.replace("{{", "{").replace("}}", "}")

        new_string_literal = node.start.replace("f", "").replace("F", "") + old_string_inner + node.end

        self.report(node, replacement=cst.SimpleString(new_string_literal))
示例#7
0
 def annotation(self) -> cst.BaseExpression:
     if self.options:
         return cst.Subscript(
             cst.Name("Literal"),
             [
                 cst.SubscriptElement(
                     cst.Index(cst.SimpleString(repr(option))))
                 for option in self.options
             ],
         )
     return cst.Name("str")
示例#8
0
    def to_dictionary_of_samples(random_variables, *_):
        scopes = [rv.scope for rv in random_variables]
        names = [rv.name for rv in random_variables]

        scoped = defaultdict(dict)
        for scope, var_name, var in zip(scopes, names, random_variables):
            scoped[scope][var_name] = var

        # if there is only one scope (99% of models) we return a flat dictionary
        if len(set(scopes)) == 1:

            scope = scopes[0]
            return cst.Dict(
                [
                    cst.DictElement(
                        cst.SimpleString(f"'{var_name}'"),
                        cst.Name(var.name),
                    )
                    for var_name, var in scoped[scope].items()
                ]
            )

        # Otherwise we return a nested dictionary where the first level is
        # the scope, and then the variables.
        return cst.Dict(
            [
                cst.DictElement(
                    cst.SimpleString(f"'{scope}'"),
                    cst.Dict(
                        [
                            cst.DictElement(
                                cst.SimpleString(f"'{var_name}'"),
                                cst.Name(var.name),
                            )
                            for var_name, var in scoped[scope].items()
                        ]
                    ),
                )
                for scope in scoped.keys()
            ]
        )
示例#9
0
 def leave_Name(
         self, original_node: cst.Name,
         updated_node: cst.Name) -> Union[cst.Name, cst.SimpleString]:
     value = updated_node.value
     if value == "NoneType":
         # This is special-cased in typing, un-special case it.
         return updated_node.with_changes(value="None")
     if value in CST_DIR and not value.endswith("Sentinel"):
         # If this isn't a typing define and it isn't a builtin, convert it to
         # a forward ref string.
         return cst.SimpleString(repr(value))
     return updated_node
示例#10
0
    def leave_FormattedString(
        self, original_node: cst.FormattedString, updated_node: cst.FormattedString
    ) -> cst.BaseExpression:
        if len(updated_node.parts) == 1 and isinstance(
            updated_node.parts[0], cst.FormattedStringText
        ):
            # We need to explicitly specify quotation marks here, otherwise we
            # will fail SimpleString's internal validation. This is due to
            # SimpleString._get_prefix treating everything before quotation
            # marks as a prefix. (sic!)
            return cst.SimpleString(value=f'"{updated_node.parts[0].value}"')

        return original_node
示例#11
0
 def leave_Expr(self, original_node, updated_node):
     if m.matches(updated_node, m.Expr(m.SimpleString())):
         s = updated_node.value.value
         if s.startswith('"""'):
             lines = s[3:-3].splitlines()
             final = ''
             for line in lines:
                 if line.strip() != '':
                     final = line
                     break
             return updated_node.with_changes(
                 value=cst.SimpleString(f'"""{final}"""'))
     return updated_node
示例#12
0
文件: matcher.py 项目: sk-/craftier
def _flatten_concatenated_string(
    node: libcst.ConcatenatedString,
) -> Union[libcst.SimpleString, libcst.FormattedString]:
    classes = {type(node.left)}
    parts = []
    rest: Union[libcst.ConcatenatedString, libcst.SimpleString,
                libcst.FormattedString] = node
    while isinstance(rest, libcst.ConcatenatedString):
        parts.append(rest.left)
        classes.add(type(rest.left))
        rest = rest.right
    parts.append(rest)
    classes.add(type(rest))
    # print(parts)

    if all(isinstance(n, libcst.SimpleString) for n in parts):
        # There's no idiom other than casting to tell mypy the list only has one
        # of the union elements. See https://github.com/python/mypy/issues/3497
        string_parts = cast(List[libcst.SimpleString], parts)
        content = "".join(n.evaluated_value for n in string_parts)
        return libcst.SimpleString(value=repr(content))

    formatted_parts: List[libcst.BaseFormattedStringContent] = []
    for part in parts:
        if isinstance(part, libcst.SimpleString):
            value = _repr_single(part.evaluated_value)[1:-1]
            if formatted_parts and isinstance(formatted_parts[-1],
                                              libcst.FormattedStringText):
                formatted_parts[-1] = libcst.FormattedStringText(
                    value=formatted_parts[-1].value + value)
            else:
                formatted_parts.append(libcst.FormattedStringText(value=value))
        else:
            for nested_part in part.parts:
                prefix = part.start.replace("f", "").replace("F", "")
                if isinstance(nested_part, libcst.FormattedStringText):
                    value = _repr_single(
                        ast.literal_eval(
                            f"{prefix}{nested_part.value}{part.end}"))[1:-1]
                    if formatted_parts and isinstance(
                            formatted_parts[-1], libcst.FormattedStringText):
                        formatted_parts[-1] = libcst.FormattedStringText(
                            value=formatted_parts[-1].value + value)
                    else:
                        formatted_parts.append(
                            libcst.FormattedStringText(value=value))
                else:
                    formatted_parts.append(nested_part)
    return libcst.FormattedString(parts=formatted_parts, start="f'", end="'")
示例#13
0
 def leave_Subscript(
     self,
     original_node: libcst.Subscript,
     updated_node: Union[libcst.Subscript, libcst.SimpleString],
 ) -> Union[libcst.Subscript, libcst.SimpleString]:
     if libcst.matchers.matches(original_node.value,
                                libcst.matchers.Name("PathLike")):
         name_node = libcst.Attribute(
             value=libcst.Name(
                 value="os",
                 lpar=[],
                 rpar=[],
             ),
             attr=libcst.Name(value="PathLike"),
         )
         node_as_string = libcst.parse_module("").code_for_node(
             updated_node.with_changes(value=name_node))
         updated_node = libcst.SimpleString(f"'{node_as_string}'")
     return updated_node
示例#14
0
 def leave_SimpleString(self, original_node: cst.SimpleString,
                        updated_node: cst.SimpleString) -> cst.SimpleString:
     if self.quote == original_node.quote:
         for quo in ["'", '"', "'''", '"""']:
             if quo != original_node.quote and quo not in original_node.raw_value:
                 escaped_string = cst.SimpleString(original_node.prefix +
                                                   quo +
                                                   original_node.raw_value +
                                                   quo)
                 if escaped_string.evaluated_value != original_node.evaluated_value:
                     raise Exception(
                         f"Failed to escape string:\n  original:{original_node.value}\n  escaped:{escaped_string.value}"
                     )
                 else:
                     return escaped_string
         raise Exception(
             f"Cannot find a good quote for escaping the SimpleString: {original_node.value}"
         )
     return original_node
示例#15
0
 def _get_assert_replacement(self, node: cst.Assert):
     message = node.msg or str(cst.Module(body=[node]).code)
     return cst.If(
         test=cst.UnaryOperation(
             operator=cst.Not(),
             expression=node.test,  # Todo: parenthesize?
         ),
         body=cst.IndentedBlock(body=[
             cst.SimpleStatementLine(body=[
                 cst.Raise(exc=cst.Call(
                     func=cst.Name(value="AssertionError", ),
                     args=[
                         cst.Arg(value=cst.SimpleString(value=repr(message),
                                                        ), ),
                     ],
                 ), ),
             ]),
         ], ),
     )
示例#16
0
def _convert_annotation(
    raw: str,
    quote_annotations: bool,
) -> cst.Annotation:
    """
    Convert a raw annotation - which is a string coming from a type
    comment - into a suitable libcst Annotation node.

    If `quote_annotations`, we'll always quote annotations unless they are builtin
    types. The reason for this is to make the codemod safer to apply
    on legacy code where type comments may well include invalid types
    that would crash at runtime.
    """
    if _is_builtin(raw):
        return cst.Annotation(annotation=cst.Name(value=raw))
    if not quote_annotations:
        try:
            return cst.Annotation(annotation=cst.parse_expression(raw))
        except cst.ParserSyntaxError:
            pass
    return cst.Annotation(annotation=cst.SimpleString(f'"{raw}"'))
示例#17
0
class AssertParsingTest(CSTNodeTest):
    @data_provider((
        # Simple assert
        {
            "node": cst.Assert(cst.Name("True")),
            "code": "assert True",
            "parser": _assert_parser,
            "expected_position": None,
        },
        # Assert with message
        {
            "node":
            cst.Assert(
                cst.Name("True"),
                cst.SimpleString('"Value should be true"'),
                comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
            ),
            "code":
            'assert True, "Value should be true"',
            "parser":
            _assert_parser,
            "expected_position":
            None,
        },
        # Whitespace oddities test
        {
            "node":
            cst.Assert(
                cst.Name("True",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), )),
                whitespace_after_assert=cst.SimpleWhitespace(""),
            ),
            "code":
            "assert(True)",
            "parser":
            _assert_parser,
            "expected_position":
            None,
        },
        # Whitespace rendering test
        {
            "node":
            cst.Assert(
                whitespace_after_assert=cst.SimpleWhitespace("  "),
                test=cst.Name("True"),
                comma=cst.Comma(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                msg=cst.SimpleString('"Value should be true"'),
            ),
            "code":
            'assert  True  ,  "Value should be true"',
            "parser":
            _assert_parser,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)
示例#18
0
 def _string_leave(
     self,
     original_node: Union[cst.SimpleString, cst.Name],
     updated_node: Union[cst.SimpleString, cst.Name],
 ) -> Union[cst.SimpleString, cst.Pass]:
     return cst.SimpleString('""')
示例#19
0
def make_string(s):
    escaped = s.replace('"', '\"')
    return cst.SimpleString(f'"{escaped}\"')
示例#20
0
class LambdaParserTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(
                    cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(cst.Name("baz"), star=""),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(
                        cst.Name("bar"),
                        default=cst.SimpleString('"one"'),
                        equal=cst.AssignEqual(),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ),
                )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(
                        cst.Name("bar"),
                        star="",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ), ),
                    default_params=(cst.Param(
                        cst.Name("baz"),
                        default=cst.Integer("5"),
                        equal=cst.AssignEqual(),
                        star="",
                    ), ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(cst.Name("baz"), star=""),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(
                            cst.Name("first"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.ParamStar(),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params"), star="*")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(
                            cst.Name("first"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("second"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    default_params=(
                        cst.Param(
                            cst.Name("third"),
                            default=cst.Float("1.0"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("fourth"),
                            default=cst.Float("1.5"),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                    ),
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    kwonly_params=(
                        cst.Param(
                            cst.Name("bar"),
                            default=cst.SimpleString('"one"'),
                            equal=cst.AssignEqual(),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("baz"),
                            star="",
                            comma=cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Param(
                            cst.Name("biz"),
                            default=cst.SimpleString('"two"'),
                            equal=cst.AssignEqual(),
                            star="",
                        ),
                    ),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**")),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(
                        cst.Name("params"),
                        star="*",
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    star_kwarg=cst.Param(cst.Name("kwparams"), star="**"),
                ),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(" "),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                params=cst.Parameters(),
                colon=cst.Colon(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace(" "),
                ),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, parse_expression, position)
示例#21
0
class CallTest(CSTNodeTest):
    @data_provider((
        # Simple call
        {
            "node": cst.Call(cst.Name("foo")),
            "code": "foo()",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(cst.Name("foo"),
                     whitespace_before_args=cst.SimpleWhitespace(" ")),
            "code":
            "foo( )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Call with attribute dereference
        {
            "node": cst.Call(cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
            "code": "foo.bar()",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Positional arguments render test
        {
            "node": cst.Call(cst.Name("foo"), (cst.Arg(cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": None,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(cst.Integer("1")),
                    cst.Arg(cst.Integer("2")),
                    cst.Arg(cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Positional arguments parse test
        {
            "node": cst.Call(cst.Name("foo"),
                             (cst.Arg(value=cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    whitespace_after_arg=cst.SimpleWhitespace(" "),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1 )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    comma=cst.Comma(
                        whitespace_after=cst.SimpleWhitespace(" ")),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1, )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Keyword arguments render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")),
                    cst.Arg(keyword=cst.Name("two"), value=cst.Integer("2")),
                    cst.Arg(keyword=cst.Name("three"), value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Keyword arguments parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    keyword=cst.Name("one"),
                    equal=cst.AssignEqual(),
                    value=cst.Integer("1"),
                ), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        keyword=cst.Name("one"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("two"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("three"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                    ),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Iterator expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="*", value=cst.Name("one")),
                    cst.Arg(star="*", value=cst.Name("two")),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Iterator expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="*",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Dictionary expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="**", value=cst.Name("one")),
                    cst.Arg(star="**", value=cst.Name("two")),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Dictionary expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="**",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Complicated mingling rules render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(value=cst.Name("pos1")),
                    cst.Arg(star="*", value=cst.Name("list1")),
                    cst.Arg(value=cst.Name("pos2")),
                    cst.Arg(value=cst.Name("pos3")),
                    cst.Arg(star="*", value=cst.Name("list2")),
                    cst.Arg(value=cst.Name("pos4")),
                    cst.Arg(star="*", value=cst.Name("list3")),
                    cst.Arg(keyword=cst.Name("kw1"), value=cst.Integer("1")),
                    cst.Arg(star="*", value=cst.Name("list4")),
                    cst.Arg(keyword=cst.Name("kw2"), value=cst.Integer("2")),
                    cst.Arg(star="*", value=cst.Name("list5")),
                    cst.Arg(keyword=cst.Name("kw3"), value=cst.Integer("3")),
                    cst.Arg(star="**", value=cst.Name("dict1")),
                    cst.Arg(keyword=cst.Name("kw4"), value=cst.Integer("4")),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Complicated mingling rules parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw2"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list5"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw3"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("dict1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw4"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Test whitespace
        {
            "node":
            cst.Call(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                func=cst.Name("foo"),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
                args=(
                    cst.Arg(
                        keyword=None,
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.Arg(
                        star="*",
                        whitespace_after_star=cst.SimpleWhitespace("  "),
                        keyword=None,
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        keyword=None,
                        whitespace_after_star=cst.SimpleWhitespace(" "),
                        value=cst.Name("dict1"),
                        whitespace_after_arg=cst.SimpleWhitespace(" "),
                    ),
                ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( foo ( pos1 ,  *  list1, kw1=1, ** dict1 ) )",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 43)),
        },
        # Test args
        {
            "node":
            cst.Arg(
                star="*",
                whitespace_after_star=cst.SimpleWhitespace("  "),
                keyword=None,
                value=cst.Name("list1"),
                comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
            ),
            "code":
            "*  list1, ",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 8)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Basic expression parenthesizing tests.
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), lpar=(cst.LeftParen(), )),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), rpar=(cst.RightParen(), )),
            "expected_re":
            "right paren without left paren",
        },
        # Test that we handle keyword stuff correctly.
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(equal=cst.AssignEqual(),
                              value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            "Must have a keyword when specifying an AssignEqual",
        },
        # Test that we separate *, ** and keyword args correctly
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(
                    star="*",
                    keyword=cst.Name("bar"),
                    value=cst.SimpleString("'baz'"),
                ), ),
            ),
            "expected_re":
            "Cannot specify a star and a keyword together",
        },
        # Test for expected star inputs only
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                # pyre-ignore: Ignore type on 'star' since we're testing behavior
                # when somebody isn't using a type checker.
                args=(cst.Arg(star="***", value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            r"Must specify either '', '\*' or '\*\*' for star",
        },
        # Test ordering exceptions
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(star="*", value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have iterable argument unpacking after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(keyword=cst.Name("arg"),
                            value=cst.SimpleString("'baz'")),
                    cst.Arg(value=cst.SimpleString("'bar'")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
示例#22
0
class SimpleStatementTest(CSTNodeTest):
    @data_provider((
        # a single-element SimpleStatementLine
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node": cst.SimpleStatementLine((cst.Pass(), )),
            "code": "pass\n",
            "parser": parse_statement,
        },
        # a multi-element SimpleStatementLine
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(semicolon=cst.Semicolon()), cst.Continue())),
            "code":
            "pass;continue\n",
            "parser":
            parse_statement,
        },
        # a multi-element SimpleStatementLine with whitespace
        {
            "node":
            cst.SimpleStatementLine((
                cst.Pass(semicolon=cst.Semicolon(
                    whitespace_before=cst.SimpleWhitespace(" "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                )),
                cst.Continue(),
            )),
            "code":
            "pass ;  continue\n",
            "parser":
            parse_statement,
        },
        # A more complicated SimpleStatementLine
        {
            "node":
            cst.SimpleStatementLine((
                cst.Pass(semicolon=cst.Semicolon()),
                cst.Continue(semicolon=cst.Semicolon()),
                cst.Break(),
            )),
            "code":
            "pass;continue;break\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 19)),
        },
        # a multi-element SimpleStatementLine, inferred semicolons
        {
            "node":
            cst.SimpleStatementLine((cst.Pass(), cst.Continue(), cst.Break())),
            "code":
            "pass; continue; break\n",
            "parser":
            None,  # No test for parsing, since we are using sentinels.
        },
        # some expression statements
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("None")), )),
            "code": "None\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("True")), )),
            "code": "True\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("False")), )),
            "code": "False\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Ellipsis()), )),
            "code": "...\n",
            "parser": parse_statement,
        },
        # Test some numbers
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Integer("5")), )),
            "code": "5\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")), )),
            "code": "5.5\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")), )),
            "code": "5j\n",
            "parser": parse_statement,
        },
        # Test some numbers with parens
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Integer("5",
                            lpar=(cst.LeftParen(), ),
                            rpar=(cst.RightParen(), ))), )),
            "code":
            "(5)\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 3)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Float("5.5",
                          lpar=(cst.LeftParen(), ),
                          rpar=(cst.RightParen(), ))), )),
            "code":
            "(5.5)\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Imaginary("5j",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), ))), )),
            "code":
            "(5j)\n",
            "parser":
            parse_statement,
        },
        # Test some strings
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')), )),
            "code":
            '"abc"\n',
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.ConcatenatedString(cst.SimpleString('"abc"'),
                                       cst.SimpleString('"def"'))), )),
            "code":
            '"abc""def"\n',
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.ConcatenatedString(
                    left=cst.SimpleString('"abc"'),
                    whitespace_between=cst.SimpleWhitespace(" "),
                    right=cst.ConcatenatedString(
                        left=cst.SimpleString('"def"'),
                        whitespace_between=cst.SimpleWhitespace(" "),
                        right=cst.SimpleString('"ghi"'),
                    ),
                )), )),
            "code":
            '"abc" "def" "ghi"\n',
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 17)),
        },
        # Test parenthesis rules
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), ))), )),
            "code":
            "(...)\n",
            "parser":
            parse_statement,
        },
        # Test parenthesis with whitespace ownership
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(cst.LeftParen(
                        whitespace_after=cst.SimpleWhitespace(" ")), ),
                    rpar=(cst.RightParen(
                        whitespace_before=cst.SimpleWhitespace(" ")), ),
                )), )),
            "code":
            "( ... )\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace("  ")),
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace("   ")),
                    ),
                    rpar=(
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace("   ")),
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace("  ")),
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace(" ")),
                    ),
                )), )),
            "code":
            "( (  (   ...   )  ) )\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 21)),
        },
        # Test parenthesis rules with expressions
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(cst.LeftParen(
                        whitespace_after=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(),
                            empty_lines=(cst.EmptyLine(
                                comment=cst.Comment("# Wow, a comment!")), ),
                            indent=True,
                            last_line=cst.SimpleWhitespace("    "),
                        )), ),
                    rpar=(cst.RightParen(
                        whitespace_before=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(),
                            empty_lines=(),
                            indent=True,
                            last_line=cst.SimpleWhitespace(""),
                        )), ),
                )), )),
            "code":
            "(\n# Wow, a comment!\n    ...\n)\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (4, 1)),
        },
        # test trailing whitespace
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(), ),
                trailing_whitespace=cst.TrailingWhitespace(
                    whitespace=cst.SimpleWhitespace("  "),
                    comment=cst.Comment("# trailing comment"),
                ),
            ),
            "code":
            "pass  # trailing comment\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 4)),
        },
        # test leading comment
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(), ),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# comment")), ),
            ),
            "code":
            "# comment\npass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((2, 0), (2, 4)),
        },
        # test indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.SimpleStatementLine(
                    (cst.Pass(), ),
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# comment")), ),
                ),
            ),
            "code":
            "    # comment\n    pass\n",
            "expected_position":
            CodeRange.create((2, 4), (2, 8)),
        },
        # test suite variant
        {
            "node": cst.SimpleStatementSuite((cst.Pass(), )),
            "code": " pass\n",
            "expected_position": CodeRange.create((1, 1), (1, 5)),
        },
        {
            "node":
            cst.SimpleStatementSuite(
                (cst.Pass(), ), leading_whitespace=cst.SimpleWhitespace("")),
            "code":
            "pass\n",
            "expected_position":
            CodeRange.create((1, 0), (1, 4)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)
示例#23
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS, kwargs)

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(
            cst.Arg(
                value=a.value,
                keyword=cst.Name(value=ctrl),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            ) for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        if self._use_keywords:
            new_kwargs = [
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=name),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for name, arg in zip(kword_params, args + kwargs)
            ]
            new_kwargs.extend([
                cst.Arg(
                    value=arg.value,
                    keyword=cst.Name(value=arg.keyword.value),
                    equal=cst.AssignEqual(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                ) for arg in ctrl_kwargs
            ])
            return updated.with_changes(args=new_kwargs)
        else:
            request_arg = cst.Arg(
                value=cst.Dict([
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(name)),
                        cst.Element(value=arg.value),
                    ) for name, arg in zip(kword_params, args + kwargs)
                ] + [
                    cst.DictElement(
                        cst.SimpleString('"{}"'.format(arg.keyword.value)),
                        cst.Element(value=arg.value),
                    ) for arg in ctrl_kwargs
                ]),
                keyword=cst.Name("request"),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after=cst.SimpleWhitespace(""),
                ),
            )

            return updated.with_changes(args=[request_arg])
示例#24
0
 def leave_Float(self, original_node: cst.Float, updated_node: cst.Float):
     return cst.SimpleString(value="\"[number]\"")
示例#25
0
class AtomTest(CSTNodeTest):
    @data_provider((
        # Simple identifier
        {
            "node": cst.Name("test"),
            "code": "test",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Parenthesized identifier
        {
            "node":
            cst.Name("test",
                     lpar=(cst.LeftParen(), ),
                     rpar=(cst.RightParen(), )),
            "code":
            "(test)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 5)),
        },
        # Decimal integers
        {
            "node": cst.Integer("12345"),
            "code": "12345",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("0000"),
            "code": "0000",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("1_234_567"),
            "code": "1_234_567",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("0_000"),
            "code": "0_000",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Binary integers
        {
            "node": cst.Integer("0b0000"),
            "code": "0b0000",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("0B1011_0100"),
            "code": "0B1011_0100",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Octal integers
        {
            "node": cst.Integer("0o12345"),
            "code": "0o12345",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("0O12_345"),
            "code": "0O12_345",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Hex numbers
        {
            "node": cst.Integer("0x123abc"),
            "code": "0x123abc",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Integer("0X12_3ABC"),
            "code": "0X12_3ABC",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Parenthesized integers
        {
            "node":
            cst.Integer("123",
                        lpar=(cst.LeftParen(), ),
                        rpar=(cst.RightParen(), )),
            "code":
            "(123)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 4)),
        },
        # Non-exponent floats
        {
            "node": cst.Float("12345."),
            "code": "12345.",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("00.00"),
            "code": "00.00",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("12.21"),
            "code": "12.21",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float(".321"),
            "code": ".321",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("1_234_567."),
            "code": "1_234_567.",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("0.000_000"),
            "code": "0.000_000",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Exponent floats
        {
            "node": cst.Float("12345.e10"),
            "code": "12345.e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("00.00e10"),
            "code": "00.00e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("12.21e10"),
            "code": "12.21e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float(".321e10"),
            "code": ".321e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("1_234_567.e10"),
            "code": "1_234_567.e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("0.000_000e10"),
            "code": "0.000_000e10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("1e+10"),
            "code": "1e+10",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Float("1e-10"),
            "code": "1e-10",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Parenthesized floats
        {
            "node":
            cst.Float("123.4",
                      lpar=(cst.LeftParen(), ),
                      rpar=(cst.RightParen(), )),
            "code":
            "(123.4)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 6)),
        },
        # Imaginary numbers
        {
            "node": cst.Imaginary("12345j"),
            "code": "12345j",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Imaginary("1_234_567J"),
            "code": "1_234_567J",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Imaginary("12345.e10j"),
            "code": "12345.e10j",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.Imaginary(".321J"),
            "code": ".321J",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Parenthesized imaginary
        {
            "node":
            cst.Imaginary("123.4j",
                          lpar=(cst.LeftParen(), ),
                          rpar=(cst.RightParen(), )),
            "code":
            "(123.4j)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 7)),
        },
        # Simple elipses
        {
            "node": cst.Ellipsis(),
            "code": "...",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Parenthesized elipses
        {
            "node":
            cst.Ellipsis(lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )),
            "code":
            "(...)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 4)),
        },
        # Simple strings
        {
            "node": cst.SimpleString('""'),
            "code": '""',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.SimpleString("''"),
            "code": "''",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.SimpleString('"test"'),
            "code": '"test"',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.SimpleString('b"test"'),
            "code": 'b"test"',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.SimpleString('r"test"'),
            "code": 'r"test"',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.SimpleString('"""test"""'),
            "code": '"""test"""',
            "parser": parse_expression,
            "expected_position": None,
        },
        # Validate parens
        {
            "node":
            cst.SimpleString('"test"',
                             lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), )),
            "code":
            '("test")',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.SimpleString('rb"test"',
                             lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), )),
            "code":
            '(rb"test")',
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 9)),
        },
        # Test that _safe_to_use_with_word_operator allows no space around quotes
        {
            "node":
            cst.Comparison(
                cst.SimpleString('"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.SimpleString('"abc"'),
                    )
                ],
            ),
            "code":
            '"a"in"abc"',
            "parser":
            parse_expression,
        },
        {
            "node":
            cst.Comparison(
                cst.SimpleString('"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.ConcatenatedString(cst.SimpleString('"a"'),
                                               cst.SimpleString('"bc"')),
                    )
                ],
            ),
            "code":
            '"a"in"a""bc"',
            "parser":
            parse_expression,
        },
        # Parenthesis make no spaces around a prefix okay
        {
            "node":
            cst.Comparison(
                cst.SimpleString('b"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.SimpleString(
                            'b"abc"',
                            lpar=[cst.LeftParen()],
                            rpar=[cst.RightParen()],
                        ),
                    )
                ],
            ),
            "code":
            'b"a"in(b"abc")',
            "parser":
            parse_expression,
        },
        {
            "node":
            cst.Comparison(
                cst.SimpleString('b"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.ConcatenatedString(
                            cst.SimpleString('b"a"'),
                            cst.SimpleString('b"bc"'),
                            lpar=[cst.LeftParen()],
                            rpar=[cst.RightParen()],
                        ),
                    )
                ],
            ),
            "code":
            'b"a"in(b"a"b"bc")',
            "parser":
            parse_expression,
        },
        # Empty formatted strings
        {
            "node": cst.FormattedString(start='f"', parts=(), end='"'),
            "code": 'f""',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.FormattedString(start="f'", parts=(), end="'"),
            "code": "f''",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.FormattedString(start='f"""', parts=(), end='"""'),
            "code": 'f""""""',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node": cst.FormattedString(start="f'''", parts=(), end="'''"),
            "code": "f''''''",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Non-empty formatted strings
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringText("foo"), )),
            "code": 'f"foo"',
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.FormattedString(
                parts=(cst.FormattedStringExpression(cst.Name("foo")), )),
            "code":
            'f"{foo}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(
                cst.FormattedStringText("foo "),
                cst.FormattedStringExpression(cst.Name("bar")),
                cst.FormattedStringText(" baz"),
            )),
            "code":
            'f"foo {bar} baz"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(
                cst.FormattedStringText("foo "),
                cst.FormattedStringExpression(cst.Call(cst.Name("bar"))),
                cst.FormattedStringText(" baz"),
            )),
            "code":
            'f"foo {bar()} baz"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Formatted strings with conversions and format specifiers
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("foo"), conversion="s"), )),
            "code":
            'f"{foo!s}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("foo"), format_spec=()), )),
            "code":
            'f"{foo:}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("today"),
                format_spec=(cst.FormattedStringText("%B %d, %Y"), ),
            ), )),
            "code":
            'f"{today:%B %d, %Y}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("foo"),
                format_spec=(cst.FormattedStringExpression(cst.Name("bar")), ),
            ), )),
            "code":
            'f"{foo:{bar}}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("foo"),
                format_spec=(
                    cst.FormattedStringExpression(cst.Name("bar")),
                    cst.FormattedStringText("."),
                    cst.FormattedStringExpression(cst.Name("baz")),
                ),
            ), )),
            "code":
            'f"{foo:{bar}.{baz}}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.FormattedString(parts=(cst.FormattedStringExpression(
                cst.Name("foo"),
                conversion="s",
                format_spec=(cst.FormattedStringExpression(cst.Name("bar")), ),
            ), )),
            "code":
            'f"{foo!s:{bar}}"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Validate parens
        {
            "node":
            cst.FormattedString(
                start='f"',
                parts=(),
                end='"',
                lpar=(cst.LeftParen(), ),
                rpar=(cst.RightParen(), ),
            ),
            "code":
            '(f"")',
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 4)),
        },
        # Concatenated strings
        {
            "node":
            cst.ConcatenatedString(cst.SimpleString('"ab"'),
                                   cst.SimpleString('"c"')),
            "code":
            '"ab""c"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.ConcatenatedString(
                cst.SimpleString('"ab"'),
                cst.ConcatenatedString(cst.SimpleString('"c"'),
                                       cst.SimpleString('"d"')),
            ),
            "code":
            '"ab""c""d"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # mixed SimpleString and FormattedString
        {
            "node":
            cst.ConcatenatedString(
                cst.FormattedString([cst.FormattedStringText("ab")]),
                cst.SimpleString('"c"'),
            ),
            "code":
            'f"ab""c"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.ConcatenatedString(
                cst.SimpleString('"ab"'),
                cst.FormattedString([cst.FormattedStringText("c")]),
            ),
            "code":
            '"ab"f"c"',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Concatenated parenthesized strings
        {
            "node":
            cst.ConcatenatedString(
                lpar=(cst.LeftParen(), ),
                left=cst.SimpleString('"ab"'),
                right=cst.SimpleString('"c"'),
                rpar=(cst.RightParen(), ),
            ),
            "code":
            '("ab""c")',
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Validate spacing
        {
            "node":
            cst.ConcatenatedString(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                left=cst.SimpleString('"ab"'),
                whitespace_between=cst.SimpleWhitespace(" "),
                right=cst.SimpleString('"c"'),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            '( "ab" "c" )',
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 10)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider(({
        "node":
        cst.FormattedStringExpression(
            cst.Name("today"),
            format_spec=(cst.FormattedStringText("%B %d, %Y"), ),
        ),
        "code":
        "{today:%B %d, %Y}",
        "parser":
        None,
        "expected_position":
        CodeRange((1, 0), (1, 17)),
    }, ))
    def test_valid_no_parse(self, **kwargs: Any) -> None:
        # Test some nodes that aren't valid source code by themselves
        self.validate_node(**kwargs)

    @data_provider((
        # Expression wrapping parenthesis rules
        {
            "get_node": (lambda: cst.Name("foo", lpar=(cst.LeftParen(), ))),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node": lambda: cst.Name("foo", rpar=(cst.RightParen(), )),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node": lambda: cst.Ellipsis(lpar=(cst.LeftParen(), )),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node": lambda: cst.Ellipsis(rpar=(cst.RightParen(), )),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node": lambda: cst.Integer("5", lpar=(cst.LeftParen(), )),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node": lambda: cst.Integer("5", rpar=(cst.RightParen(), )),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node": lambda: cst.Float("5.5", lpar=(cst.LeftParen(), )),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node": lambda: cst.Float("5.5", rpar=(cst.RightParen(), )),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node":
            (lambda: cst.Imaginary("5j", lpar=(cst.LeftParen(), ))),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node":
            (lambda: cst.Imaginary("5j", rpar=(cst.RightParen(), ))),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node": (lambda: cst.Integer("5", lpar=(cst.LeftParen(), ))),
            "expected_re": "left paren without right paren",
        },
        {
            "get_node": (lambda: cst.Integer("5", rpar=(cst.RightParen(), ))),
            "expected_re": "right paren without left paren",
        },
        {
            "get_node":
            (lambda: cst.SimpleString("'foo'", lpar=(cst.LeftParen(), ))),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            (lambda: cst.SimpleString("'foo'", rpar=(cst.RightParen(), ))),
            "expected_re":
            "right paren without left paren",
        },
        {
            "get_node":
            (lambda: cst.FormattedString(parts=(), lpar=(cst.LeftParen(), ))),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            (lambda: cst.FormattedString(parts=(), rpar=(cst.RightParen(), ))),
            "expected_re":
            "right paren without left paren",
        },
        {
            "get_node": (lambda: cst.ConcatenatedString(
                cst.SimpleString("'foo'"),
                cst.SimpleString("'foo'"),
                lpar=(cst.LeftParen(), ),
            )),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node": (lambda: cst.ConcatenatedString(
                cst.SimpleString("'foo'"),
                cst.SimpleString("'foo'"),
                rpar=(cst.RightParen(), ),
            )),
            "expected_re":
            "right paren without left paren",
        },
        # Node-specific rules
        {
            "get_node": (lambda: cst.Name("")),
            "expected_re": "empty name identifier",
        },
        {
            "get_node": (lambda: cst.Name(r"\/")),
            "expected_re": "not a valid identifier",
        },
        {
            "get_node": (lambda: cst.Integer("")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("012345")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("012345")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("_12345")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("0b2")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("0o8")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("0xg")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("123.45")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Integer("12345j")),
            "expected_re": "not a valid integer",
        },
        {
            "get_node": (lambda: cst.Float("12.3.45")),
            "expected_re": "not a valid float",
        },
        {
            "get_node": (lambda: cst.Float("12")),
            "expected_re": "not a valid float"
        },
        {
            "get_node": (lambda: cst.Float("12.3j")),
            "expected_re": "not a valid float",
        },
        {
            "get_node": (lambda: cst.Imaginary("_12345j")),
            "expected_re": "not a valid imaginary",
        },
        {
            "get_node": (lambda: cst.Imaginary("0b0j")),
            "expected_re": "not a valid imaginary",
        },
        {
            "get_node": (lambda: cst.Imaginary("0o0j")),
            "expected_re": "not a valid imaginary",
        },
        {
            "get_node": (lambda: cst.Imaginary("0x0j")),
            "expected_re": "not a valid imaginary",
        },
        {
            "get_node": (lambda: cst.SimpleString('wee""')),
            "expected_re": "Invalid string prefix",
        },
        {
            "get_node": (lambda: cst.SimpleString("'")),
            "expected_re": "must have enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString('"')),
            "expected_re": "must have enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString("\"'")),
            "expected_re": "must have matching enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString("")),
            "expected_re": "must have enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString("'bla")),
            "expected_re": "must have matching enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString("f''")),
            "expected_re": "Invalid string prefix",
        },
        {
            "get_node": (lambda: cst.SimpleString("'''bla''")),
            "expected_re": "must have matching enclosing quotes",
        },
        {
            "get_node": (lambda: cst.SimpleString("'''bla\"\"\"")),
            "expected_re": "must have matching enclosing quotes",
        },
        {
            "get_node":
            (lambda: cst.FormattedString(start="'", parts=(), end="'")),
            "expected_re": "Invalid f-string prefix",
        },
        {
            "get_node":
            (lambda: cst.FormattedString(start="f'", parts=(), end='"')),
            "expected_re":
            "must have matching enclosing quotes",
        },
        {
            "get_node": (lambda: cst.ConcatenatedString(
                cst.SimpleString('"ab"',
                                 lpar=(cst.LeftParen(), ),
                                 rpar=(cst.RightParen(), )),
                cst.SimpleString('"c"'),
            )),
            "expected_re":
            "Cannot concatenate parenthesized",
        },
        {
            "get_node": (lambda: cst.ConcatenatedString(
                cst.SimpleString('"ab"'),
                cst.SimpleString('"c"',
                                 lpar=(cst.LeftParen(), ),
                                 rpar=(cst.RightParen(), )),
            )),
            "expected_re":
            "Cannot concatenate parenthesized",
        },
        {
            "get_node": (lambda: cst.ConcatenatedString(
                cst.SimpleString('"ab"'), cst.SimpleString('b"c"'))),
            "expected_re":
            "Cannot concatenate string and bytes",
        },
        # This isn't valid code: `"a" inb"abc"`
        {
            "get_node": (lambda: cst.Comparison(
                cst.SimpleString('"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(whitespace_after=cst.SimpleWhitespace("")),
                        cst.SimpleString('b"abc"'),
                    )
                ],
            )),
            "expected_re":
            "Must have at least one space around comparison operator.",
        },
        # Also not valid: `"a" in b"a"b"bc"`
        {
            "get_node": (lambda: cst.Comparison(
                cst.SimpleString('"a"'),
                [
                    cst.ComparisonTarget(
                        cst.In(whitespace_after=cst.SimpleWhitespace("")),
                        cst.ConcatenatedString(cst.SimpleString('b"a"'),
                                               cst.SimpleString('b"bc"')),
                    )
                ],
            )),
            "expected_re":
            "Must have at least one space around comparison operator.",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
示例#26
0
 def leave_Integer(self, original_node: cst.Integer,
                   updated_node: cst.Integer):
     return cst.SimpleString(value="\"[number]\"")
示例#27
0
class AssertConstructionTest(CSTNodeTest):
    @data_provider((
        # Simple assert
        {
            "node": cst.Assert(cst.Name("True")),
            "code": "assert True",
            "parser": None,
            "expected_position": None,
        },
        # Assert with message
        {
            "node":
            cst.Assert(cst.Name("True"),
                       cst.SimpleString('"Value should be true"')),
            "code":
            'assert True, "Value should be true"',
            "parser":
            None,
            "expected_position":
            None,
        },
        # Whitespace oddities test
        {
            "node":
            cst.Assert(
                cst.Name("True",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), )),
                whitespace_after_assert=cst.SimpleWhitespace(""),
            ),
            "code":
            "assert(True)",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 12)),
        },
        # Whitespace rendering test
        {
            "node":
            cst.Assert(
                whitespace_after_assert=cst.SimpleWhitespace("  "),
                test=cst.Name("True"),
                comma=cst.Comma(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                msg=cst.SimpleString('"Value should be true"'),
            ),
            "code":
            'assert  True  ,  "Value should be true"',
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 39)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Validate whitespace handling
        {
            "get_node": (lambda: cst.Assert(
                cst.Name("True"),
                whitespace_after_assert=cst.SimpleWhitespace(""),
            )),
            "expected_re":
            "Must have at least one space after 'assert'",
        },
        # Validate comma handling
        {
            "get_node":
            (lambda: cst.Assert(test=cst.Name("True"), comma=cst.Comma())),
            "expected_re":
            "Cannot have trailing comma after 'test'",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
示例#28
0
class LambdaCreationTest(CSTNodeTest):
    @data_provider((
        # Simple lambda
        (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"),
        # Test basic positional params
        (
            cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar")),
                                       cst.Param(cst.Name("baz")))),
                cst.Integer("5"),
            ),
            "lambda bar, baz: 5",
        ),
        # Test basic positional default params
        (
            cst.Lambda(
                cst.Parameters(default_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz"), default=cst.Integer("5")),
                )),
                cst.Integer("5"),
            ),
            'lambda bar = "one", baz = 5: 5',
        ),
        # Mixed positional and default params.
        (
            cst.Lambda(
                cst.Parameters(
                    params=(cst.Param(cst.Name("bar")), ),
                    default_params=(cst.Param(cst.Name("baz"),
                                              default=cst.Integer("5")), ),
                ),
                cst.Integer("5"),
            ),
            "lambda bar, baz = 5: 5",
        ),
        # Test kwonly params
        (
            cst.Lambda(
                cst.Parameters(kwonly_params=(
                    cst.Param(cst.Name("bar"),
                              default=cst.SimpleString('"one"')),
                    cst.Param(cst.Name("baz")),
                )),
                cst.Integer("5"),
            ),
            'lambda *, bar = "one", baz: 5',
        ),
        # Mixed params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed default_params and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    default_params=(
                        cst.Param(cst.Name("first"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("second"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params, default_params, and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5',
            CodeRange((1, 0), (1, 84)),
        ),
        # Test star_arg
        (
            cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("params"))),
                cst.Integer("5"),
            ),
            "lambda *params: 5",
        ),
        # Typed star_arg, include kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda *params, bar = "one", baz, biz = "two": 5',
        ),
        # Mixed params default_params, star_arg and kwonly_params
        (
            cst.Lambda(
                cst.Parameters(
                    params=(
                        cst.Param(cst.Name("first")),
                        cst.Param(cst.Name("second")),
                    ),
                    default_params=(
                        cst.Param(cst.Name("third"), default=cst.Float("1.0")),
                        cst.Param(cst.Name("fourth"),
                                  default=cst.Float("1.5")),
                    ),
                    star_arg=cst.Param(cst.Name("params")),
                    kwonly_params=(
                        cst.Param(cst.Name("bar"),
                                  default=cst.SimpleString('"one"')),
                        cst.Param(cst.Name("baz")),
                        cst.Param(cst.Name("biz"),
                                  default=cst.SimpleString('"two"')),
                    ),
                ),
                cst.Integer("5"),
            ),
            'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5',
        ),
        # Test star_arg and star_kwarg
        (
            cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))),
                cst.Integer("5"),
            ),
            "lambda **kwparams: 5",
        ),
        # Test star_arg and kwarg
        (
            cst.Lambda(
                cst.Parameters(
                    star_arg=cst.Param(cst.Name("params")),
                    star_kwarg=cst.Param(cst.Name("kwparams")),
                ),
                cst.Integer("5"),
            ),
            "lambda *params, **kwparams: 5",
        ),
        # Inner whitespace
        (
            cst.Lambda(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                whitespace_after_lambda=cst.SimpleWhitespace("  "),
                params=cst.Parameters(),
                colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")),
                body=cst.Integer("5"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( lambda  : 5 )",
            CodeRange((1, 2), (1, 13)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"), default=cst.Integer("5")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "at least one space after lambda",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    equal=cst.AssignEqual())),
                cst.Integer("5"),
            ),
            "Must have a default when specifying an AssignEqual.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"),
                                                    star="***")),
                cst.Integer("5"),
            ),
            r"Must specify either '', '\*' or '\*\*' for star.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("bar"), default=cst.SimpleString('"one"')), )),
                cst.Integer("5"),
            ),
            "Cannot have defaults for params",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )),
                cst.Integer("5"),
            ),
            "Must have defaults for default_params",
        ),
        (
            lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()),
                               cst.Integer("5")),
            "Must have at least one kwonly param if ParamStar is used.",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), )
                               ),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("bar"),
                    default=cst.SimpleString('"one"'),
                    star="*",
                ), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"),
                                                        star="*"), )),
                cst.Integer("5"),
            ),
            "Expecting a star prefix of ''",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*")
                               ),
                cst.Integer("5"),
            ),
            r"Expecting a star prefix of '\*\*'",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(default_params=(cst.Param(
                    cst.Name("arg"),
                    default=cst.Integer("5"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_arg=cst.Param(cst.Name("arg"),
                                                  annotation=cst.Annotation(
                                                      cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(kwonly_params=(cst.Param(
                    cst.Name("arg"),
                    annotation=cst.Annotation(cst.Name("str")),
                ), )),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
        (
            lambda: cst.Lambda(
                cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"),
                                                    annotation=cst.Annotation(
                                                        cst.Name("str")))),
                cst.Integer("5"),
                whitespace_after_lambda=cst.SimpleWhitespace(""),
            ),
            "Lambda params cannot have type annotations",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
示例#29
0
 def docstring(self, indent: int) -> cst.BaseExpression:
     i = " " * ((indent + 1) * 4)
     inner = f"\n{i}".join(metadata_lines(self.metadata))
     return cst.SimpleString(f'"""\n{i}{inner}\n{i}"""')
示例#30
0
 def leave_Imaginary(self, original_node: cst.Imaginary,
                     updated_node: cst.Imaginary):
     return cst.SimpleString(value="\"[number]\"")