コード例 #1
0
    def test_gen_dotted_names(self) -> None:
        names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))}
        self.assertEqual(names, {"a"})

        names = {
            name
            for name, node in _gen_dotted_names(
                cst.Attribute(value=cst.Name(value="a"),
                              attr=cst.Name(value="b")))
        }
        self.assertEqual(names, {"a.b", "a"})

        names = {
            name
            for name, node in _gen_dotted_names(
                cst.Attribute(
                    value=cst.Call(
                        func=cst.Attribute(
                            value=cst.Attribute(value=cst.Name(value="a"),
                                                attr=cst.Name(value="b")),
                            attr=cst.Name(value="c"),
                        ),
                        args=[],
                    ),
                    attr=cst.Name(value="d"),
                ))
        }
        self.assertEqual(names, {"a.b.c", "a.b", "a"})
コード例 #2
0
ファイル: cst_utc.py プロジェクト: KGerring/metaproj
	def datetime_datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Attribute(
								value=cst.Name(value="datetime"),
								attr=cst.Name(value="datetime"),
						),
						attr=cst.Name(value="now"),
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
コード例 #3
0
ファイル: target_functions.py プロジェクト: shwinnn/mcx
 def sampleop_to_logpdf(cst_generator, *args, **kwargs):
     name = kwargs.pop("var_name")
     return cst.Call(
         cst.Attribute(cst_generator(*args, **kwargs),
                       cst.Name("logpdf_sum")),
         [cst.Arg(name)],
     )
コード例 #4
0
 def leave_SimpleString(
     self, original_node: cst.SimpleString, updated_node: cst.SimpleString
 ) -> Union[cst.SimpleString, cst.Attribute]:
     value = updated_node.evaluated_value
     if value in CST_DIR:
         return cst.Attribute(cst.Name("cst"), cst.Name(value))
     return updated_node
コード例 #5
0
ファイル: target_functions.py プロジェクト: tblazina/mcx
 def choice_ast(rng_key):
     return cst.Call(
         func=cst.Attribute(
             value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")),
             attr=cst.Name("choice"),
         ),
         args=[
             cst.Arg(rng_key),
             cst.Arg(
                 cst.Subscript(
                     cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")),
                     [cst.SubscriptElement(cst.Index(cst.Integer("0")))],
                 )
             ),
         ],
     )
コード例 #6
0
ファイル: target_functions.py プロジェクト: shwinnn/mcx
 def to_sampler(cst_generator, *args, **kwargs):
     rng_key = kwargs.pop("rng_key")
     return cst.Call(
         func=cst.Attribute(value=cst_generator(*args, **kwargs),
                            attr=cst.Name("sample")),
         args=[cst.Arg(value=rng_key)],
     )
コード例 #7
0
    def test_with_dots(self) -> None:
        self.assertEqual("foo", util.with_dots(cst.Name(value="foo")))
        self.assertEqual(
            "foo.bar.baz",
            util.with_dots(
                cst.Attribute(
                    value=cst.Attribute(
                        value=cst.Name("foo"),
                        attr=cst.Name("bar"),
                    ),
                    attr=cst.Name("baz"),
                )),
        )

        with self.assertRaisesRegex(TypeError, "Can't with_dots"):
            util.with_dots("foo.bar")  # type: ignore
コード例 #8
0
 def pluck_asyncio_gather_expression_from_yield_list_or_list_comp(
     node: cst.Yield, ) -> cst.BaseExpression:
     return cst.Call(
         func=cst.Attribute(value=cst.Name("asyncio"),
                            attr=cst.Name("gather")),
         args=[cst.Arg(value=node.value, star="*")],
     )
コード例 #9
0
 def get_name_node(name: str) -> Union[cst.Name, cst.Attribute]:
     # Inverse `_get_alias_name`.
     if "." not in name:
         return cst.Name(name)
     names = name.split(".")
     value = get_name_node(".".join(names[:-1]))
     attr = get_name_node(names[-1])
     return cst.Attribute(value=value, attr=attr)  # type: ignore
