def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), slice=( cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("Callable"), slice=( cst.SubscriptElement( cst.Index( cst.List([ cst.Element( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes( oldtype)) ]))), cst.SubscriptElement( cst.Index(cst.Name("bool"))), ), ))), ), )))
def annotation(self): return cst.Subscript( cst.Name("Union"), [ cst.SubscriptElement(cst.Index(o.annotation)) for o in self.options ], )
def _get_wrapped_union_type(node: cst.BaseExpression, addition: cst.ExtSlice, *additions: cst.ExtSlice) -> cst.Subscript: """ Take two or more nodes, wrap them in a union type. Function signature is explicitly defined as taking at least one addition for type safety. """ return cst.Subscript(cst.Name("Union"), [cst.ExtSlice(cst.Index(node)), addition, *additions])
def annotation(self): if is_unknown(self.key) and is_unknown(self.value): return cst.Name("dict") return cst.Subscript( cst.Name("Dict"), [ cst.SubscriptElement(cst.Index(self.key.annotation)), cst.SubscriptElement(cst.Index(self.value.annotation)), ], )
def test_deprecated_non_element_construction(self) -> None: module = cst.Module(body=[ cst.SimpleStatementLine(body=[ cst.Expr(value=cst.Subscript( value=cst.Name(value="foo"), slice=cst.Index(value=cst.Integer(value="1")), )) ]) ]) self.assertEqual(module.code, "foo[1]\n")
def annotation(self) -> cst.BaseExpression: if self.options: return cst.Subscript( cst.Name("Literal"), [ cst.SubscriptElement( cst.Index(cst.SimpleString(repr(option)))) for option in self.options ], ) return cst.Name("str")
def annotation(self): if is_unknown(self.items): return cst.Name("tuple") if isinstance(self.items, tuple): return cst.Subscript( cst.Name("Tuple"), [ cst.SubscriptElement(cst.Index(s.annotation)) for s in self.items ], ) return cst.helpers.parse_template_expression( "Tuple[{item}, ...]", parser_config, item=self.items.annotation)
def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], )
def leave_SubscriptElement(self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement): if self.type_annot_visited and self.parametric_type_annot_visited: if match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.Subscript()))): q_name, _ = self.__get_qualified_name( original_node.slice.value.value) if q_name is not None: return updated_node.with_changes(slice=cst.Index( value=cst.Subscript( value=self.__name2annotation(q_name).annotation, slice=updated_node.slice.value.slice))) elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.Ellipsis()))): # TODO: Should the original node be returned?! return updated_node.with_changes(slice=cst.Index( value=cst.Ellipsis())) elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.SimpleString(value=match.DoNotCare())))): return updated_node.with_changes(slice=cst.Index( value=updated_node.slice.value)) elif match.matches( original_node, match.SubscriptElement(slice=match.Index(value=match.Name( value='None')))): return original_node elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.List()))): return updated_node.with_changes(slice=cst.Index( value=updated_node.slice.value)) else: q_name, _ = self.__get_qualified_name( original_node.slice.value) if q_name is not None: return updated_node.with_changes(slice=cst.Index( value=self.__name2annotation(q_name).annotation)) return original_node
def annotation(self): """ Doesn't exist yet as generic type, but should https://github.com/python/typing/issues/159 """ if is_unknown(self.start) and is_unknown(self.stop) and is_unknown( self.step): return cst.Name("slice") return cst.Subscript( cst.Name("slice"), [ cst.SubscriptElement(cst.Index(self.start.annotation)), cst.SubscriptElement(cst.Index(self.stop.annotation)), cst.SubscriptElement(cst.Index(self.start.annotation)), ], )
def leave_Annotation(self, original_node: cst.Annotation, updated_node: cst.Annotation): if self.type_annot_visited: self.type_annot_visited = False if self.parametric_type_annot_visited: self.parametric_type_annot_visited = False q_name, _ = self.__get_qualified_name( original_node.annotation.value) if q_name is not None: return updated_node.with_changes(annotation=cst.Subscript( value=self.__name2annotation(q_name).annotation, slice=updated_node.annotation.slice)) else: q_name, _ = self.__get_qualified_name(original_node.annotation) if q_name is not None: return updated_node.with_changes( annotation=self.__name2annotation(q_name).annotation) return original_node
def test_collect_targets(): tree = cst.parse_module(''' x = [0, 1] x[0] = 1 x.attr = 2 ''') x = cst.Name(value='x') x0 = cst.Subscript( value=x, slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))], ) xa = cst.Attribute( value=x, attr=cst.Name('attr'), ) golds = x, x0, xa targets = collect_targets(tree) assert all(t.deep_equals(g) for t, g in zip(targets, golds))
def leave_Annotation(self, original_node: cst.Annotation) -> None: if self.contains_union_with_none(original_node): scope = self.get_metadata(cst.metadata.ScopeProvider, original_node, None) nones = 0 indexes = [] replacement = None if scope is not None and "Optional" in scope: for s in cst.ensure_type(original_node.annotation, cst.Subscript).slice: if m.matches(s, m.SubscriptElement(m.Index(m.Name("None")))): nones += 1 else: indexes.append(s.slice) if not (nones > 1) and len(indexes) == 1: replacement = original_node.with_changes( annotation=cst.Subscript( value=cst.Name("Optional"), slice=(cst.SubscriptElement(indexes[0]),), ) ) # TODO(T57106602) refactor lint replacement once extract exists self.report(original_node, replacement=replacement)
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
def annotation(self): if self.name is None: return cst.Name("type") return cst.Subscript( cst.Name("Type"), [cst.SubscriptElement(cst.Index(self.name.annotation))])
def to_subscript_cst(value, *slice_elements): return cst.Subscript(value, slice_elements)
def _add_generic(name: str, oldtype: cst.BaseExpression) -> cst.BaseExpression: return cst.Subscript(cst.Name(name), (cst.SubscriptElement(cst.Index(oldtype)), ))
def sample_index(placeholder, idx): return cst.Subscript(placeholder, [cst.SubscriptElement(cst.Index(idx))])
class AnnAssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")), cst.Integer("5")), "code": "foo: str = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 12)), }, # Annotation creation without assignment { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str"))), "code": "foo: str", "parser": None, "expected_position": CodeRange((1, 0), (1, 8)), }, # Complex annotation creation { "node": cst.AnnAssign( cst.Name("foo"), cst.Annotation( cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), )), cst.Integer("5"), ), "code": "foo: Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 22)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: str = 5\n", "parser": parse_statement, "expected_position": None, }, # Annotation without assignment { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), value=None, ), )), "code": "foo: str\n", "parser": parse_statement, "expected_position": None, }, # Complex annotation { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test { "node": cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), "code": "foo : Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 26)), }, { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), )), "code": "foo : Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation(cst.Name("str")), equal=cst.AssignEqual(), value=None, )), "expected_re": "Must have a value when specifying an AssignEqual.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def make_aref(name: str, idx: cst.BaseExpression) -> cst.Subscript: sub_elt = cst.SubscriptElement(slice=cst.Index(value=idx)) return cst.Subscript(value=cst.Name(name), slice=[sub_elt])
def make_index(arr, idx): return cst.Subscript( value=arr, slice=[cst.SubscriptElement(slice=cst.Index(value=idx))])
class SubscriptTest(CSTNodeTest): @data_provider(( # Simple subscript expression ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[5]", True, ), # Test creation of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", False, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3, 5]", False, CodeRange((1, 0), (1, 13)), ), # Test parsing of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", True, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), comma=cst.Comma(), ), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3,5]", True, ), # Some more wild slice creations ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, step=cst.Integer("3"), )), ), ), "foo[1::3]", False, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=None, step=cst.Integer("3"))), ), ), "foo[::3]", False, CodeRange((1, 0), (1, 8)), ), # Some more wild slice parsings ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1::3]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[::3]", True, ), # Valid list clone operations rendering ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # Valid list clone operations parsing ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # In parenthesis ( cst.Subscript( lpar=(cst.LeftParen(), ), value=cst.Name("foo"), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "(foo[5])", True, ), # Verify spacing ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 5 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), )), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 , 5 ] )", True, CodeRange((1, 2), (1, 24)), ), # Test Index, Slice, SubscriptElement (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))), ( cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), "::", False, CodeRange((1, 0), (1, 2)), ), ( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), "1 : 2 : 3 , ", False, CodeRange((1, 0), (1, 9)), ), )) def test_valid( self, node: cst.CSTNode, code: str, check_parsing: bool, position: Optional[CodeRange] = None, ) -> None: if check_parsing: self.validate_node(node, code, parse_expression, expected_position=position) else: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
def sample_index(rv, returned_var, *_): return cst.Subscript( cst.Name(rv), [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))], )