Пример #1
0
    def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
        try:
            key = original.func.attr.value
            kword_params = self.METHOD_TO_PARAMS[key]
        except (AttributeError, KeyError):
            # Either not a method from the API or too convoluted to be sure.
            return updated

        # If the existing code is valid, keyword args come after positional args.
        # Therefore, all positional args must map to the first parameters.
        args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
        if any(k.keyword.value == "request" for k in kwargs):
            # We've already fixed this file, don't fix it again.
            return updated

        kwargs, ctrl_kwargs = partition(
            lambda a: not a.keyword.value in self.CTRL_PARAMS,
            kwargs
        )

        args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
        ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl))
                           for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))

        request_arg = cst.Arg(
            value=cst.Dict([
                cst.DictElement(
                    cst.SimpleString("'{}'".format(name)),
                    cst.Element(value=arg.value)
                )
                # Note: the args + kwargs looks silly, but keep in mind that
                # the control parameters had to be stripped out, and that
                # those could have been passed positionally or by keyword.
                for name, arg in zip(kword_params, args + kwargs)]),
            keyword=cst.Name("request")
        )

        return updated.with_changes(
            args=[request_arg] + ctrl_kwargs
        )
Пример #2
0
class ExpressionTest(UnitTest):
    @data_provider((
        ("a string", "a string"),
        (cst.Name("a_name"), "a_name"),
        (cst.parse_expression("a.b.c"), "a.b.c"),
        (cst.parse_expression("a.b()"), "a.b"),
        (cst.parse_expression("a.b.c[i]"), "a.b.c"),
        (cst.parse_statement("def fun():  pass"), "fun"),
        (cst.parse_statement("class cls:  pass"), "cls"),
        (
            cst.Decorator(
                ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)),
            "a.b.c.d",
        ),
        (cst.parse_statement("(a.b()).c()"),
         None),  # not a supported Node type
    ))
    def test_get_full_name_for_expression(
        self,
        input: Union[str, cst.CSTNode],
        output: Optional[str],
    ) -> None:
        self.assertEqual(get_full_name_for_node(input), output)
        if output is None:
            with self.assertRaises(Exception):
                get_full_name_for_node_or_raise(input)
        else:
            self.assertEqual(get_full_name_for_node_or_raise(input), output)

    def test_simplestring_evaluated_value(self) -> None:
        raw_string = '"a string."'
        node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString)
        self.assertEqual(node.value, raw_string)
        self.assertEqual(node.evaluated_value, literal_eval(raw_string))

    def test_integer_evaluated_value(self) -> None:
        raw_value = "5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Integer)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_float_evaluated_value(self) -> None:
        raw_value = "5.5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Float)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_complex_evaluated_value(self) -> None:
        raw_value = "5j"
        node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))
Пример #3
0
def _get_wrapped_union_type(
    node: cst.BaseExpression,
    addition: cst.SubscriptElement,
    *additions: cst.SubscriptElement,
) -> cst.Subscript:
    """
    Take two or more nodes, wrap them in a union type. Function signature is
    explicitly defined as taking at least one addition for type safety.
    """

    return cst.Subscript(
        cst.Name("Union"),
        [cst.SubscriptElement(cst.Index(node)), addition, *additions])
Пример #4
0
 def leave_Module(self, node: cst.Module,
                  updated_node: cst.Module) -> cst.CSTNode:
     body = list(updated_node.body)
     index = self._get_toplevel_index(body)
     for name, annotation in self.toplevel_annotations.items():
         annotated_assign = cst.AnnAssign(
             cst.Name(name),
             # pyre-fixme[16]: `CSTNode` has no attribute `annotation`.
             cst.Annotation(annotation.annotation),
             None,
         )
         body.insert(index, cst.SimpleStatementLine([annotated_assign]))
     return updated_node.with_changes(body=tuple(body))
Пример #5
0
 def leave_StarImport(
     original_node: cst.ImportFrom,
     updated_node: cst.ImportFrom,
     imp: ImportFrom,
 ) -> Union[cst.ImportFrom, cst.RemovalSentinel]:
     if imp.modules:
         names_to_suggestion = [
             cst.ImportAlias(cst.Name(module)) for module in imp.modules
         ]
         return updated_node.with_changes(names=names_to_suggestion)
     elif imp.module:
         return cst.RemoveFromParent()
     return original_node
Пример #6
0
    def test_annotation(self) -> None:
        # Test that we can insert an annotation expression normally.
        statement = parse_template_statement(
            "x: {type} = {val}",
            type=cst.Name("int"),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement),
            "x: int = 5\n",
        )

        # Test that we can insert an annotation node as a special case.
        statement = parse_template_statement(
            "x: {type} = {val}",
            type=cst.Annotation(cst.Name("int")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement),
            "x: int = 5\n",
        )