コード例 #10
0
 def annotation(self) -> typing.Union[cst.Name, cst.Attribute]:
     first_name, *rest = (self.module.split(".") +
                          [self.name] if self.module else [self.name])
     try:
         expr: typing.Union[cst.Name, cst.Attribute] = cst.Name(first_name)
         for name in rest:
             expr = cst.Attribute(expr, cst.Name(name))
     except cst._nodes.base.CSTValidationError:
         return cst.Name("Unknown")
     return expr
コード例 #11
0
    def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
        if not self.in_coroutine(self.coroutine_stack):
            return updated_node

        if m.matches(updated_node, gen_sleep_matcher):
            self.required_imports.add("asyncio")
            return updated_node.with_changes(func=cst.Attribute(
                value=cst.Name("asyncio"), attr=cst.Name("sleep")))

        return updated_node
コード例 #12
0
    def leave_SimpleString(
        self, original_node: cst.SimpleString, updated_node: cst.SimpleString
    ) -> Union[cst.SimpleString, cst.Attribute]:
        try:
            value = ast.literal_eval(updated_node.value)
        except SyntaxError:
            return updated_node

        if value in CST_DIR:
            return cst.Attribute(cst.Name("cst"), cst.Name(value))
        return updated_node
コード例 #13
0
ファイル: six_io.py プロジェクト: jhance/py3ify
 def leave_Attribute(
     self, original: libcst.Attribute, updated: libcst.Attribute
 ) -> Any:
     if m.matches(updated.value, m.Name("six")):
         if m.matches(updated.attr, m.Name()):
             if updated.attr.value in _IO_OBJECTS:
                 return libcst.Attribute(
                     value=libcst.Name("io"),
                     attr=updated.attr,
                 )
     return updated
コード例 #14
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Matches calls with symbols without the wx prefix
        for symbol, matcher, renamed in self.matchers_short_map:
            if symbol in self.wx_imports and matchers.matches(
                    updated_node, matcher):
                # Remove the symbol's import
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)

                # Add import of top level wx package
                AddImportsVisitor.add_needed_import(self.context, "wx")

                # Return updated node
                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(func=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value=renamed)))

        # Matches full calls like wx.MySymbol
        for matcher, renamed in self.matchers_full_map:
            if matchers.matches(updated_node, matcher):

                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(
                    func=updated_node.func.with_changes(attr=cst.Name(
                        value=renamed)))

        # Returns updated node
        return updated_node
コード例 #15
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
    def leave_Attribute(self, original_node: cst.Attribute,
                        updated_node: cst.Attribute) -> cst.Attribute:
        for matcher in self.matchers:
            if matchers.matches(updated_node, matcher):
                # Ensure that wx.adv is imported
                AddImportsVisitor.add_needed_import(self.context, "wx.adv")

                # Return modified node
                return updated_node.with_changes(value=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value="adv")))

        return updated_node
コード例 #16
0
 def leave_Call(self, original: libcst.Call,
                updated: libcst.Call) -> libcst.Call:
     if m.matches(updated.func.value, m.Name("six")):
         for orig_name, updated_name in _CONVERSION_MAP.items():
             if m.matches(updated.func.attr, m.Name(orig_name)):
                 if len(updated.args) != 1:
                     self.warn(
                         f"Odd six.{orig_name} call does not have one argument. Cannot perform substitution."
                     )
                     continue
                 value = updated.args[0].value
                 return libcst.Call(func=libcst.Attribute(
                     value=value,
                     attr=libcst.Name(value=updated_name),
                 ))
     return updated
