コード例 #1
0
 def leave_Subscript(
     self, original_node: cst.Subscript, updated_node: cst.Subscript
 ) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Union")):
         # Take the original node, remove do not care so we have concrete types.
         concrete_only_expr = _remove_do_not_care(original_node)
         # Take the current subscript, add a MatchIfTrue node to it.
         match_if_true_expr = _add_match_if_true(
             _remove_do_not_care(updated_node), concrete_only_expr
         )
         return updated_node.with_changes(
             slice=[
                 *updated_node.slice,
                 # Make sure that OneOf/AllOf types are widened to take all of the
                 # original SomeTypes, and also takes a MatchIfTrue, so that
                 # you can do something like OneOf(SomeType(), MatchIfTrue(lambda)).
                 # We could explicitly enforce that MatchIfTrue could not be used
                 # inside OneOf/AllOf clauses, but then if you want to mix and match you
                 # would have to use a recursive matches() inside your lambda which
                 # is super ugly.
                 cst.ExtSlice(cst.Index(_add_generic("OneOf", match_if_true_expr))),
                 cst.ExtSlice(cst.Index(_add_generic("AllOf", match_if_true_expr))),
                 # We use original node here, because we don't want MatchIfTrue
                 # to get modifications to child Union classes. If we allow
                 # that, we get MatchIfTrue nodes whose Callable takes in
                 # OneOf/AllOf and MatchIfTrue values, which is incorrect. MatchIfTrue
                 # only takes in cst nodes, and returns a boolean.
                 _get_match_if_true(concrete_only_expr),
             ]
         )
     return updated_node
コード例 #2
0
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.ExtSlice:
    """
    Construct a MatchIfTrue type node appropriate for going into a Union.
    """
    return cst.ExtSlice(
        cst.Index(
            cst.Subscript(
                cst.Name("MatchIfTrue"),
                cst.Index(
                    cst.Subscript(
                        cst.Name("Callable"),
                        slice=[
                            cst.ExtSlice(
                                cst.Index(
                                    cst.List([
                                        cst.Element(
                                            # MatchIfTrue takes in the original node type,
                                            # and returns a boolean. So, lets convert our
                                            # quoted classes (forward refs to other
                                            # matchers) back to the CSTNode they refer to.
                                            # We can do this because there's always a 1:1
                                            # name mapping.
                                            _convert_match_nodes_to_cst_nodes(
                                                oldtype))
                                    ]))),
                            cst.ExtSlice(cst.Index(cst.Name("bool"))),
                        ],
                    )),
            )))
コード例 #3
0
 def leave_Subscript(self, original_node: cst.Subscript,
                     updated_node: cst.Subscript) -> cst.Subscript:
     if original_node in self.in_match_if_true:
         self.in_match_if_true.remove(original_node)
     if original_node in self.fixup_nodes:
         self.fixup_nodes.remove(original_node)
         return updated_node.with_changes(slice=[
             *updated_node.slice,
             cst.ExtSlice(cst.Index(_add_generic("AtLeastN",
                                                 original_node))),
             cst.ExtSlice(cst.Index(_add_generic("AtMostN",
                                                 original_node))),
         ])
     return updated_node
コード例 #4
0
    def test_deprecated_construction(self) -> None:
        module = cst.Module(body=[
            cst.SimpleStatementLine(body=[
                cst.Expr(value=cst.Subscript(
                    value=cst.Name(value="foo"),
                    slice=[
                        cst.ExtSlice(slice=cst.Index(value=cst.Integer(
                            value="1"))),
                        cst.ExtSlice(slice=cst.Index(value=cst.Integer(
                            value="2"))),
                    ],
                ))
            ])
        ])

        self.assertEqual(module.code, "foo[1, 2]\n")
コード例 #5
0
def _get_wrapped_union_type(node: cst.BaseExpression, addition: cst.ExtSlice,
                            *additions: cst.ExtSlice) -> cst.Subscript:
    """
    Take two or more nodes, wrap them in a union type. Function signature is
    explicitly defined as taking at least one addition for type safety.
    """

    return cst.Subscript(cst.Name("Union"),
                         [cst.ExtSlice(cst.Index(node)), addition, *additions])
コード例 #6
0
def _get_do_not_care() -> cst.ExtSlice:
    """
    Construct a DoNotCareSentinel entry appropriate for going into a Union.
    """

    return cst.ExtSlice(cst.Index(cst.Name("DoNotCareSentinel")))
