def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), slice=( cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("Callable"), slice=( cst.SubscriptElement( cst.Index( cst.List([ cst.Element( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes( oldtype)) ]))), cst.SubscriptElement( cst.Index(cst.Name("bool"))), ), ))), ), )))
def leave_Subscript( self, original_node: cst.Subscript, updated_node: cst.Subscript ) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Union")): # Take the original node, remove do not care so we have concrete types. concrete_only_expr = _remove_do_not_care(original_node) # Take the current subscript, add a MatchIfTrue node to it. match_if_true_expr = _add_match_if_true( _remove_do_not_care(updated_node), concrete_only_expr ) return updated_node.with_changes( slice=[ *updated_node.slice, # Make sure that OneOf/AllOf types are widened to take all of the # original SomeTypes, and also takes a MatchIfTrue, so that # you can do something like OneOf(SomeType(), MatchIfTrue(lambda)). # We could explicitly enforce that MatchIfTrue could not be used # inside OneOf/AllOf clauses, but then if you want to mix and match you # would have to use a recursive matches() inside your lambda which # is super ugly. cst.ExtSlice(cst.Index(_add_generic("OneOf", match_if_true_expr))), cst.ExtSlice(cst.Index(_add_generic("AllOf", match_if_true_expr))), # We use original node here, because we don't want MatchIfTrue # to get modifications to child Union classes. If we allow # that, we get MatchIfTrue nodes whose Callable takes in # OneOf/AllOf and MatchIfTrue values, which is incorrect. MatchIfTrue # only takes in cst nodes, and returns a boolean. _get_match_if_true(concrete_only_expr), ] ) return updated_node
def annotation(self): if is_unknown(self.key) and is_unknown(self.value): return cst.Name("dict") return cst.Subscript( cst.Name("Dict"), [ cst.SubscriptElement(cst.Index(self.key.annotation)), cst.SubscriptElement(cst.Index(self.value.annotation)), ], )
def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: if original_node in self.in_match_if_true: self.in_match_if_true.remove(original_node) if original_node in self.fixup_nodes: self.fixup_nodes.remove(original_node) return updated_node.with_changes(slice=[ *updated_node.slice, cst.SubscriptElement( cst.Index(_add_generic("AtLeastN", original_node))), cst.SubscriptElement( cst.Index(_add_generic("AtMostN", original_node))), ]) return updated_node
def leave_SubscriptElement(self, original_node, updated_node): if self.max_annot > self.max_annot_depth and self.current_annot_depth == self.max_annot_depth: self.max_annot -= 1 return updated_node.with_changes(slice=cst.Index(value=cst.Name( value='Any'))) else: return updated_node
def leave_Subscript( self, original_node: cst.Subscript, updated_node: cst.Subscript ) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Sequence")): nodeslice = updated_node.slice if isinstance(nodeslice, cst.Index): possibleunion = nodeslice.value if isinstance(possibleunion, cst.Subscript): # Special case for Sequence[Union] so that we make more collapsed # types. if possibleunion.value.deep_equals(cst.Name("Union")): return updated_node.with_changes( slice=nodeslice.with_changes( value=possibleunion.with_changes( slice=[*possibleunion.slice, _get_do_not_care()] ) ) ) # This is a sequence of some node, add DoNotCareSentinel here so that # a person can add a do not care to a sequence that otherwise has # valid matcher nodes. return updated_node.with_changes( slice=cst.Index( _get_wrapped_union_type(nodeslice.value, _get_do_not_care()) ) ) raise Exception("Unexpected slice type for Sequence!") return updated_node
def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Union")): # Take the original node, remove do not care so we have concrete types. # Explicitly taking the original node because we want to discard nested # changes. concrete_only_expr = _remove_types(updated_node, ["DoNotCareSentinel"]) return updated_node.with_changes(slice=[ *updated_node.slice, cst.SubscriptElement( cst.Index(_add_generic("OneOf", concrete_only_expr))), cst.SubscriptElement( cst.Index(_add_generic("AllOf", concrete_only_expr))), ]) return updated_node
def test_deprecated_construction(self) -> None: module = cst.Module(body=[ cst.SimpleStatementLine(body=[ cst.Expr(value=cst.Subscript( value=cst.Name(value="foo"), slice=[ cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="1"))), cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="2"))), ], )) ]) ]) self.assertEqual(module.code, "foo[1, 2]\n")
def leave_SubscriptElement(self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement): if self.type_annot_visited and self.parametric_type_annot_visited: if match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.Subscript()))): q_name, _ = self.__get_qualified_name( original_node.slice.value.value) if q_name is not None: return updated_node.with_changes(slice=cst.Index( value=cst.Subscript( value=self.__name2annotation(q_name).annotation, slice=updated_node.slice.value.slice))) elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.Ellipsis()))): # TODO: Should the original node be returned?! return updated_node.with_changes(slice=cst.Index( value=cst.Ellipsis())) elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.SimpleString(value=match.DoNotCare())))): return updated_node.with_changes(slice=cst.Index( value=updated_node.slice.value)) elif match.matches( original_node, match.SubscriptElement(slice=match.Index(value=match.Name( value='None')))): return original_node elif match.matches( original_node, match.SubscriptElement(slice=match.Index( value=match.List()))): return updated_node.with_changes(slice=cst.Index( value=updated_node.slice.value)) else: q_name, _ = self.__get_qualified_name( original_node.slice.value) if q_name is not None: return updated_node.with_changes(slice=cst.Index( value=self.__name2annotation(q_name).annotation)) return original_node
def annotation(self): return cst.Subscript( cst.Name("Union"), [ cst.SubscriptElement(cst.Index(o.annotation)) for o in self.options ], )
def annotation(self): """ Doesn't exist yet as generic type, but should https://github.com/python/typing/issues/159 """ if is_unknown(self.start) and is_unknown(self.stop) and is_unknown( self.step): return cst.Name("slice") return cst.Subscript( cst.Name("slice"), [ cst.SubscriptElement(cst.Index(self.start.annotation)), cst.SubscriptElement(cst.Index(self.stop.annotation)), cst.SubscriptElement(cst.Index(self.start.annotation)), ], )
def _get_wrapped_union_type(node: cst.BaseExpression, addition: cst.ExtSlice, *additions: cst.ExtSlice) -> cst.Subscript: """ Take two or more nodes, wrap them in a union type. Function signature is explicitly defined as taking at least one addition for type safety. """ return cst.Subscript(cst.Name("Union"), [cst.ExtSlice(cst.Index(node)), addition, *additions])
def annotation(self) -> cst.BaseExpression: if self.options: return cst.Subscript( cst.Name("Literal"), [ cst.SubscriptElement( cst.Index(cst.SimpleString(repr(option)))) for option in self.options ], ) return cst.Name("str")
def annotation(self): if is_unknown(self.items): return cst.Name("tuple") if isinstance(self.items, tuple): return cst.Subscript( cst.Name("Tuple"), [ cst.SubscriptElement(cst.Index(s.annotation)) for s in self.items ], ) return cst.helpers.parse_template_expression( "Tuple[{item}, ...]", parser_config, item=self.items.annotation)
def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], )
def flatten_union_subscript(self, _, updated_node): new_slice = [] has_none = False for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): new_slice += item.slice.value.slice # peel off "Optional" has_none = True elif m.matches(item.slice.value, m.Subscript(m.Name("Union"))) and m.matches( updated_node.value, item.slice.value.value): new_slice += item.slice.value.slice # peel off "Union" or "Literal" elif m.matches(item.slice.value, m.Name("None")): has_none = True else: new_slice.append(item) if has_none: new_slice.append( cst.SubscriptElement(slice=cst.Index(cst.Name("None")))) return updated_node.with_changes(slice=new_slice)
def test_collect_targets(): tree = cst.parse_module(''' x = [0, 1] x[0] = 1 x.attr = 2 ''') x = cst.Name(value='x') x0 = cst.Subscript( value=x, slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))], ) xa = cst.Attribute( value=x, attr=cst.Name('attr'), ) golds = x, x0, xa targets = collect_targets(tree) assert all(t.deep_equals(g) for t, g in zip(targets, golds))
def assign_properties( p: typing.Dict[str, typing.Tuple[Metadata, Type]], is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]: for name, metadata_and_tp in sort_items(p): if bad_name(name): continue metadata, tp = metadata_and_tp ann = tp.annotation yield cst.SimpleStatementLine( [ cst.AnnAssign( cst.Name(name), cst.Annotation( cst.Subscript(cst.Name("ClassVar"), [cst.SubscriptElement(cst.Index(ann))] ) if is_classvar else ann), ) ], leading_lines=[cst.EmptyLine()] + [ cst.EmptyLine(comment=cst.Comment("# " + l)) for l in metadata_lines(metadata) ], )
def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Sequence")): slc = updated_node.slice # TODO: We can remove the instance check after ExtSlice is deprecated. if not isinstance(slc, Sequence) or len(slc) != 1: raise Exception( "Unexpected number of sequence elements inside Sequence type " + "annotation!") nodeslice = slc[0].slice if isinstance(nodeslice, cst.Index): possibleunion = nodeslice.value if isinstance(possibleunion, cst.Subscript): # Special case for Sequence[Union] so that we make more collapsed # types. if possibleunion.value.deep_equals(cst.Name("Union")): return updated_node.with_deep_changes( possibleunion, slice=[ *possibleunion.slice, _get_do_not_care(), _get_match_metadata(), ], ) # This is a sequence of some node, add DoNotCareSentinel here so that # a person can add a do not care to a sequence that otherwise has # valid matcher nodes. return updated_node.with_changes(slice=(cst.SubscriptElement( cst.Index( _get_wrapped_union_type( nodeslice.value, _get_do_not_care(), _get_match_metadata(), ))), )) raise Exception("Unexpected slice type for Sequence!") return updated_node
class SubscriptTest(CSTNodeTest): @data_provider(( # Simple subscript expression ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[5]", True, ), # Test creation of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", False, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3, 5]", False, CodeRange((1, 0), (1, 13)), ), # Test parsing of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", True, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), comma=cst.Comma(), ), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3,5]", True, ), # Some more wild slice creations ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, step=cst.Integer("3"), )), ), ), "foo[1::3]", False, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=None, step=cst.Integer("3"))), ), ), "foo[::3]", False, CodeRange((1, 0), (1, 8)), ), # Some more wild slice parsings ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1::3]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[::3]", True, ), # Valid list clone operations rendering ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # Valid list clone operations parsing ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # In parenthesis ( cst.Subscript( lpar=(cst.LeftParen(), ), value=cst.Name("foo"), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "(foo[5])", True, ), # Verify spacing ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 5 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), )), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 , 5 ] )", True, CodeRange((1, 2), (1, 24)), ), # Test Index, Slice, SubscriptElement (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))), ( cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), "::", False, CodeRange((1, 0), (1, 2)), ), ( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), "1 : 2 : 3 , ", False, CodeRange((1, 0), (1, 9)), ), )) def test_valid( self, node: cst.CSTNode, code: str, check_parsing: bool, position: Optional[CodeRange] = None, ) -> None: if check_parsing: self.validate_node(node, code, parse_expression, expected_position=position) else: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
def make_aref(name: str, idx: cst.BaseExpression) -> cst.Subscript: sub_elt = cst.SubscriptElement(slice=cst.Index(value=idx)) return cst.Subscript(value=cst.Name(name), slice=[sub_elt])
def to_index_cst(value): return cst.Index(value)
def _get_match_metadata() -> cst.SubscriptElement: """ Construct a MetadataMatchType entry appropriate for going into a Union. """ return cst.SubscriptElement(cst.Index(cst.Name("MetadataMatchType")))
def _get_do_not_care() -> cst.SubscriptElement: """ Construct a DoNotCareSentinel entry appropriate for going into a Union. """ return cst.SubscriptElement(cst.Index(cst.Name("DoNotCareSentinel")))
def make_index(arr, idx): return cst.Subscript( value=arr, slice=[cst.SubscriptElement(slice=cst.Index(value=idx))])
def _add_generic(name: str, oldtype: cst.BaseExpression) -> cst.BaseExpression: return cst.Subscript(cst.Name(name), (cst.SubscriptElement(cst.Index(oldtype)), ))
def annotation(self): if self.name is None: return cst.Name("type") return cst.Subscript( cst.Name("Type"), [cst.SubscriptElement(cst.Index(self.name.annotation))])
class AnnAssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")), cst.Integer("5")), "code": "foo: str = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 12)), }, # Annotation creation without assignment { "node": cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str"))), "code": "foo: str", "parser": None, "expected_position": CodeRange((1, 0), (1, 8)), }, # Complex annotation creation { "node": cst.AnnAssign( cst.Name("foo"), cst.Annotation( cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), )), cst.Integer("5"), ), "code": "foo: Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 22)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: str = 5\n", "parser": parse_statement, "expected_position": None, }, # Annotation without assignment { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Name("str"), whitespace_before_indicator=cst.SimpleWhitespace(""), ), value=None, ), )), "code": "foo: str\n", "parser": parse_statement, "expected_position": None, }, # Complex annotation { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(""), ), equal=cst.AssignEqual(), value=cst.Integer("5"), ), )), "code": "foo: Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test { "node": cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), "code": "foo : Optional[str] = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 26)), }, { "node": cst.SimpleStatementLine((cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation( annotation=cst.Subscript( cst.Name("Optional"), (cst.SubscriptElement(cst.Index(cst.Name("str"))), ), ), whitespace_before_indicator=cst.SimpleWhitespace(" "), whitespace_after_indicator=cst.SimpleWhitespace(" "), ), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), value=cst.Integer("5"), ), )), "code": "foo : Optional[str] = 5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.AnnAssign( target=cst.Name("foo"), annotation=cst.Annotation(cst.Name("str")), equal=cst.AssignEqual(), value=None, )), "expected_re": "Must have a value when specifying an AssignEqual.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def _get_do_not_care() -> cst.ExtSlice: """ Construct a DoNotCareSentinel entry appropriate for going into a Union. """ return cst.ExtSlice(cst.Index(cst.Name("DoNotCareSentinel")))
def test_subscript(self) -> None: # Test that we can insert various subscript slices into an # acceptible spot. expression = parse_template_expression( "Optional[{type}]", type=cst.Name("int"), ) self.assertEqual( self.code(expression), "Optional[int]", ) expression = parse_template_expression( "Tuple[{type1}, {type2}]", type1=cst.Name("int"), type2=cst.Name("str"), ) self.assertEqual( self.code(expression), "Tuple[int, str]", ) expression = parse_template_expression( "Optional[{type}]", type=cst.Index(cst.Name("int")), ) self.assertEqual( self.code(expression), "Optional[int]", ) expression = parse_template_expression( "Optional[{type}]", type=cst.SubscriptElement(cst.Index(cst.Name("int"))), ) self.assertEqual( self.code(expression), "Optional[int]", ) expression = parse_template_expression( "foo[{slice}]", slice=cst.Slice(cst.Integer("5"), cst.Integer("6")), ) self.assertEqual( self.code(expression), "foo[5:6]", ) expression = parse_template_expression( "foo[{slice}]", slice=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))), ) self.assertEqual( self.code(expression), "foo[5:6]", ) expression = parse_template_expression( "foo[{slice}]", slice=cst.Slice(cst.Integer("5"), cst.Integer("6")), ) self.assertEqual( self.code(expression), "foo[5:6]", ) expression = parse_template_expression( "foo[{slice}]", slice=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))), ) self.assertEqual( self.code(expression), "foo[5:6]", ) expression = parse_template_expression( "foo[{slice1}, {slice2}]", slice1=cst.Slice(cst.Integer("5"), cst.Integer("6")), slice2=cst.Index(cst.Integer("7")), ) self.assertEqual( self.code(expression), "foo[5:6, 7]", ) expression = parse_template_expression( "foo[{slice1}, {slice2}]", slice1=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))), slice2=cst.SubscriptElement(cst.Index(cst.Integer("7"))), ) self.assertEqual( self.code(expression), "foo[5:6, 7]", )