Exemple #1
0
 def _list_call(self, node: cst.Call) -> Union[cst.Call, cst.List, cst.ListComp]:
     if not node.args:
         return cst.List(elements=[])
     if len(node.args) != 1:
         return node
     value = node.args[0].value
     if isinstance(value, cst.ListComp):
         return value
     if isinstance(value, cst.GeneratorExp):
         return cst.ListComp(elt=value.elt, for_in=value.for_in)
     if isinstance(value, cst.List):
         return value
     if isinstance(value, cst.Tuple):
         return cst.List(elements=value.elements)
     return node
Exemple #2
0
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement:
    """
    Construct a MatchIfTrue type node appropriate for going into a Union.
    """
    return cst.SubscriptElement(
        cst.Index(
            cst.Subscript(
                cst.Name("MatchIfTrue"),
                slice=(
                    cst.SubscriptElement(
                        cst.Index(
                            cst.Subscript(
                                cst.Name("Callable"),
                                slice=(
                                    cst.SubscriptElement(
                                        cst.Index(
                                            cst.List([
                                                cst.Element(
                                                    # MatchIfTrue takes in the original node type,
                                                    # and returns a boolean. So, lets convert our
                                                    # quoted classes (forward refs to other
                                                    # matchers) back to the CSTNode they refer to.
                                                    # We can do this because there's always a 1:1
                                                    # name mapping.
                                                    _convert_match_nodes_to_cst_nodes(
                                                        oldtype))
                                            ]))),
                                    cst.SubscriptElement(
                                        cst.Index(cst.Name("bool"))),
                                ),
                            ))), ),
            )))
Exemple #3
0
 def to_list_cst(*list_elements):
     return cst.List(list_elements)
Exemple #4
0
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("tuple") | m.Name("list") | m.Name("set")
                    | m.Name("dict"),
                    args=[m.Arg(value=m.List() | m.Tuple())],
                ),
        ) or m.matches(
                node,
                m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"),
                       args=[]),
        ):

            pairs_matcher = m.ZeroOrMore(
                m.Element(m.Tuple(
                    elements=[m.DoNotCare(), m.DoNotCare()]))
                | m.Element(m.List(
                    elements=[m.DoNotCare(), m.DoNotCare()])))

            exp = cst.ensure_type(node, cst.Call)
            call_name = cst.ensure_type(exp.func, cst.Name).value

            # If this is a empty call, it's an Unnecessary Call where we rewrite the call
            # to literal, except set().
            if not exp.args:
                elements = []
                message_formatter = UNNCESSARY_CALL
            else:
                arg = exp.args[0].value
                elements = cst.ensure_type(
                    arg, cst.List
                    if isinstance(arg, cst.List) else cst.Tuple).elements
                message_formatter = UNNECESSARY_LITERAL

            if call_name == "tuple":
                new_node = cst.Tuple(elements=elements)
            elif call_name == "list":
                new_node = cst.List(elements=elements)
            elif call_name == "set":
                # set() doesn't have an equivelant literal call. If it was
                # matched here, it's an unnecessary literal suggestion.
                if len(elements) == 0:
                    self.report(
                        node,
                        UNNECESSARY_LITERAL.format(func=call_name),
                        replacement=node.deep_replace(
                            node, cst.Call(func=cst.Name("set"))),
                    )
                    return
                new_node = cst.Set(elements=elements)
            elif len(elements) == 0 or m.matches(
                    exp.args[0].value,
                    m.Tuple(elements=[pairs_matcher])
                    | m.List(elements=[pairs_matcher]),
            ):
                new_node = cst.Dict(elements=[(
                    lambda val: cst.DictElement(val.elements[
                        0].value, val.elements[1].value))(cst.ensure_type(
                            ele.value,
                            cst.Tuple if isinstance(ele.value, cst.Tuple
                                                    ) else cst.List,
                        )) for ele in elements])
            else:
                # Unrecoginized form
                return

            self.report(
                node,
                message_formatter.format(func=call_name),
                replacement=node.deep_replace(node, new_node),
            )
Exemple #5
0
class ListTest(CSTNodeTest):

    # A lot of Element/StarredElement tests are provided by the tests for Tuple, so we
    # we don't need to duplicate them here.
    @data_provider([
        # zero-element list
        {
            "node": cst.List([]),
            "code": "[]",
            "parser": parse_expression
        },
        # one-element list, sentinel comma value
        {
            "node": cst.List([cst.Element(cst.Name("single_element"))]),
            "code": "[single_element]",
            "parser": parse_expression,
        },
        # custom whitespace between brackets
        {
            "node":
            cst.List(
                [cst.Element(cst.Name("single_element"))],
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace("    ")),
            ),
            "code":
            "[\tsingle_element    ]",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange.create((1, 0), (1, 21)),
        },
        # two-element list, sentinel comma value
        {
            "node":
            cst.List(
                [cst.Element(cst.Name("one")),
                 cst.Element(cst.Name("two"))]),
            "code":
            "[one, two]",
            "parser":
            None,
        },
        # with parenthesis
        {
            "node":
            cst.List(
                [cst.Element(cst.Name("one"))],
                lpar=[cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "code":
            "([one])",
            "parser":
            None,
            "expected_position":
            CodeRange.create((1, 1), (1, 6)),
        },
        # starred element
        {
            "node":
            cst.List([
                cst.StarredElement(cst.Name("one")),
                cst.StarredElement(cst.Name("two")),
            ]),
            "code":
            "[*one, *two]",
            "parser":
            None,
            "expected_position":
            CodeRange.create((1, 0), (1, 12)),
        },
        # missing spaces around list, always okay
        {
            "node":
            cst.For(
                target=cst.List([
                    cst.Element(cst.Name("k"), comma=cst.Comma()),
                    cst.Element(cst.Name("v")),
                ]),
                iter=cst.Name("abc"),
                body=cst.SimpleStatementSuite([cst.Pass()]),
                whitespace_after_for=cst.SimpleWhitespace(""),
                whitespace_before_in=cst.SimpleWhitespace(""),
            ),
            "code":
            "for[k,v]in abc: pass\n",
            "parser":
            parse_statement,
        },
    ])
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(((
        lambda: cst.List(
            [cst.Element(cst.Name("mismatched"))],
            lpar=[cst.LeftParen(), cst.LeftParen()],
            rpar=[cst.RightParen()],
        ),
        "unbalanced parens",
    ), ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
Exemple #6
0
def make_list(elts):
    return cst.List(elements=[cst.Element(value=elt) for elt in elts])