def test_decorators(self) -> None:
     # Test that we can special-case decorators when needed.
     statement = parse_template_statement(
         "@{decorator}\ndef foo(): pass\n", decorator=cst.Name("bar"),
     )
     self.assertEqual(
         self.code(statement), "@bar\ndef foo(): pass\n",
     )
     statement = parse_template_statement(
         "@{decorator}\ndef foo(): pass\n", decorator=cst.Decorator(cst.Name("bar")),
     )
     self.assertEqual(
         self.code(statement), "@bar\ndef foo(): pass\n",
     )
Exemplo n.º 2
0
    def test_parameters(self) -> None:
        # Test that we can insert a parameter into a function def normally.
        statement = parse_template_statement(
            "def foo({arg}): pass",
            arg=cst.Name("bar"),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test that we can insert a parameter as a special case.
        statement = parse_template_statement(
            "def foo({arg}): pass",
            arg=cst.Param(cst.Name("bar")),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test that we can insert a parameters list as a special case.
        statement = parse_template_statement(
            "def foo({args}): pass",
            args=cst.Parameters(
                (cst.Param(cst.Name("bar")),),
            ),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar): pass\n",
        )

        # Test filling out multiple parameters
        statement = parse_template_statement(
            "def foo({args}): pass",
            args=cst.Parameters(
                params=(
                    cst.Param(cst.Name("bar")),
                    cst.Param(cst.Name("baz")),
                ),
                star_kwarg=cst.Param(cst.Name("rest")),
            ),
        )
        self.assertEqual(
            self.code(statement),
            "def foo(bar, baz, **rest): pass\n",
        )
    def test_annotation(self) -> None:
        # Test that we can insert an annotation expression normally.
        statement = parse_template_statement(
            "x: {type} = {val}", type=cst.Name("int"), val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "x: int = 5\n",
        )

        # Test that we can insert an annotation node as a special case.
        statement = parse_template_statement(
            "x: {type} = {val}",
            type=cst.Annotation(cst.Name("int")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "x: int = 5\n",
        )
 def test_simple_statement(self) -> None:
     statement = parse_template_statement(
         "assert {test}, {msg}\n",
         test=cst.Name("True"),
         msg=cst.SimpleString('"Somehow True is no longer True..."'),
     )
     self.assertEqual(
         self.code(statement), 'assert True, "Somehow True is no longer True..."\n',
     )
    def test_assign_target(self) -> None:
        # Test that we can insert an assignment target normally.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.Name("first"),
            b=cst.Name("second"),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )

        # Test that we can insert an assignment target as a special case.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.AssignTarget(cst.Name("first")),
            b=cst.AssignTarget(cst.Name("second")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )
Exemplo n.º 6
0
    def leave_FunctionDef(
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> Union[cst.BaseStatement, cst.RemovalSentinel]:
        modified_defaults: List = []
        mutable_args: List[Tuple[cst.Name, Union[cst.List, cst.Dict]]] = []

        for param in updated_node.params.params:
            if not m.matches(param,
                             m.Param(default=m.OneOf(m.List(), m.Dict()))):
                modified_defaults.append(param)
                continue

            # This line here is just for type checkers peace of mind,
            # since it cannot reason about variables from matchers result.
            if not isinstance(param.default, (cst.List, cst.Dict)):
                continue

            mutable_args.append((param.name, param.default))
            modified_defaults.append(
                param.with_changes(default=cst.Name("None"), ))

        if not mutable_args:
            return original_node

        modified_params: cst.Parameters = updated_node.params.with_changes(
            params=modified_defaults)

        initializations: List[Union[
            cst.SimpleStatementLine, cst.BaseCompoundStatement]] = [
                # We use generation by template here since construction of the
                # resulting 'if' can be burdensome due to many nested objects
                # involved. Additional line is attached so that we may control
                # exact spacing between generated statements.
                parse_template_statement(
                    DEFAULT_INIT_TEMPLATE,
                    config=self.module_config,
                    arg=arg,
                    init=init).with_changes(leading_lines=[EMPTY_LINE])
                for arg, init in mutable_args
            ]

        # Docstring should always go right after the function definition,
        # so we take special care to insert our initializations after the
        # last docstring found.
        docstrings = takewhile(is_docstring, updated_node.body.body)
        function_code = dropwhile(is_docstring, updated_node.body.body)

        # It is not possible to insert empty line after the statement line,
        # because whitespace is owned by the next statement after it.
        stmt_with_empty_line = next(function_code).with_changes(
            leading_lines=[EMPTY_LINE])

        modified_body = (
            *docstrings,
            *initializations,
            stmt_with_empty_line,
            *function_code,
        )

        return updated_node.with_changes(
            params=modified_params,
            body=updated_node.body.with_changes(body=modified_body),
        )