コード例 #17
0
 def leave_Subscript(
     self,
     original_node: libcst.Subscript,
     updated_node: Union[libcst.Subscript, libcst.SimpleString],
 ) -> Union[libcst.Subscript, libcst.SimpleString]:
     if libcst.matchers.matches(original_node.value,
                                libcst.matchers.Name("PathLike")):
         name_node = libcst.Attribute(
             value=libcst.Name(
                 value="os",
                 lpar=[],
                 rpar=[],
             ),
             attr=libcst.Name(value="PathLike"),
         )
         node_as_string = libcst.parse_module("").code_for_node(
             updated_node.with_changes(value=name_node))
         updated_node = libcst.SimpleString(f"'{node_as_string}'")
     return updated_node
コード例 #18
0
ファイル: test_visitors.py プロジェクト: leonardt/ast_tools
def test_collect_targets():
    tree = cst.parse_module('''
x = [0, 1]
x[0] = 1
x.attr = 2
''')
    x = cst.Name(value='x')
    x0 = cst.Subscript(
        value=x,
        slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))],
    )
    xa = cst.Attribute(
        value=x,
        attr=cst.Name('attr'),
    )

    golds = x, x0, xa

    targets = collect_targets(tree)
    assert all(t.deep_equals(g) for t, g in zip(targets, golds))
コード例 #19
0
ファイル: six_constants.py プロジェクト: jhance/py3ify
import libcst
import libcst.matchers as m

from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand

_CONVERSION_MAP = {
    "class_types":
    libcst.Name("type"),
    "integer_types":
    libcst.Name("int"),
    "string_types":
    libcst.Name("str"),
    "text_type":
    libcst.Name("str"),
    "binary_type":
    libcst.Name("bytes"),
    # TODO need to import sys automatically if we do this
    "MAXSIZE":
    libcst.Attribute(value=libcst.Name("sys"), attr=libcst.Name("maxsize")),
}


class ConvertSixConstants(VisitorBasedCodemodCommand):
    def leave_Attribute(self, original: libcst.Attribute,
                        updated: libcst.Attribute) -> Any:
        if m.matches(updated.value, m.Name("six")):
            if m.matches(updated.attr, m.Name()):
                return _CONVERSION_MAP.get(updated.attr.value, updated)
        return updated
コード例 #20
0
ファイル: test_import.py プロジェクト: petersktang/LibCST
class ImportCreateTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
                "expected_position": CodeRange((1, 0), (1, 23)),
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
                "expected_position": CodeRange((1, 0), (1, 46)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.Import(names=()),
                "expected_re": "at least one ImportAlias",
            },
            {
                "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(),
                        ),
                    )
                ),
                "expected_re": "trailing comma",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "at least one space",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
コード例 #21
0
ファイル: test_import.py プロジェクト: petersktang/LibCST
class ImportParseTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(
                parse_statement(code), cst.SimpleStatementLine
            ).body[0],
            **kwargs,
        )
コード例 #22
0
 def to_attribute_cst(value, attr):
     return cst.Attribute(value, attr)
