def _list_call(self, node: cst.Call) -> Union[cst.Call, cst.List, cst.ListComp]: if not node.args: return cst.List(elements=[]) if len(node.args) != 1: return node value = node.args[0].value if isinstance(value, cst.ListComp): return value if isinstance(value, cst.GeneratorExp): return cst.ListComp(elt=value.elt, for_in=value.for_in) if isinstance(value, cst.List): return value if isinstance(value, cst.Tuple): return cst.List(elements=value.elements) return node
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), slice=( cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("Callable"), slice=( cst.SubscriptElement( cst.Index( cst.List([ cst.Element( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes( oldtype)) ]))), cst.SubscriptElement( cst.Index(cst.Name("bool"))), ), ))), ), )))
def to_list_cst(*list_elements): return cst.List(list_elements)
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("tuple") | m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.List() | m.Tuple())], ), ) or m.matches( node, m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"), args=[]), ): pairs_matcher = m.ZeroOrMore( m.Element(m.Tuple( elements=[m.DoNotCare(), m.DoNotCare()])) | m.Element(m.List( elements=[m.DoNotCare(), m.DoNotCare()]))) exp = cst.ensure_type(node, cst.Call) call_name = cst.ensure_type(exp.func, cst.Name).value # If this is a empty call, it's an Unnecessary Call where we rewrite the call # to literal, except set(). if not exp.args: elements = [] message_formatter = UNNCESSARY_CALL else: arg = exp.args[0].value elements = cst.ensure_type( arg, cst.List if isinstance(arg, cst.List) else cst.Tuple).elements message_formatter = UNNECESSARY_LITERAL if call_name == "tuple": new_node = cst.Tuple(elements=elements) elif call_name == "list": new_node = cst.List(elements=elements) elif call_name == "set": # set() doesn't have an equivelant literal call. If it was # matched here, it's an unnecessary literal suggestion. if len(elements) == 0: self.report( node, UNNECESSARY_LITERAL.format(func=call_name), replacement=node.deep_replace( node, cst.Call(func=cst.Name("set"))), ) return new_node = cst.Set(elements=elements) elif len(elements) == 0 or m.matches( exp.args[0].value, m.Tuple(elements=[pairs_matcher]) | m.List(elements=[pairs_matcher]), ): new_node = cst.Dict(elements=[( lambda val: cst.DictElement(val.elements[ 0].value, val.elements[1].value))(cst.ensure_type( ele.value, cst.Tuple if isinstance(ele.value, cst.Tuple ) else cst.List, )) for ele in elements]) else: # Unrecoginized form return self.report( node, message_formatter.format(func=call_name), replacement=node.deep_replace(node, new_node), )
class ListTest(CSTNodeTest): # A lot of Element/StarredElement tests are provided by the tests for Tuple, so we # we don't need to duplicate them here. @data_provider([ # zero-element list { "node": cst.List([]), "code": "[]", "parser": parse_expression }, # one-element list, sentinel comma value { "node": cst.List([cst.Element(cst.Name("single_element"))]), "code": "[single_element]", "parser": parse_expression, }, # custom whitespace between brackets { "node": cst.List( [cst.Element(cst.Name("single_element"))], lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace("\t")), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), ), "code": "[\tsingle_element ]", "parser": parse_expression, "expected_position": CodeRange.create((1, 0), (1, 21)), }, # two-element list, sentinel comma value { "node": cst.List( [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]), "code": "[one, two]", "parser": None, }, # with parenthesis { "node": cst.List( [cst.Element(cst.Name("one"))], lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "([one])", "parser": None, "expected_position": CodeRange.create((1, 1), (1, 6)), }, # starred element { "node": cst.List([ cst.StarredElement(cst.Name("one")), cst.StarredElement(cst.Name("two")), ]), "code": "[*one, *two]", "parser": None, "expected_position": CodeRange.create((1, 0), (1, 12)), }, # missing spaces around list, always okay { "node": cst.For( target=cst.List([ cst.Element(cst.Name("k"), comma=cst.Comma()), cst.Element(cst.Name("v")), ]), iter=cst.Name("abc"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), ), "code": "for[k,v]in abc: pass\n", "parser": parse_statement, }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider((( lambda: cst.List( [cst.Element(cst.Name("mismatched"))], lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
def make_list(elts): return cst.List(elements=[cst.Element(value=elt) for elt in elts])