Пример #7
0
    def test_args(self) -> None:
        # Test that we can insert an argument into a function call normally.
        statement = parse_template_expression(
            "foo({arg1}, {arg2})",
            arg1=cst.Name("bar"),
            arg2=cst.Name("baz"),
        )
        self.assertEqual(
            self.code(statement),
            "foo(bar, baz)",
        )

        # Test that we can insert an argument as a special case.
        statement = parse_template_expression(
            "foo({arg1}, {arg2})",
            arg1=cst.Arg(cst.Name("bar")),
            arg2=cst.Arg(cst.Name("baz")),
        )
        self.assertEqual(
            self.code(statement),
            "foo(bar, baz)",
        )
Пример #8
0
 def test_at_most_n_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)),
             m.Call(m.Name("foo"), (m.AtMostN(n=2),)),
         )
     )
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.AtMostN(n=2),)),
         )
     )
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"), (libcst.Arg(libcst.Integer("1")),)),
             m.Call(m.Name("foo"), [m.AtMostN(n=5), m.Arg(m.Integer("1"))]),
         )
     )
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))),
         )
     )
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))),
         )
     )
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())),
         )
     )
Пример #9
0
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Migrate form deprecated method AppendItem()
        if matchers.matches(updated_node, self.deprecated_call_matcher):
            updated_node = updated_node.with_changes(
                func=updated_node.func.with_changes(attr=cst.Name(
                    value="Append")))

        # Update keywords
        if matchers.matches(updated_node, self.call_matcher):
            updated_node_args = list(updated_node.args)

            for arg_matcher, renamed in self.args_matchers_map.items():
                for i, node_arg in enumerate(updated_node.args):
                    if matchers.matches(node_arg, arg_matcher):
                        updated_node_args[i] = node_arg.with_changes(
                            keyword=cst.Name(value=renamed))

                updated_node = updated_node.with_changes(
                    args=updated_node_args)

        return updated_node
Пример #10
0
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                    self,
                    decorator,
                    QualifiedName(name="dataclasses.dataclass",
                                  source=QualifiedNameSource.IMPORT),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-\fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]],
                # typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.

                if not any(
                        m.matches(arg.keyword, m.Name("frozen"))
                        for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args) + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(
                                        value=""),
                                    whitespace_after=SimpleWhitespace(
                                        value=""),
                                ),
                            )
                        ],
                    )
                    self.report(
                        d, replacement=d.with_changes(decorator=new_decorator))
Пример #11
0
 def test_does_not_match_operator_false(self) -> None:
     # Match on any call that takes one argument that isn't the value None.
     self.assertFalse(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("None")),)),
             m.Call(args=(m.Arg(value=~m.Name("None")),)),
         )
     )
     self.assertFalse(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)),
             m.Call(args=((~m.Arg(m.Integer("1"))),)),
         )
     )
     # Match any call that takes an argument which isn't True or False.
     self.assertFalse(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")),)),
             m.Call(args=(m.Arg(value=~(m.Name("True") | m.Name("False"))),)),
         )
     )
     # Roundabout way of verifying ~(x&y) behavior.
     self.assertFalse(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")),)),
             m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))),)),
         )
     )
     # Roundabout way of verifying (~x)|(~y) behavior
     self.assertFalse(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("True")),)),
             m.Call(args=(m.Arg(value=(~m.Name("True")) | (~m.Name("True"))),)),
         )
     )
     # Match any name node that doesn't match the regex for True
     self.assertFalse(
         matches(cst.Name("True"), m.Name(value=~m.MatchRegex(r"True")))
     )
Пример #12
0
def _get_clean_type_from_subscript(
        aliases: List[Alias], typecst: cst.Subscript) -> cst.BaseExpression:
    # We can modify this as-is to add our extra values
    if not typecst.value.deep_equals(cst.Name("Union")):
        # Don't handle other types like "Literal", just widen them.
        return _get_clean_type_from_expression(aliases, typecst)
    name = _get_alias_name(typecst)
    value = typecst.with_changes(slice=[
        *typecst.slice,
        _get_match_metadata(),
        _get_match_if_true(typecst)
    ])
    return _wrap_clean_type(aliases, name, value)