コード例 #23
0
    def visit_Call(self, node: cst.Call) -> None:
        # Todo: Make use of single extract instead of having several
        # if else statemenets to make the code more robust and readable.
        if m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(operator=m.In())
                            ]))
                    ],
                ),
        ):
            # self.assertTrue(a in b) -> self.assertIn(a, b)
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIn")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left),
                    cst.Arg(
                        ensure_type(node.args[0].value,
                                    cst.Comparison).comparisons[0].comparator),
                ],
            )
            self.report(node, replacement=new_call)
        else:
            # ... -> self.assertNotIn(a, b)
            matched, arg1, arg2 = False, None, None
            if m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertTrue")),
                        args=[
                            m.Arg(
                                m.UnaryOperation(
                                    operator=m.Not(),
                                    expression=m.Comparison(comparisons=[
                                        m.ComparisonTarget(operator=m.In())
                                    ]),
                                ))
                        ],
                    ),
            ):
                # self.assertTrue(not a in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(
                        ensure_type(node.args[0].value,
                                    cst.UnaryOperation).expression,
                        cst.Comparison,
                    ).left)
                arg2 = cst.Arg(
                    ensure_type(
                        ensure_type(node.args[0].value,
                                    cst.UnaryOperation).expression,
                        cst.Comparison,
                    ).comparisons[0].comparator)
            elif m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertTrue")),
                        args=[
                            m.Arg(
                                m.Comparison(comparisons=[
                                    m.ComparisonTarget(m.NotIn())
                                ]))
                        ],
                    ),
            ):
                # self.assertTrue(a not in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(node.args[0].value, cst.Comparison).left)
                arg2 = cst.Arg(
                    ensure_type(node.args[0].value,
                                cst.Comparison).comparisons[0].comparator)
            elif m.matches(
                    node,
                    m.Call(
                        func=m.Attribute(value=m.Name("self"),
                                         attr=m.Name("assertFalse")),
                        args=[
                            m.Arg(
                                m.Comparison(
                                    comparisons=[m.ComparisonTarget(m.In())]))
                        ],
                    ),
            ):
                # self.assertFalse(a in b) -> self.assertNotIn(a, b)
                matched = True
                arg1 = cst.Arg(
                    ensure_type(node.args[0].value, cst.Comparison).left)
                arg2 = cst.Arg(
                    ensure_type(node.args[0].value,
                                cst.Comparison).comparisons[0].comparator)

            if matched:
                new_call = node.with_changes(
                    func=cst.Attribute(value=cst.Name("self"),
                                       attr=cst.Name("assertNotIn")),
                    args=[arg1, arg2],
                )
                self.report(node, replacement=new_call)
コード例 #24
0
ファイル: test_attribute.py プロジェクト: stjordanis/LibCST
class AttributeTest(CSTNodeTest):
    @data_provider(
        (
            # Simple attribute access
            {
                "node": cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                "code": "foo.bar",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 0), (1, 7)),
            },
            # Parenthesized attribute access
            {
                "node": cst.Attribute(
                    lpar=(cst.LeftParen(),),
                    value=cst.Name("foo"),
                    attr=cst.Name("bar"),
                    rpar=(cst.RightParen(),),
                ),
                "code": "(foo.bar)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 8)),
            },
            # Make sure that spacing works
            {
                "node": cst.Attribute(
                    lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),),
                    value=cst.Name("foo"),
                    dot=cst.Dot(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    attr=cst.Name("bar"),
                    rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),),
                ),
                "code": "( foo . bar )",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 2), (1, 11)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": (
                    lambda: cst.Attribute(
                        cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(),)
                    )
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": (
                    lambda: cst.Attribute(
                        cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(),)
                    )
                ),
                "expected_re": "right paren without left paren",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
コード例 #25
0
ファイル: test_statement.py プロジェクト: willcrichton/LibCST
class StatementTest(UnitTest):
    @data_provider(
        (
            # Simple imports that are already absolute.
            (None, "from a.b import c", "a.b"),
            ("x.y.z", "from a.b import c", "a.b"),
            # Relative import that can't be resolved due to missing module.
            (None, "from ..w import c", None),
            # Relative import that goes past the module level.
            ("x", "from ...y import z", None),
            ("x.y.z", "from .....w import c", None),
            ("x.y.z", "from ... import c", None),
            # Correct resolution of absolute from relative modules.
            ("x.y.z", "from . import c", "x.y"),
            ("x.y.z", "from .. import c", "x"),
            ("x.y.z", "from .w import c", "x.y.w"),
            ("x.y.z", "from ..w import c", "x.w"),
            ("x.y.z", "from ...w import c", "w"),
        )
    )
    def test_get_absolute_module(
        self, module: Optional[str], importfrom: str, output: Optional[str],
    ) -> None:
        node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
        assert len(node.body) == 1, "Unexpected number of statements!"
        import_node = ensure_type(node.body[0], cst.ImportFrom)

        self.assertEqual(get_absolute_module_for_import(module, import_node), output)
        if output is None:
            with self.assertRaises(Exception):
                get_absolute_module_for_import_or_raise(module, import_node)
        else:
            self.assertEqual(
                get_absolute_module_for_import_or_raise(module, import_node), output
            )

    @data_provider(
        (
            # Nodes without an asname
            (cst.ImportAlias(name=cst.Name("foo")), "foo", None),
            (
                cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
                "foo.bar",
                None,
            ),
            # Nodes with an asname
            (
                cst.ImportAlias(
                    name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz"))
                ),
                "foo",
                "baz",
            ),
            (
                cst.ImportAlias(
                    name=cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                    asname=cst.AsName(name=cst.Name("baz")),
                ),
                "foo.bar",
                "baz",
            ),
        )
    )
    def test_importalias_helpers(
        self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str]
    ) -> None:
        self.assertEqual(alias_node.evaluated_name, full_name)
        self.assertEqual(alias_node.evaluated_alias, alias)
