def _set_call(self, node: cst.Call) -> Union[cst.Call, cst.Set, cst.SetComp]: if len(node.args) != 1: return node value = node.args[0].value if isinstance(value, (cst.List, cst.Tuple)): if value.elements: return cst.Set(elements=value.elements) else: return node.with_changes(args=[]) if isinstance(value, (cst.ListComp, cst.SetComp, cst.GeneratorExp)): return cst.SetComp(elt=value.elt, for_in=value.for_in) return node
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
class SimpleCompTest(CSTNodeTest): @data_provider( [ # simple GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")) ), "code": "(a for b in c)", "parser": parse_expression, "expected_position": CodeRange.create((1, 1), (1, 13)), }, # simple ListComp { "node": cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")) ), "code": "[a for b in c]", "parser": parse_expression, "expected_position": CodeRange.create((1, 0), (1, 14)), }, # simple SetComp { "node": cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")) ), "code": "{a for b in c}", "parser": parse_expression, }, # async GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), ), "code": "(a async for b in c)", "parser": parse_expression, }, # a generator doesn't have to own it's own parenthesis { "node": cst.Call( cst.Name("func"), [ cst.Arg( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[], rpar=[], ) ) ], ), "code": "func(a for b in c)", "parser": parse_expression, }, # add a few 'if' clauses { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf(cst.Name("d")), cst.CompIf(cst.Name("e")), cst.CompIf(cst.Name("f")), ], ), ), "code": "(a for b in c if d if e if f)", "parser": parse_expression, "expected_position": CodeRange.create((1, 1), (1, 28)), }, # nested/inner for-in clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e") ), ), ), "code": "(a for b in c for d in e)", "parser": parse_expression, }, # nested/inner for-in clause with an 'if' clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[cst.CompIf(cst.Name("d"))], inner_for_in=cst.CompFor( target=cst.Name("e"), iter=cst.Name("f") ), ), ), "code": "(a for b in c if d for e in f)", "parser": parse_expression, }, # custom whitespace { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace("\t"), whitespace_before_test=cst.SimpleWhitespace("\t\t"), ) ], whitespace_before=cst.SimpleWhitespace(" "), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), ), lpar=[cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))], rpar=[ cst.RightParen(whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\fa for b in c\tif\t\td\f\f)", "parser": parse_expression, "expected_position": CodeRange.create((1, 2), (1, 30)), }, # custom whitespace around ListComp's brackets { "node": cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace("\t") ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace("\t\t") ), lpar=[cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))], rpar=[ cst.RightParen(whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f[\ta for b in c\t\t]\f\f)", "parser": parse_expression, "expected_position": CodeRange.create((1, 2), (1, 19)), }, # custom whitespace around SetComp's braces { "node": cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t") ), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace("\t\t") ), lpar=[cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))], rpar=[ cst.RightParen(whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f{\ta for b in c\t\t}\f\f)", "parser": parse_expression, }, # no whitespace between elements { "node": cst.GeneratorExp( cst.Name("a", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), cst.CompFor( target=cst.Name( "b", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), iter=cst.Name( "c", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), ifs=[ cst.CompIf( cst.Name( "d", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), whitespace_before=cst.SimpleWhitespace(""), whitespace_before_test=cst.SimpleWhitespace(""), ) ], inner_for_in=cst.CompFor( target=cst.Name( "e", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), iter=cst.Name( "f", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "((a)for(b)in(c)if(d)for(e)in(f))", "parser": parse_expression, "expected_position": CodeRange.create((1, 1), (1, 31)), }, # no whitespace before/after GeneratorExp is valid { "node": cst.Comparison( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.GeneratorExp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "(a for b in c)is(d for e in f)", "parser": parse_expression, }, # no whitespace before/after ListComp is valid { "node": cst.Comparison( cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.ListComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "[a for b in c]is[d for e in f]", "parser": parse_expression, }, # no whitespace before/after SetComp is valid { "node": cst.Comparison( cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.SetComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "{a for b in c}is{d for e in f}", "parser": parse_expression, }, ] ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'async' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_for=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space before 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before_test=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space after 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'async' keyword.", ), ) ) def test_invalid( self, get_node: Callable[[], cst.CSTNode], expected_re: str ) -> None: self.assert_invalid(get_node, expected_re)