Ejemplo n.º 1
0
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement:
    """
    Construct a MatchIfTrue type node appropriate for going into a Union.
    """
    return cst.SubscriptElement(
        cst.Index(
            cst.Subscript(
                cst.Name("MatchIfTrue"),
                slice=(
                    cst.SubscriptElement(
                        cst.Index(
                            cst.Subscript(
                                cst.Name("Callable"),
                                slice=(
                                    cst.SubscriptElement(
                                        cst.Index(
                                            cst.List([
                                                cst.Element(
                                                    # MatchIfTrue takes in the original node type,
                                                    # and returns a boolean. So, lets convert our
                                                    # quoted classes (forward refs to other
                                                    # matchers) back to the CSTNode they refer to.
                                                    # We can do this because there's always a 1:1
                                                    # name mapping.
                                                    _convert_match_nodes_to_cst_nodes(
                                                        oldtype))
                                            ]))),
                                    cst.SubscriptElement(
                                        cst.Index(cst.Name("bool"))),
                                ),
                            ))), ),
            )))
Ejemplo n.º 2
0
 def leave_Subscript(
     self, original_node: cst.Subscript, updated_node: cst.Subscript
 ) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Union")):
         # Take the original node, remove do not care so we have concrete types.
         concrete_only_expr = _remove_do_not_care(original_node)
         # Take the current subscript, add a MatchIfTrue node to it.
         match_if_true_expr = _add_match_if_true(
             _remove_do_not_care(updated_node), concrete_only_expr
         )
         return updated_node.with_changes(
             slice=[
                 *updated_node.slice,
                 # Make sure that OneOf/AllOf types are widened to take all of the
                 # original SomeTypes, and also takes a MatchIfTrue, so that
                 # you can do something like OneOf(SomeType(), MatchIfTrue(lambda)).
                 # We could explicitly enforce that MatchIfTrue could not be used
                 # inside OneOf/AllOf clauses, but then if you want to mix and match you
                 # would have to use a recursive matches() inside your lambda which
                 # is super ugly.
                 cst.ExtSlice(cst.Index(_add_generic("OneOf", match_if_true_expr))),
                 cst.ExtSlice(cst.Index(_add_generic("AllOf", match_if_true_expr))),
                 # We use original node here, because we don't want MatchIfTrue
                 # to get modifications to child Union classes. If we allow
                 # that, we get MatchIfTrue nodes whose Callable takes in
                 # OneOf/AllOf and MatchIfTrue values, which is incorrect. MatchIfTrue
                 # only takes in cst nodes, and returns a boolean.
                 _get_match_if_true(concrete_only_expr),
             ]
         )
     return updated_node
Ejemplo n.º 3
0
 def annotation(self):
     if is_unknown(self.key) and is_unknown(self.value):
         return cst.Name("dict")
     return cst.Subscript(
         cst.Name("Dict"),
         [
             cst.SubscriptElement(cst.Index(self.key.annotation)),
             cst.SubscriptElement(cst.Index(self.value.annotation)),
         ],
     )
Ejemplo n.º 4
0
 def leave_Subscript(self, original_node: cst.Subscript,
                     updated_node: cst.Subscript) -> cst.Subscript:
     if original_node in self.in_match_if_true:
         self.in_match_if_true.remove(original_node)
     if original_node in self.fixup_nodes:
         self.fixup_nodes.remove(original_node)
         return updated_node.with_changes(slice=[
             *updated_node.slice,
             cst.SubscriptElement(
                 cst.Index(_add_generic("AtLeastN", original_node))),
             cst.SubscriptElement(
                 cst.Index(_add_generic("AtMostN", original_node))),
         ])
     return updated_node
Ejemplo n.º 5
0
 def leave_SubscriptElement(self, original_node, updated_node):
     if self.max_annot > self.max_annot_depth and self.current_annot_depth == self.max_annot_depth:
         self.max_annot -= 1
         return updated_node.with_changes(slice=cst.Index(value=cst.Name(
             value='Any')))
     else:
         return updated_node
Ejemplo n.º 6
0
 def leave_Subscript(
     self, original_node: cst.Subscript, updated_node: cst.Subscript
 ) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Sequence")):
         nodeslice = updated_node.slice
         if isinstance(nodeslice, cst.Index):
             possibleunion = nodeslice.value
             if isinstance(possibleunion, cst.Subscript):
                 # Special case for Sequence[Union] so that we make more collapsed
                 # types.
                 if possibleunion.value.deep_equals(cst.Name("Union")):
                     return updated_node.with_changes(
                         slice=nodeslice.with_changes(
                             value=possibleunion.with_changes(
                                 slice=[*possibleunion.slice, _get_do_not_care()]
                             )
                         )
                     )
             # This is a sequence of some node, add DoNotCareSentinel here so that
             # a person can add a do not care to a sequence that otherwise has
             # valid matcher nodes.
             return updated_node.with_changes(
                 slice=cst.Index(
                     _get_wrapped_union_type(nodeslice.value, _get_do_not_care())
                 )
             )
         raise Exception("Unexpected slice type for Sequence!")
     return updated_node