コード例 #26
0
    def visit_Call(self, node: cst.Call) -> None:
        match_compare_is_none = m.ComparisonTarget(
            m.SaveMatchedNode(
                m.OneOf(m.Is(), m.IsNot()),
                "comparison_type",
            ),
            comparator=m.Name("None"),
        )
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(
                    value=m.Name("self"),
                    attr=m.SaveMatchedNode(
                        m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")),
                        "assertion_name",
                    ),
                ),
                args=[
                    m.Arg(
                        m.SaveMatchedNode(
                            m.OneOf(
                                m.Comparison(
                                    comparisons=[match_compare_is_none]),
                                m.UnaryOperation(
                                    operator=m.Not(),
                                    expression=m.Comparison(
                                        comparisons=[match_compare_is_none]),
                                ),
                            ),
                            "argument",
                        ))
                ],
            ),
        )

        if result:
            assertion_name = result["assertion_name"]
            if isinstance(assertion_name, Sequence):
                assertion_name = assertion_name[0]

            argument = result["argument"]
            if isinstance(argument, Sequence):
                argument = argument[0]

            comparison_type = result["comparison_type"]
            if isinstance(comparison_type, Sequence):
                comparison_type = comparison_type[0]

            if m.matches(argument, m.Comparison()):
                assertion_argument = ensure_type(argument, cst.Comparison).left
            else:
                assertion_argument = ensure_type(
                    ensure_type(argument, cst.UnaryOperation).expression,
                    cst.Comparison).left

            negations_seen = 0
            if m.matches(assertion_name, m.Name("assertFalse")):
                negations_seen += 1
            if m.matches(argument, m.UnaryOperation()):
                negations_seen += 1
            if m.matches(comparison_type, m.IsNot()):
                negations_seen += 1

            new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone"
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name(new_attr)),
                args=[cst.Arg(assertion_argument)],
            )

            if new_call is not node:
                self.report(node, replacement=new_call)
コード例 #27
0
    def visit_Call(self, node: cst.Call) -> None:
        # `self.assertTrue(x is not None)` -> `self.assertIsNotNone(x)`
        if m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.IsNot(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)

        # `self.assertTrue(not x is None)` -> `self.assertIsNotNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(value=m.UnaryOperation(
                            operator=m.Not(),
                            expression=m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]),
                        ))
                    ],
                ),
        ):

            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(
                            ensure_type(node.args[0].value,
                                        cst.UnaryOperation).expression,
                            cst.Comparison,
                        ).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertFalse(x is None)` -> `self.assertIsNotNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNotNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertTrue(x is None)` -> `self.assertIsNotNone(x))
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertTrue")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)

        # `self.assertFalse(x is not None)` -> `self.assertIsNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(
                            m.Comparison(comparisons=[
                                m.ComparisonTarget(m.IsNot(),
                                                   comparator=m.Name("None"))
                            ]))
                    ],
                ),
        ):
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(node.args[0].value, cst.Comparison).left)
                ],
            )
            self.report(node, replacement=new_call)
        # `self.assertFalse(not x is None)` -> `self.assertIsNone(x)`
        elif m.matches(
                node,
                m.Call(
                    func=m.Attribute(value=m.Name("self"),
                                     attr=m.Name("assertFalse")),
                    args=[
                        m.Arg(value=m.UnaryOperation(
                            operator=m.Not(),
                            expression=m.Comparison(comparisons=[
                                m.ComparisonTarget(m.Is(),
                                                   comparator=m.Name("None"))
                            ]),
                        ))
                    ],
                ),
        ):

            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name("assertIsNone")),
                args=[
                    cst.Arg(
                        ensure_type(
                            ensure_type(node.args[0].value,
                                        cst.UnaryOperation).expression,
                            cst.Comparison,
                        ).left)
                ],
            )
            self.report(node, replacement=new_call)
