def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("list") | m.Name("set") | m.Name("dict"),
                    args=[m.Arg(value=m.GeneratorExp() | m.ListComp())],
                ),
        ):
            call_name = cst.ensure_type(node.func, cst.Name).value

            if m.matches(node.args[0].value, m.GeneratorExp()):
                exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp)
                message_formatter = UNNECESSARY_GENERATOR
            else:
                exp = cst.ensure_type(node.args[0].value, cst.ListComp)
                message_formatter = UNNECESSARY_LIST_COMPREHENSION

            replacement = None
            if call_name == "list":
                replacement = node.deep_replace(
                    node, cst.ListComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "set":
                replacement = node.deep_replace(
                    node, cst.SetComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "dict":
                elt = exp.elt
                key = None
                value = None
                if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.Tuple)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.List)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                else:
                    # Unrecoginized form
                    return

                replacement = node.deep_replace(
                    node,
                    # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st
                    #  param but got `BaseExpression`.
                    cst.DictComp(key=key, value=value, for_in=exp.for_in),
                )

            self.report(node,
                        message_formatter.format(func=call_name),
                        replacement=replacement)
Exemple #2
0
 def _dict_call(self, node: cst.Call) -> Union[cst.Call, cst.Dict, cst.DictComp]:
     if not node.args:
         return cst.Dict(elements=[])
     if len(node.args) != 1:
         return node
     value = node.args[0].value
     if isinstance(value, cst.DictComp):
         return value
     if isinstance(value, (cst.ListComp, cst.GeneratorExp)):
         elt = value.elt
         if isinstance(elt, (cst.Tuple, cst.List)) and len(elt.elements) == 2:
             return cst.DictComp(
                 key=elt.elements[0].value,
                 value=elt.elements[1].value,
                 for_in=value.for_in,
             )
     if isinstance(value, (cst.Tuple, cst.List)):
         if value.elements:
             elements = []
             for el in value.elements:
                 if (
                     isinstance(el.value, (cst.Tuple, cst.List))
                     and len(el.value.elements) == 2
                 ):
                     elements.append(
                         cst.DictElement(
                             key=el.value.elements[0].value,
                             value=el.value.elements[1].value,
                         )
                     )
                 else:
                     break
             else:
                 return cst.Dict(elements=elements)
         else:
             return cst.Dict(elements=[])
     return node
Exemple #3
0
class DictCompTest(CSTNodeTest):
    @data_provider([
        # simple DictComp
        {
            "node":
            cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
            ),
            "code":
            "{k: v for a in b}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 17)),
        },
        # custom whitespace around colon
        {
            "node":
            cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
                whitespace_before_colon=cst.SimpleWhitespace("\t"),
                whitespace_after_colon=cst.SimpleWhitespace("\t\t"),
            ),
            "code":
            "{k\t:\t\tv for a in b}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 19)),
        },
        # custom whitespace inside braces
        {
            "node":
            cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
                lbrace=cst.LeftCurlyBrace(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbrace=cst.RightCurlyBrace(
                    whitespace_before=cst.SimpleWhitespace("\t\t")),
            ),
            "code":
            "{\tk: v for a in b\t\t}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 20)),
        },
        # parenthesis
        {
            "node":
            cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
                lpar=[cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "code":
            "({k: v for a in b})",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 18)),
        },
        # missing spaces around DictComp is always okay
        {
            "node":
            cst.DictComp(
                cst.Name("a"),
                cst.Name("b"),
                cst.CompFor(
                    target=cst.Name("c"),
                    iter=cst.DictComp(
                        cst.Name("d"),
                        cst.Name("e"),
                        cst.CompFor(target=cst.Name("f"), iter=cst.Name("g")),
                    ),
                    ifs=[
                        cst.CompIf(
                            cst.Name("h"),
                            whitespace_before=cst.SimpleWhitespace(""),
                        )
                    ],
                    whitespace_after_in=cst.SimpleWhitespace(""),
                ),
            ),
            "code":
            "{a: b for c in{d: e for f in g}if h}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 36)),
        },
        # no whitespace before `for` clause
        {
            "node":
            cst.DictComp(
                cst.Name("k"),
                cst.Name("v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]),
                cst.CompFor(
                    target=cst.Name("a"),
                    iter=cst.Name("b"),
                    whitespace_before=cst.SimpleWhitespace(""),
                ),
            ),
            "code":
            "{k: (v)for a in b}",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 18)),
        },
    ])
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider([
        # unbalanced DictComp
        {
            "get_node":
            lambda: cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
                lpar=[cst.LeftParen()],
            ),
            "expected_re":
            "left paren without right paren",
        },
        # invalid whitespace before for/async
        {
            "get_node":
            lambda: cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(
                    target=cst.Name("a"),
                    iter=cst.Name("b"),
                    whitespace_before=cst.SimpleWhitespace(""),
                ),
            ),
            "expected_re":
            "Must have at least one space before 'for' keyword.",
        },
        {
            "get_node":
            lambda: cst.DictComp(
                cst.Name("k"),
                cst.Name("v"),
                cst.CompFor(
                    target=cst.Name("a"),
                    iter=cst.Name("b"),
                    asynchronous=cst.Asynchronous(),
                    whitespace_before=cst.SimpleWhitespace(""),
                ),
            ),
            "expected_re":
            "Must have at least one space before 'async' keyword.",
        },
    ])
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)