def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
def _dict_call(self, node: cst.Call) -> Union[cst.Call, cst.Dict, cst.DictComp]: if not node.args: return cst.Dict(elements=[]) if len(node.args) != 1: return node value = node.args[0].value if isinstance(value, cst.DictComp): return value if isinstance(value, (cst.ListComp, cst.GeneratorExp)): elt = value.elt if isinstance(elt, (cst.Tuple, cst.List)) and len(elt.elements) == 2: return cst.DictComp( key=elt.elements[0].value, value=elt.elements[1].value, for_in=value.for_in, ) if isinstance(value, (cst.Tuple, cst.List)): if value.elements: elements = [] for el in value.elements: if ( isinstance(el.value, (cst.Tuple, cst.List)) and len(el.value.elements) == 2 ): elements.append( cst.DictElement( key=el.value.elements[0].value, value=el.value.elements[1].value, ) ) else: break else: return cst.Dict(elements=elements) else: return cst.Dict(elements=[]) return node
class DictCompTest(CSTNodeTest): @data_provider([ # simple DictComp { "node": cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")), ), "code": "{k: v for a in b}", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 17)), }, # custom whitespace around colon { "node": cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")), whitespace_before_colon=cst.SimpleWhitespace("\t"), whitespace_after_colon=cst.SimpleWhitespace("\t\t"), ), "code": "{k\t:\t\tv for a in b}", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 19)), }, # custom whitespace inside braces { "node": cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")), lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t")), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace("\t\t")), ), "code": "{\tk: v for a in b\t\t}", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 20)), }, # parenthesis { "node": cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "({k: v for a in b})", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 18)), }, # missing spaces around DictComp is always okay { "node": cst.DictComp( cst.Name("a"), cst.Name("b"), cst.CompFor( target=cst.Name("c"), iter=cst.DictComp( cst.Name("d"), cst.Name("e"), cst.CompFor(target=cst.Name("f"), iter=cst.Name("g")), ), ifs=[ cst.CompIf( cst.Name("h"), whitespace_before=cst.SimpleWhitespace(""), ) ], whitespace_after_in=cst.SimpleWhitespace(""), ), ), "code": "{a: b for c in{d: e for f in g}if h}", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 36)), }, # no whitespace before `for` clause { "node": cst.DictComp( cst.Name("k"), cst.Name("v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), cst.CompFor( target=cst.Name("a"), iter=cst.Name("b"), whitespace_before=cst.SimpleWhitespace(""), ), ), "code": "{k: (v)for a in b}", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 18)), }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider([ # unbalanced DictComp { "get_node": lambda: cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")), lpar=[cst.LeftParen()], ), "expected_re": "left paren without right paren", }, # invalid whitespace before for/async { "get_node": lambda: cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor( target=cst.Name("a"), iter=cst.Name("b"), whitespace_before=cst.SimpleWhitespace(""), ), ), "expected_re": "Must have at least one space before 'for' keyword.", }, { "get_node": lambda: cst.DictComp( cst.Name("k"), cst.Name("v"), cst.CompFor( target=cst.Name("a"), iter=cst.Name("b"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), "expected_re": "Must have at least one space before 'async' keyword.", }, ]) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)