def leave_Subscript( self, original_node: cst.Subscript, updated_node: cst.Subscript ) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Union")): # Take the original node, remove do not care so we have concrete types. concrete_only_expr = _remove_do_not_care(original_node) # Take the current subscript, add a MatchIfTrue node to it. match_if_true_expr = _add_match_if_true( _remove_do_not_care(updated_node), concrete_only_expr ) return updated_node.with_changes( slice=[ *updated_node.slice, # Make sure that OneOf/AllOf types are widened to take all of the # original SomeTypes, and also takes a MatchIfTrue, so that # you can do something like OneOf(SomeType(), MatchIfTrue(lambda)). # We could explicitly enforce that MatchIfTrue could not be used # inside OneOf/AllOf clauses, but then if you want to mix and match you # would have to use a recursive matches() inside your lambda which # is super ugly. cst.ExtSlice(cst.Index(_add_generic("OneOf", match_if_true_expr))), cst.ExtSlice(cst.Index(_add_generic("AllOf", match_if_true_expr))), # We use original node here, because we don't want MatchIfTrue # to get modifications to child Union classes. If we allow # that, we get MatchIfTrue nodes whose Callable takes in # OneOf/AllOf and MatchIfTrue values, which is incorrect. MatchIfTrue # only takes in cst nodes, and returns a boolean. _get_match_if_true(concrete_only_expr), ] ) return updated_node
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.ExtSlice: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.ExtSlice( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), cst.Index( cst.Subscript( cst.Name("Callable"), slice=[ cst.ExtSlice( cst.Index( cst.List([ cst.Element( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes( oldtype)) ]))), cst.ExtSlice(cst.Index(cst.Name("bool"))), ], )), )))
def leave_Subscript(self, original_node: cst.Subscript, updated_node: cst.Subscript) -> cst.Subscript: if original_node in self.in_match_if_true: self.in_match_if_true.remove(original_node) if original_node in self.fixup_nodes: self.fixup_nodes.remove(original_node) return updated_node.with_changes(slice=[ *updated_node.slice, cst.ExtSlice(cst.Index(_add_generic("AtLeastN", original_node))), cst.ExtSlice(cst.Index(_add_generic("AtMostN", original_node))), ]) return updated_node
def test_deprecated_construction(self) -> None: module = cst.Module(body=[ cst.SimpleStatementLine(body=[ cst.Expr(value=cst.Subscript( value=cst.Name(value="foo"), slice=[ cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="1"))), cst.ExtSlice(slice=cst.Index(value=cst.Integer( value="2"))), ], )) ]) ]) self.assertEqual(module.code, "foo[1, 2]\n")
def _get_wrapped_union_type(node: cst.BaseExpression, addition: cst.ExtSlice, *additions: cst.ExtSlice) -> cst.Subscript: """ Take two or more nodes, wrap them in a union type. Function signature is explicitly defined as taking at least one addition for type safety. """ return cst.Subscript(cst.Name("Union"), [cst.ExtSlice(cst.Index(node)), addition, *additions])
def _get_do_not_care() -> cst.ExtSlice: """ Construct a DoNotCareSentinel entry appropriate for going into a Union. """ return cst.ExtSlice(cst.Index(cst.Name("DoNotCareSentinel")))
class SubscriptTest(CSTNodeTest): @data_provider(( # Simple subscript expression ( cst.Subscript(cst.Name("foo"), cst.Index(cst.Integer("5"))), "foo[5]", True, ), # Test creation of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), ), ), "foo[1:2:3]", False, ), ( cst.Subscript( cst.Name("foo"), ( cst.ExtSlice( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), cst.ExtSlice(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3, 5]", False, CodeRange((1, 0), (1, 13)), ), # Test parsing of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), ), "foo[1:2:3]", True, ), ( cst.Subscript( cst.Name("foo"), ( cst.ExtSlice( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), comma=cst.Comma(), ), cst.ExtSlice(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3,5]", True, ), # Some more wild slice creations ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2")), ), "foo[1:2]", True, ), ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=cst.Integer("1"), upper=None)), "foo[1:]", True, ), ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=cst.Integer("2"))), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=cst.Integer("1"), upper=None, step=cst.Integer("3")), ), "foo[1::3]", False, ), ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=None, upper=None, step=cst.Integer("3")), ), "foo[::3]", False, CodeRange((1, 0), (1, 8)), ), # Some more wild slice parsings ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2")), ), "foo[1:2]", True, ), ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=cst.Integer("1"), upper=None)), "foo[1:]", True, ), ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=cst.Integer("2"))), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), cst.Slice( lower=cst.Integer("1"), upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), ), ), "foo[1::3]", True, ), ( cst.Subscript( cst.Name("foo"), cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), ), ), "foo[::3]", True, ), # Valid list clone operations rendering ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), ), "foo[::]", True, ), # Valid list clone operations parsing ( cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), ), "foo[::]", True, ), # In parenthesis ( cst.Subscript( lpar=(cst.LeftParen(), ), value=cst.Name("foo"), slice=cst.Index(cst.Integer("5")), rpar=(cst.RightParen(), ), ), "(foo[5])", True, ), # Verify spacing ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=cst.Index(cst.Integer("5")), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 5 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=( cst.ExtSlice( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ExtSlice(slice=cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 , 5 ] )", True, CodeRange((1, 2), (1, 24)), ), # Test Index, Slice, ExtSlice (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))), ( cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), "::", False, CodeRange((1, 0), (1, 2)), ), ( cst.ExtSlice( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), "1 : 2 : 3 , ", False, CodeRange((1, 0), (1, 9)), ), )) def test_valid( self, node: cst.CSTNode, code: str, check_parsing: bool, position: Optional[CodeRange] = None, ) -> None: if check_parsing: self.validate_node(node, code, parse_expression, expected_position=position) else: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Subscript( cst.Name("foo"), cst.Index(cst.Integer("5")), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Subscript( cst.Name("foo"), cst.Index(cst.Integer("5")), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), (lambda: cst.Subscript(cst.Name("foo"), ()), "empty ExtSlice"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)