Пример #13
0
 def visit_Subscript(self, node: cst.Subscript) -> None:
     # If the current node is a MatchIfTrue, we don't want to modify it.
     if node.value.deep_equals(cst.Name("MatchIfTrue")):
         self.in_match_if_true.add(node)
     # If the direct descendant is a union, lets add it to be fixed up.
     elif node.value.deep_equals(cst.Name("Sequence")):
         if self.in_match_if_true:
             # We don't want to add AtLeastN/AtMostN inside MatchIfTrue
             # type blocks, even for sequence types.
             return
         slc = node.slice
         # TODO: We can remove the instance check after ExtSlice is deprecated.
         if not isinstance(slc, Sequence) or len(slc) != 1:
             raise Exception(
                 "Unexpected number of sequence elements inside Sequence type "
                 + "annotation!")
         nodeslice = slc[0].slice
         if isinstance(nodeslice, cst.Index):
             possibleunion = nodeslice.value
             if isinstance(possibleunion, cst.Subscript):
                 if possibleunion.value.deep_equals(cst.Name("Union")):
                     self.fixup_nodes.add(possibleunion)
Пример #14
0
 def test_or_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(
         matches(libcst.Name("True"),
                 m.OneOf(m.Name("True"), m.Name("False"))))
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ),
                           libcst.Name("True")),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         ))
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         ))
Пример #15
0
    def body(
        self,
    ) -> typing.Iterable[typing.Union[cst.BaseCompoundStatement,
                                      cst.SimpleStatementLine]]:
        yield cst.SimpleStatementLine(
            [cst.ImportFrom(cst.Name("typing"), names=cst.ImportStar())])
        yield from assign_properties(self.properties)

        yield from function_defs(self.function_overloads, self.functions,
                                 "function")

        for name, class_ in sort_items(self.classes):
            yield class_.class_def(name)
Пример #16
0
 def test_or_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(
         matches(libcst.Name("None"),
                 m.OneOf(m.Name("True"), m.Name("False"))))
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ),
                           libcst.Name("None")),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         ))
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("4")),
                         m.Arg(m.Integer("5")),
                         m.Arg(m.Integer("6")),
                     ),
                 ),
             ),
         ))
    def test_assign_target(self) -> None:
        # Test that we can insert an assignment target normally.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.Name("first"),
            b=cst.Name("second"),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )

        # Test that we can insert an assignment target as a special case.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.AssignTarget(cst.Name("first")),
            b=cst.AssignTarget(cst.Name("second")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )
Пример #18
0
    def test_findall_with_metadata_wrapper(self) -> None:
        # Find all assignments in a tree
        code = """
            a = 1
            b = True

            def foo(bar: int) -> bool:
                return False
        """

        module = cst.parse_module(dedent(code))
        wrapper = meta.MetadataWrapper(module)

        # Test that when we find over a wrapper, we implicitly use it for
        # metadata as well as traversal.
        booleans = findall(
            wrapper,
            m.MatchMetadata(meta.ExpressionContextProvider,
                            meta.ExpressionContext.STORE),
        )
        self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")])

        # Test that we can provide an explicit resolver and tree
        booleans = findall(
            wrapper.module,
            m.MatchMetadata(meta.ExpressionContextProvider,
                            meta.ExpressionContext.STORE),
            metadata_resolver=wrapper,
        )
        self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")])

        # Test that failing to provide metadata leads to no match
        booleans = findall(
            wrapper.module,
            m.MatchMetadata(meta.ExpressionContextProvider,
                            meta.ExpressionContext.STORE),
        )
        self.assertNodeSequenceEqual(booleans, [])
Пример #19
0
    def test_findall_with_transformers(self) -> None:
        # Find all assignments in a tree
        class TestTransformer(m.MatcherDecoratableTransformer):
            METADATA_DEPENDENCIES: Sequence[meta.ProviderT] = (
                meta.ExpressionContextProvider, )

            def __init__(self) -> None:
                super().__init__()
                self.results: Sequence[cst.CSTNode] = ()

            def visit_Module(self, node: cst.Module) -> None:
                self.results = self.findall(
                    node,
                    m.MatchMetadata(meta.ExpressionContextProvider,
                                    meta.ExpressionContext.STORE),
                )

        code = """
            a = 1
            b = True

            def foo(bar: int) -> bool:
                return False
        """

        module = cst.parse_module(dedent(code))
        wrapper = meta.MetadataWrapper(module)
        visitor = TestTransformer()
        wrapper.visit(visitor)
        self.assertNodeSequenceEqual(
            visitor.results,
            [
                cst.Name("a"),
                cst.Name("b"),
                cst.Name("foo"),
                cst.Name("bar"),
            ],
        )