コード例 #28
0
ファイル: test_call.py プロジェクト: PhamAn12/GeneticProg
class CallTest(CSTNodeTest):
    @data_provider((
        # Simple call
        {
            "node": cst.Call(cst.Name("foo")),
            "code": "foo()",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(cst.Name("foo"),
                     whitespace_before_args=cst.SimpleWhitespace(" ")),
            "code":
            "foo( )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Call with attribute dereference
        {
            "node": cst.Call(cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
            "code": "foo.bar()",
            "parser": parse_expression,
            "expected_position": None,
        },
        # Positional arguments render test
        {
            "node": cst.Call(cst.Name("foo"), (cst.Arg(cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": None,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(cst.Integer("1")),
                    cst.Arg(cst.Integer("2")),
                    cst.Arg(cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Positional arguments parse test
        {
            "node": cst.Call(cst.Name("foo"),
                             (cst.Arg(value=cst.Integer("1")), )),
            "code": "foo(1)",
            "parser": parse_expression,
            "expected_position": None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    whitespace_after_arg=cst.SimpleWhitespace(" "),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1 )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    value=cst.Integer("1"),
                    comma=cst.Comma(
                        whitespace_after=cst.SimpleWhitespace(" ")),
                ), ),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
            ),
            "code":
            "foo ( 1, )",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(1, 2, 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Keyword arguments render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")),
                    cst.Arg(keyword=cst.Name("two"), value=cst.Integer("2")),
                    cst.Arg(keyword=cst.Name("three"), value=cst.Integer("3")),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Keyword arguments parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (cst.Arg(
                    keyword=cst.Name("one"),
                    equal=cst.AssignEqual(),
                    value=cst.Integer("1"),
                ), ),
            ),
            "code":
            "foo(one = 1)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        keyword=cst.Name("one"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("two"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("three"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                    ),
                ),
            ),
            "code":
            "foo(one = 1, two = 2, three = 3)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Iterator expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="*", value=cst.Name("one")),
                    cst.Arg(star="*", value=cst.Name("two")),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Iterator expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="*", value=cst.Name("one")), )),
            "code":
            "foo(*one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="*",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(*one, *two, *three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Dictionary expansion render test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            None,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(star="**", value=cst.Name("one")),
                    cst.Arg(star="**", value=cst.Name("two")),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Dictionary expansion parser test
        {
            "node":
            cst.Call(cst.Name("foo"),
                     (cst.Arg(star="**", value=cst.Name("one")), )),
            "code":
            "foo(**one)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        star="**",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
            ),
            "code":
            "foo(**one, **two, **three)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Complicated mingling rules render test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(value=cst.Name("pos1")),
                    cst.Arg(star="*", value=cst.Name("list1")),
                    cst.Arg(value=cst.Name("pos2")),
                    cst.Arg(value=cst.Name("pos3")),
                    cst.Arg(star="*", value=cst.Name("list2")),
                    cst.Arg(value=cst.Name("pos4")),
                    cst.Arg(star="*", value=cst.Name("list3")),
                    cst.Arg(keyword=cst.Name("kw1"), value=cst.Integer("1")),
                    cst.Arg(star="*", value=cst.Name("list4")),
                    cst.Arg(keyword=cst.Name("kw2"), value=cst.Integer("2")),
                    cst.Arg(star="*", value=cst.Name("list5")),
                    cst.Arg(keyword=cst.Name("kw3"), value=cst.Integer("3")),
                    cst.Arg(star="**", value=cst.Name("dict1")),
                    cst.Arg(keyword=cst.Name("kw4"), value=cst.Integer("4")),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            None,
            "expected_position":
            None,
        },
        # Complicated mingling rules parser test
        {
            "node":
            cst.Call(
                cst.Name("foo"),
                (
                    cst.Arg(
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        value=cst.Name("pos4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw2"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("2"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("list5"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw3"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("3"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("dict1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw4"),
                        equal=cst.AssignEqual(),
                        value=cst.Integer("4"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("dict2")),
                ),
            ),
            "code":
            "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)",
            "parser":
            parse_expression,
            "expected_position":
            None,
        },
        # Test whitespace
        {
            "node":
            cst.Call(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                func=cst.Name("foo"),
                whitespace_after_func=cst.SimpleWhitespace(" "),
                whitespace_before_args=cst.SimpleWhitespace(" "),
                args=(
                    cst.Arg(
                        keyword=None,
                        value=cst.Name("pos1"),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.Arg(
                        star="*",
                        whitespace_after_star=cst.SimpleWhitespace("  "),
                        keyword=None,
                        value=cst.Name("list1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        keyword=cst.Name("kw1"),
                        equal=cst.AssignEqual(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        value=cst.Integer("1"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        keyword=None,
                        whitespace_after_star=cst.SimpleWhitespace(" "),
                        value=cst.Name("dict1"),
                        whitespace_after_arg=cst.SimpleWhitespace(" "),
                    ),
                ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( foo ( pos1 ,  *  list1, kw1=1, ** dict1 ) )",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 43)),
        },
        # Test args
        {
            "node":
            cst.Arg(
                star="*",
                whitespace_after_star=cst.SimpleWhitespace("  "),
                keyword=None,
                value=cst.Name("list1"),
                comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
            ),
            "code":
            "*  list1, ",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 8)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Basic expression parenthesizing tests.
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), lpar=(cst.LeftParen(), )),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            lambda: cst.Call(func=cst.Name("foo"), rpar=(cst.RightParen(), )),
            "expected_re":
            "right paren without left paren",
        },
        # Test that we handle keyword stuff correctly.
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(equal=cst.AssignEqual(),
                              value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            "Must have a keyword when specifying an AssignEqual",
        },
        # Test that we separate *, ** and keyword args correctly
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(cst.Arg(
                    star="*",
                    keyword=cst.Name("bar"),
                    value=cst.SimpleString("'baz'"),
                ), ),
            ),
            "expected_re":
            "Cannot specify a star and a keyword together",
        },
        # Test for expected star inputs only
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                # pyre-ignore: Ignore type on 'star' since we're testing behavior
                # when somebody isn't using a type checker.
                args=(cst.Arg(star="***", value=cst.SimpleString("'baz'")), ),
            ),
            "expected_re":
            r"Must specify either '', '\*' or '\*\*' for star",
        },
        # Test ordering exceptions
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(star="*", value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have iterable argument unpacking after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(star="**", value=cst.Name("bar")),
                    cst.Arg(value=cst.Name("baz")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument unpacking",
        },
        {
            "get_node":
            lambda: cst.Call(
                func=cst.Name("foo"),
                args=(
                    cst.Arg(keyword=cst.Name("arg"),
                            value=cst.SimpleString("'baz'")),
                    cst.Arg(value=cst.SimpleString("'bar'")),
                ),
            ),
            "expected_re":
            "Cannot have positional argument after keyword argument",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
コード例 #29
0
def name_to_node(name: str) -> Union[cst.Name, cst.Attribute]:
    if "." not in name:
        return cst.Name(name)

    base, name = name.rsplit(".", 1)
    return cst.Attribute(value=name_to_node(base), attr=cst.Name(name))
コード例 #30
0
 def annotation(self):
     return cst.Attribute(cst.Name("types"), cst.Name("ModuleType"))