Ejemplo n.º 7
0
 def leave_Subscript(self, original_node: cst.Subscript,
                     updated_node: cst.Subscript) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Union")):
         # Take the original node, remove do not care so we have concrete types.
         # Explicitly taking the original node because we want to discard nested
         # changes.
         concrete_only_expr = _remove_types(updated_node,
                                            ["DoNotCareSentinel"])
         return updated_node.with_changes(slice=[
             *updated_node.slice,
             cst.SubscriptElement(
                 cst.Index(_add_generic("OneOf", concrete_only_expr))),
             cst.SubscriptElement(
                 cst.Index(_add_generic("AllOf", concrete_only_expr))),
         ])
     return updated_node
Ejemplo n.º 8
0
    def test_deprecated_construction(self) -> None:
        module = cst.Module(body=[
            cst.SimpleStatementLine(body=[
                cst.Expr(value=cst.Subscript(
                    value=cst.Name(value="foo"),
                    slice=[
                        cst.ExtSlice(slice=cst.Index(value=cst.Integer(
                            value="1"))),
                        cst.ExtSlice(slice=cst.Index(value=cst.Integer(
                            value="2"))),
                    ],
                ))
            ])
        ])

        self.assertEqual(module.code, "foo[1, 2]\n")
Ejemplo n.º 9
0
    def leave_SubscriptElement(self, original_node: cst.SubscriptElement,
                               updated_node: cst.SubscriptElement):
        if self.type_annot_visited and self.parametric_type_annot_visited:
            if match.matches(
                    original_node,
                    match.SubscriptElement(slice=match.Index(
                        value=match.Subscript()))):
                q_name, _ = self.__get_qualified_name(
                    original_node.slice.value.value)
                if q_name is not None:
                    return updated_node.with_changes(slice=cst.Index(
                        value=cst.Subscript(
                            value=self.__name2annotation(q_name).annotation,
                            slice=updated_node.slice.value.slice)))
            elif match.matches(
                    original_node,
                    match.SubscriptElement(slice=match.Index(
                        value=match.Ellipsis()))):
                # TODO: Should the original node be returned?!
                return updated_node.with_changes(slice=cst.Index(
                    value=cst.Ellipsis()))
            elif match.matches(
                    original_node,
                    match.SubscriptElement(slice=match.Index(
                        value=match.SimpleString(value=match.DoNotCare())))):
                return updated_node.with_changes(slice=cst.Index(
                    value=updated_node.slice.value))
            elif match.matches(
                    original_node,
                    match.SubscriptElement(slice=match.Index(value=match.Name(
                        value='None')))):
                return original_node
            elif match.matches(
                    original_node,
                    match.SubscriptElement(slice=match.Index(
                        value=match.List()))):
                return updated_node.with_changes(slice=cst.Index(
                    value=updated_node.slice.value))
            else:
                q_name, _ = self.__get_qualified_name(
                    original_node.slice.value)
                if q_name is not None:
                    return updated_node.with_changes(slice=cst.Index(
                        value=self.__name2annotation(q_name).annotation))

        return original_node
Ejemplo n.º 10
0
 def annotation(self):
     return cst.Subscript(
         cst.Name("Union"),
         [
             cst.SubscriptElement(cst.Index(o.annotation))
             for o in self.options
         ],
     )
Ejemplo n.º 11
0
    def annotation(self):
        """
        Doesn't exist yet as generic type, but should

        https://github.com/python/typing/issues/159
        """
        if is_unknown(self.start) and is_unknown(self.stop) and is_unknown(
                self.step):
            return cst.Name("slice")
        return cst.Subscript(
            cst.Name("slice"),
            [
                cst.SubscriptElement(cst.Index(self.start.annotation)),
                cst.SubscriptElement(cst.Index(self.stop.annotation)),
                cst.SubscriptElement(cst.Index(self.start.annotation)),
            ],
        )
Ejemplo n.º 12
0
def _get_wrapped_union_type(node: cst.BaseExpression, addition: cst.ExtSlice,
                            *additions: cst.ExtSlice) -> cst.Subscript:
    """
    Take two or more nodes, wrap them in a union type. Function signature is
    explicitly defined as taking at least one addition for type safety.
    """

    return cst.Subscript(cst.Name("Union"),
                         [cst.ExtSlice(cst.Index(node)), addition, *additions])