Пример #20
0
def _get_clean_type_from_subscript(
        aliases: List[Alias], typecst: cst.Subscript) -> cst.BaseExpression:
    if typecst.value.deep_equals(cst.Name("Sequence")):
        # Lets attempt to widen the sequence type and alias it.
        slc = typecst.slice
        # TODO: This instance check can go away once we deprecate ExtSlice
        if not isinstance(slc, Sequence):
            raise Exception("Logic error, expected Sequence to have children!")

        if len(slc) != 1:
            raise Exception(
                "Logic error, Sequence shouldn't have more than one param!")
        inner_type = slc[0].slice
        if not isinstance(inner_type, cst.Index):
            raise Exception(
                "Logic error, expecting Index for only Sequence element!")
        inner_type = inner_type.value

        if isinstance(inner_type, cst.Subscript):
            clean_inner_type = _get_clean_type_from_subscript(
                aliases, inner_type)
        elif isinstance(inner_type, (cst.Name, cst.SimpleString)):
            clean_inner_type = _get_clean_type_from_expression(
                aliases, inner_type)
        else:
            raise Exception("Logic error, unexpected type in Sequence!")

        return _get_wrapped_union_type(
            typecst.deep_replace(inner_type, clean_inner_type),
            _get_do_not_care(),
            _get_match_if_true(typecst),
        )
    # We can modify this as-is to add our extra values
    elif typecst.value.deep_equals(cst.Name("Union")):
        return _get_clean_type_from_union(aliases, typecst)
    else:
        # Don't handle other types like "Literal", just widen them.
        return _get_clean_type_from_expression(aliases, typecst)
Пример #21
0
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.ExtSlice:
    """
    Construct a MatchIfTrue type node appropriate for going into a Union.
    """
    return cst.ExtSlice(
        cst.Index(
            cst.Subscript(
                cst.Name("MatchIfTrue"),
                cst.Index(
                    cst.Subscript(
                        cst.Name("Callable"),
                        slice=[
                            cst.ExtSlice(
                                cst.Index(
                                    cst.List(
                                        [
                                            cst.Element(
                                                # MatchIfTrue takes in the original node type,
                                                # and returns a boolean. So, lets convert our
                                                # quoted classes (forward refs to other
                                                # matchers) back to the CSTNode they refer to.
                                                # We can do this because there's always a 1:1
                                                # name mapping.
                                                _convert_match_nodes_to_cst_nodes(
                                                    oldtype
                                                )
                                            )
                                        ]
                                    )
                                )
                            ),
                            cst.ExtSlice(cst.Index(cst.Name("bool"))),
                        ],
                    )
                ),
            )
        )
    )
Пример #22
0
 def test_zero_or_more_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguments are strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))),
             ),
         )
     )
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguements are integers with the
     # value 2.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))),
             ),
         )
     )
Пример #23
0
 def leave_StarImport(original_node: cst.ImportFrom,
                      updated_node: cst.ImportFrom,
                      **kwargs) -> Union[cst.ImportFrom, RemovalSentinel]:
     imp = kwargs["imp"]
     if imp["modules"]:
         modules = ",".join(imp["modules"])
         names_to_suggestion = []
         for module in modules.split(","):
             names_to_suggestion.append(cst.ImportAlias(cst.Name(module)))
         return updated_node.with_changes(names=names_to_suggestion)
     else:
         if imp["module"]:
             return cst.RemoveFromParent()
     return original_node
Пример #24
0
def assign_properties(
        p: typing.Dict[str, typing.Tuple[Metadata, Type]],
        is_classvar=False) -> typing.Iterable[cst.SimpleStatementLine]:
    for name, metadata_and_tp in sort_items(p):
        if bad_name(name):
            continue
        metadata, tp = metadata_and_tp
        ann = tp.annotation
        yield cst.SimpleStatementLine(
            [
                cst.AnnAssign(
                    cst.Name(name),
                    cst.Annotation(
                        cst.Subscript(cst.Name("ClassVar"),
                                      [cst.SubscriptElement(cst.Index(ann))]
                                      ) if is_classvar else ann),
                )
            ],
            leading_lines=[cst.EmptyLine()] + [
                cst.EmptyLine(comment=cst.Comment("# " + l))
                for l in metadata_lines(metadata)
            ],
        )