コード例 #7
0
class SubscriptTest(CSTNodeTest):
    @data_provider((
        # Simple subscript expression
        (
            cst.Subscript(cst.Name("foo"), cst.Index(cst.Integer("5"))),
            "foo[5]",
            True,
        ),
        # Test creation of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(
                    lower=cst.Integer("1"),
                    upper=cst.Integer("2"),
                    step=cst.Integer("3"),
                ),
            ),
            "foo[1:2:3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.ExtSlice(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            upper=cst.Integer("2"),
                            step=cst.Integer("3"),
                        )),
                    cst.ExtSlice(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3, 5]",
            False,
            CodeRange((1, 0), (1, 13)),
        ),
        # Test parsing of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(
                    lower=cst.Integer("1"),
                    first_colon=cst.Colon(),
                    upper=cst.Integer("2"),
                    second_colon=cst.Colon(),
                    step=cst.Integer("3"),
                ),
            ),
            "foo[1:2:3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.ExtSlice(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(),
                    ),
                    cst.ExtSlice(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3,5]",
            True,
        ),
        # Some more wild slice creations
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2")),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(cst.Name("foo"),
                          cst.Slice(lower=cst.Integer("1"), upper=None)),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(cst.Name("foo"),
                          cst.Slice(lower=None, upper=cst.Integer("2"))),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=cst.Integer("1"),
                          upper=None,
                          step=cst.Integer("3")),
            ),
            "foo[1::3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=None, upper=None, step=cst.Integer("3")),
            ),
            "foo[::3]",
            False,
            CodeRange((1, 0), (1, 8)),
        ),
        # Some more wild slice parsings
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2")),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(cst.Name("foo"),
                          cst.Slice(lower=cst.Integer("1"), upper=None)),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(cst.Name("foo"),
                          cst.Slice(lower=None, upper=cst.Integer("2"))),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(
                    lower=cst.Integer("1"),
                    upper=None,
                    second_colon=cst.Colon(),
                    step=cst.Integer("3"),
                ),
            ),
            "foo[1::3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(
                    lower=None,
                    upper=None,
                    second_colon=cst.Colon(),
                    step=cst.Integer("3"),
                ),
            ),
            "foo[::3]",
            True,
        ),
        # Valid list clone operations rendering
        (
            cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=None,
                          upper=None,
                          second_colon=cst.Colon(),
                          step=None),
            ),
            "foo[::]",
            True,
        ),
        # Valid list clone operations parsing
        (
            cst.Subscript(cst.Name("foo"), cst.Slice(lower=None, upper=None)),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                cst.Slice(lower=None,
                          upper=None,
                          second_colon=cst.Colon(),
                          step=None),
            ),
            "foo[::]",
            True,
        ),
        # In parenthesis
        (
            cst.Subscript(
                lpar=(cst.LeftParen(), ),
                value=cst.Name("foo"),
                slice=cst.Index(cst.Integer("5")),
                rpar=(cst.RightParen(), ),
            ),
            "(foo[5])",
            True,
        ),
        # Verify spacing
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=cst.Index(cst.Integer("5")),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 5 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=cst.Slice(
                    lower=cst.Integer("1"),
                    first_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    upper=cst.Integer("2"),
                    second_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    step=cst.Integer("3"),
                ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(
                    cst.ExtSlice(
                        slice=cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.ExtSlice(slice=cst.Index(cst.Integer("5"))),
                ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ,  5 ] )",
            True,
            CodeRange((1, 2), (1, 24)),
        ),
        # Test Index, Slice, ExtSlice
        (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))),
        (
            cst.Slice(lower=None,
                      upper=None,
                      second_colon=cst.Colon(),
                      step=None),
            "::",
            False,
            CodeRange((1, 0), (1, 2)),
        ),
        (
            cst.ExtSlice(
                slice=cst.Slice(
                    lower=cst.Integer("1"),
                    first_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    upper=cst.Integer("2"),
                    second_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    step=cst.Integer("3"),
                ),
                comma=cst.Comma(
                    whitespace_before=cst.SimpleWhitespace(" "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
            ),
            "1 : 2 : 3 ,  ",
            False,
            CodeRange((1, 0), (1, 9)),
        ),
    ))
    def test_valid(
        self,
        node: cst.CSTNode,
        code: str,
        check_parsing: bool,
        position: Optional[CodeRange] = None,
    ) -> None:
        if check_parsing:
            self.validate_node(node,
                               code,
                               parse_expression,
                               expected_position=position)
        else:
            self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                cst.Index(cst.Integer("5")),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                cst.Index(cst.Integer("5")),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (lambda: cst.Subscript(cst.Name("foo"), ()), "empty ExtSlice"),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)