Ejemplo n.º 13
0
 def annotation(self) -> cst.BaseExpression:
     if self.options:
         return cst.Subscript(
             cst.Name("Literal"),
             [
                 cst.SubscriptElement(
                     cst.Index(cst.SimpleString(repr(option))))
                 for option in self.options
             ],
         )
     return cst.Name("str")
Ejemplo n.º 14
0
    def annotation(self):
        if is_unknown(self.items):
            return cst.Name("tuple")
        if isinstance(self.items, tuple):
            return cst.Subscript(
                cst.Name("Tuple"),
                [
                    cst.SubscriptElement(cst.Index(s.annotation))
                    for s in self.items
                ],
            )

        return cst.helpers.parse_template_expression(
            "Tuple[{item}, ...]", parser_config, item=self.items.annotation)
Ejemplo n.º 15
0
 def choice_ast(rng_key):
     return cst.Call(
         func=cst.Attribute(
             value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")),
             attr=cst.Name("choice"),
         ),
         args=[
             cst.Arg(rng_key),
             cst.Arg(
                 cst.Subscript(
                     cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")),
                     [cst.SubscriptElement(cst.Index(cst.Integer("0")))],
                 )
             ),
         ],
     )
Ejemplo n.º 16
0
 def flatten_union_subscript(self, _, updated_node):
     new_slice = []
     has_none = False
     for item in updated_node.slice:
         if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
             new_slice += item.slice.value.slice  # peel off "Optional"
             has_none = True
         elif m.matches(item.slice.value,
                        m.Subscript(m.Name("Union"))) and m.matches(
                            updated_node.value, item.slice.value.value):
             new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
         elif m.matches(item.slice.value, m.Name("None")):
             has_none = True
         else:
             new_slice.append(item)
     if has_none:
         new_slice.append(
             cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
     return updated_node.with_changes(slice=new_slice)
Ejemplo n.º 17
0
def test_collect_targets():
    tree = cst.parse_module('''
x = [0, 1]
x[0] = 1
x.attr = 2
''')
    x = cst.Name(value='x')
    x0 = cst.Subscript(
        value=x,
        slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))],
    )
    xa = cst.Attribute(
        value=x,
        attr=cst.Name('attr'),
    )

    golds = x, x0, xa

    targets = collect_targets(tree)
    assert all(t.deep_equals(g) for t, g in zip(targets, golds))
Ejemplo n.º 18
0
def assign_properties(
        p: typing.Dict[str, typing.Tuple[Metadata, Type]],
        is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]:
    for name, metadata_and_tp in sort_items(p):
        if bad_name(name):
            continue
        metadata, tp = metadata_and_tp
        ann = tp.annotation
        yield cst.SimpleStatementLine(
            [
                cst.AnnAssign(
                    cst.Name(name),
                    cst.Annotation(
                        cst.Subscript(cst.Name("ClassVar"),
                                      [cst.SubscriptElement(cst.Index(ann))]
                                      ) if is_classvar else ann),
                )
            ],
            leading_lines=[cst.EmptyLine()] + [
                cst.EmptyLine(comment=cst.Comment("# " + l))
                for l in metadata_lines(metadata)
            ],
        )
Ejemplo n.º 19
0
 def leave_Subscript(self, original_node: cst.Subscript,
                     updated_node: cst.Subscript) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Sequence")):
         slc = updated_node.slice
         # TODO: We can remove the instance check after ExtSlice is deprecated.
         if not isinstance(slc, Sequence) or len(slc) != 1:
             raise Exception(
                 "Unexpected number of sequence elements inside Sequence type "
                 + "annotation!")
         nodeslice = slc[0].slice
         if isinstance(nodeslice, cst.Index):
             possibleunion = nodeslice.value
             if isinstance(possibleunion, cst.Subscript):
                 # Special case for Sequence[Union] so that we make more collapsed
                 # types.
                 if possibleunion.value.deep_equals(cst.Name("Union")):
                     return updated_node.with_deep_changes(
                         possibleunion,
                         slice=[
                             *possibleunion.slice,
                             _get_do_not_care(),
                             _get_match_metadata(),
                         ],
                     )
             # This is a sequence of some node, add DoNotCareSentinel here so that
             # a person can add a do not care to a sequence that otherwise has
             # valid matcher nodes.
             return updated_node.with_changes(slice=(cst.SubscriptElement(
                 cst.Index(
                     _get_wrapped_union_type(
                         nodeslice.value,
                         _get_do_not_care(),
                         _get_match_metadata(),
                     ))), ))
         raise Exception("Unexpected slice type for Sequence!")
     return updated_node