Пример #25
0
    def leave_Attribute(self,
                        orig: cst.Attribute,
                        updated: cst.Attribute) -> cst.BaseExpression:
        if isinstance(updated.value, cst.Name):
            stem = updated.value.value

            # Strip out "self." references. In the ISS code, a field in the
            # instruction appears as self.field_name. In the documentation, we
            # can treat all the instruction fields as being in scope.
            if stem == 'self':
                return updated.attr

            # Replace state.dmem with DMEM. This is an object in the ISS code,
            # so you see things like state.dmem.load_u32(...). We keep the
            # "object-orientated" style (so DMEM.load_u32(...)) because we need
            # to distinguish between 32-bit and 256-bit loads.
            if stem == 'state' and updated.attr.value == 'dmem':
                return cst.Name(value='DMEM')

        if isinstance(updated.value, cst.Attribute):
            # This attribute looks like A.B.C where B, C are names and A may be
            # a further attribute or it might be a name.
            attr_a = updated.value.value
            attr_b = updated.value.attr.value
            attr_c = updated.attr.value

            if isinstance(attr_a, cst.Name):
                stem = attr_a.value

                # Replace state.csrs.flags with FLAGs: the flag groups are
                # stored in the CSRs in the ISS and the implementation, but
                # logically exist somewhat separately, so we want named
                # reads/writes from them to look different.
                if (stem, attr_b, attr_c) == ('state', 'csrs', 'flags'):
                    return cst.Name(value='FLAGs')

        return updated
Пример #26
0
 def leave_IfExp(
         self,
         original_node: cst.IfExp,
         updated_node: cst.IfExp,
         ):
     return cst.Call(
             func=cst.Name(value=self.phi_name),
             args=[
                 cst.Arg(value=v) for v in (
                     updated_node.test,
                     updated_node.body,
                     updated_node.orelse
                 )
             ],
     )
Пример #27
0
 def leave_Subscript(self, original_node: cst.Subscript,
                     updated_node: cst.Subscript) -> cst.Subscript:
     if updated_node.value.deep_equals(cst.Name("Sequence")):
         slc = updated_node.slice
         # TODO: We can remove the instance check after ExtSlice is deprecated.
         if not isinstance(slc, Sequence) or len(slc) != 1:
             raise Exception(
                 "Unexpected number of sequence elements inside Sequence type "
                 + "annotation!")
         nodeslice = slc[0].slice
         if isinstance(nodeslice, cst.Index):
             possibleunion = nodeslice.value
             if isinstance(possibleunion, cst.Subscript):
                 # Special case for Sequence[Union] so that we make more collapsed
                 # types.
                 if possibleunion.value.deep_equals(cst.Name("Union")):
                     return updated_node.with_deep_changes(
                         possibleunion,
                         slice=[
                             *possibleunion.slice,
                             _get_do_not_care(),
                             _get_match_metadata(),
                         ],
                     )
             # This is a sequence of some node, add DoNotCareSentinel here so that
             # a person can add a do not care to a sequence that otherwise has
             # valid matcher nodes.
             return updated_node.with_changes(slice=(cst.SubscriptElement(
                 cst.Index(
                     _get_wrapped_union_type(
                         nodeslice.value,
                         _get_do_not_care(),
                         _get_match_metadata(),
                     ))), ))
         raise Exception("Unexpected slice type for Sequence!")
     return updated_node
Пример #28
0
    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Matches calls with symbols without the wx prefix
        for symbol, matcher, renamed in self.matchers_short_map:
            if symbol in self.wx_imports and matchers.matches(
                    updated_node, matcher):
                # Remove the symbol's import
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)

                # Add import of top level wx package
                AddImportsVisitor.add_needed_import(self.context, "wx")

                # Return updated node
                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(func=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value=renamed)))

        # Matches full calls like wx.MySymbol
        for matcher, renamed in self.matchers_full_map:
            if matchers.matches(updated_node, matcher):

                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(
                    func=updated_node.func.with_changes(attr=cst.Name(
                        value=renamed)))

        # Returns updated node
        return updated_node
    def test_simple_module(self) -> None:
        module = parse_template_module(
            self.dedent(
                """
                from {module} import {obj}

                def foo() -> {obj}:
                    return {obj}()
                """
            ),
            module=cst.Name("foo"),
            obj=cst.Name("Bar"),
        )
        self.assertEqual(
            module.code,
            self.dedent(
                """
                from foo import Bar

                def foo() -> Bar:
                    return Bar()
                """
            ),
        )
Пример #30
0
def _get_alias_name(node: cst.CSTNode) -> Optional[str]:
    if isinstance(node, (cst.Name, cst.SimpleString)):
        return f"{_get_raw_name(node)}MatchType"
    elif isinstance(node, cst.Subscript):
        if node.value.deep_equals(cst.Name("Union")):
            slc = node.slice
            # TODO: This instance check can go away once we deprecate ExtSlice
            if isinstance(slc, Sequence):
                names = [_get_raw_name(s) for s in slc]
                if any(n is None for n in names):
                    return None
                return "Or".join(n
                                 for n in names if n is not None) + "MatchType"

    return None