Ejemplo n.º 20
0
class SubscriptTest(CSTNodeTest):
    @data_provider((
        # Simple subscript expression
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
            ),
            "foo[5]",
            True,
        ),
        # Test creation of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=cst.Integer("2"),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1:2:3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.SubscriptElement(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            upper=cst.Integer("2"),
                            step=cst.Integer("3"),
                        )),
                    cst.SubscriptElement(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3, 5]",
            False,
            CodeRange((1, 0), (1, 13)),
        ),
        # Test parsing of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        first_colon=cst.Colon(),
                        upper=cst.Integer("2"),
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1:2:3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.SubscriptElement(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(),
                    ),
                    cst.SubscriptElement(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3,5]",
            True,
        ),
        # Some more wild slice creations
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"),
                              upper=cst.Integer("2"))), ),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"), upper=None)), ),
            ),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=cst.Integer("2"))), ),
            ),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=None,
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1::3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=None,
                              step=cst.Integer("3"))), ),
            ),
            "foo[::3]",
            False,
            CodeRange((1, 0), (1, 8)),
        ),
        # Some more wild slice parsings
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"),
                              upper=cst.Integer("2"))), ),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"), upper=None)), ),
            ),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=cst.Integer("2"))), ),
            ),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=None,
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1::3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[::3]",
            True,
        ),
        # Valid list clone operations rendering
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ),
            ),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=None,
                    )), ),
            ),
            "foo[::]",
            True,
        ),
        # Valid list clone operations parsing
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ),
            ),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=None,
                    )), ),
            ),
            "foo[::]",
            True,
        ),
        # In parenthesis
        (
            cst.Subscript(
                lpar=(cst.LeftParen(), ),
                value=cst.Name("foo"),
                slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rpar=(cst.RightParen(), ),
            ),
            "(foo[5])",
            True,
        ),
        # Verify spacing
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 5 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        first_colon=cst.Colon(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        upper=cst.Integer("2"),
                        second_colon=cst.Colon(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        step=cst.Integer("3"),
                    )), ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(
                    cst.SubscriptElement(
                        slice=cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))),
                ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ,  5 ] )",
            True,
            CodeRange((1, 2), (1, 24)),
        ),
        # Test Index, Slice, SubscriptElement
        (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))),
        (
            cst.Slice(lower=None,
                      upper=None,
                      second_colon=cst.Colon(),
                      step=None),
            "::",
            False,
            CodeRange((1, 0), (1, 2)),
        ),
        (
            cst.SubscriptElement(
                slice=cst.Slice(
                    lower=cst.Integer("1"),
                    first_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    upper=cst.Integer("2"),
                    second_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    step=cst.Integer("3"),
                ),
                comma=cst.Comma(
                    whitespace_before=cst.SimpleWhitespace(" "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
            ),
            "1 : 2 : 3 ,  ",
            False,
            CodeRange((1, 0), (1, 9)),
        ),
    ))
    def test_valid(
        self,
        node: cst.CSTNode,
        code: str,
        check_parsing: bool,
        position: Optional[CodeRange] = None,
    ) -> None:
        if check_parsing:
            self.validate_node(node,
                               code,
                               parse_expression,
                               expected_position=position)
        else:
            self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
Ejemplo n.º 21
0
def make_aref(name: str, idx: cst.BaseExpression) -> cst.Subscript:
    sub_elt = cst.SubscriptElement(slice=cst.Index(value=idx))
    return cst.Subscript(value=cst.Name(name), slice=[sub_elt])
Ejemplo n.º 22
0
 def to_index_cst(value):
     return cst.Index(value)
Ejemplo n.º 23
0
def _get_match_metadata() -> cst.SubscriptElement:
    """
    Construct a MetadataMatchType entry appropriate for going into a Union.
    """

    return cst.SubscriptElement(cst.Index(cst.Name("MetadataMatchType")))
Ejemplo n.º 24
0
def _get_do_not_care() -> cst.SubscriptElement:
    """
    Construct a DoNotCareSentinel entry appropriate for going into a Union.
    """

    return cst.SubscriptElement(cst.Index(cst.Name("DoNotCareSentinel")))
Ejemplo n.º 25
0
def make_index(arr, idx):
    return cst.Subscript(
        value=arr, slice=[cst.SubscriptElement(slice=cst.Index(value=idx))])
Ejemplo n.º 26
0
def _add_generic(name: str, oldtype: cst.BaseExpression) -> cst.BaseExpression:
    return cst.Subscript(cst.Name(name),
                         (cst.SubscriptElement(cst.Index(oldtype)), ))
Ejemplo n.º 27
0
 def annotation(self):
     if self.name is None:
         return cst.Name("type")
     return cst.Subscript(
         cst.Name("Type"),
         [cst.SubscriptElement(cst.Index(self.name.annotation))])
Ejemplo n.º 28
0
class AnnAssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.AnnAssign(cst.Name("foo"), cst.Annotation(cst.Name("str")),
                          cst.Integer("5")),
            "code":
            "foo: str = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 12)),
        },
        # Annotation creation without assignment
        {
            "node": cst.AnnAssign(cst.Name("foo"),
                                  cst.Annotation(cst.Name("str"))),
            "code": "foo: str",
            "parser": None,
            "expected_position": CodeRange((1, 0), (1, 8)),
        },
        # Complex annotation creation
        {
            "node":
            cst.AnnAssign(
                cst.Name("foo"),
                cst.Annotation(
                    cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    )),
                cst.Integer("5"),
            ),
            "code":
            "foo: Optional[str] = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 22)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: str = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Annotation without assignment
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Name("str"),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                value=None,
            ), )),
            "code":
            "foo: str\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Complex annotation
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(""),
                ),
                equal=cst.AssignEqual(),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo: Optional[str] = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test
        {
            "node":
            cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ),
            "code":
            "foo :  Optional[str]  =  5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 26)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.AnnAssign(
                target=cst.Name("foo"),
                annotation=cst.Annotation(
                    annotation=cst.Subscript(
                        cst.Name("Optional"),
                        (cst.SubscriptElement(cst.Index(cst.Name("str"))), ),
                    ),
                    whitespace_before_indicator=cst.SimpleWhitespace(" "),
                    whitespace_after_indicator=cst.SimpleWhitespace("  "),
                ),
                equal=cst.AssignEqual(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
                value=cst.Integer("5"),
            ), )),
            "code":
            "foo :  Optional[str]  =  5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.AnnAssign(
            target=cst.Name("foo"),
            annotation=cst.Annotation(cst.Name("str")),
            equal=cst.AssignEqual(),
            value=None,
        )),
        "expected_re":
        "Must have a value when specifying an AssignEqual.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Ejemplo n.º 29
0
def _get_do_not_care() -> cst.ExtSlice:
    """
    Construct a DoNotCareSentinel entry appropriate for going into a Union.
    """

    return cst.ExtSlice(cst.Index(cst.Name("DoNotCareSentinel")))
Ejemplo n.º 30
0
    def test_subscript(self) -> None:
        # Test that we can insert various subscript slices into an
        # acceptible spot.
        expression = parse_template_expression(
            "Optional[{type}]", type=cst.Name("int"),
        )
        self.assertEqual(
            self.code(expression), "Optional[int]",
        )
        expression = parse_template_expression(
            "Tuple[{type1}, {type2}]", type1=cst.Name("int"), type2=cst.Name("str"),
        )
        self.assertEqual(
            self.code(expression), "Tuple[int, str]",
        )

        expression = parse_template_expression(
            "Optional[{type}]", type=cst.Index(cst.Name("int")),
        )
        self.assertEqual(
            self.code(expression), "Optional[int]",
        )
        expression = parse_template_expression(
            "Optional[{type}]", type=cst.SubscriptElement(cst.Index(cst.Name("int"))),
        )
        self.assertEqual(
            self.code(expression), "Optional[int]",
        )

        expression = parse_template_expression(
            "foo[{slice}]", slice=cst.Slice(cst.Integer("5"), cst.Integer("6")),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6]",
        )
        expression = parse_template_expression(
            "foo[{slice}]",
            slice=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6]",
        )

        expression = parse_template_expression(
            "foo[{slice}]", slice=cst.Slice(cst.Integer("5"), cst.Integer("6")),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6]",
        )
        expression = parse_template_expression(
            "foo[{slice}]",
            slice=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6]",
        )

        expression = parse_template_expression(
            "foo[{slice1}, {slice2}]",
            slice1=cst.Slice(cst.Integer("5"), cst.Integer("6")),
            slice2=cst.Index(cst.Integer("7")),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6, 7]",
        )
        expression = parse_template_expression(
            "foo[{slice1}, {slice2}]",
            slice1=cst.SubscriptElement(cst.Slice(cst.Integer("5"), cst.Integer("6"))),
            slice2=cst.SubscriptElement(cst.Index(cst.Integer("7"))),
        )
        self.assertEqual(
            self.code(expression), "foo[5:6, 7]",
        )