Esempio n. 1
0
 def test_does_not_match_true(self) -> None:
     # Match on any call that takes one argument that isn't the value None.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("True")), )),
             m.Call(args=(m.Arg(value=m.DoesNotMatch(m.Name("None"))), )),
         ))
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=(m.DoesNotMatch(m.Arg(m.Name("None"))), )),
         ))
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=m.DoesNotMatch((m.Arg(m.Integer("2")), ))),
         ))
     # Match any call that takes an argument which isn't True or False.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=(m.Arg(value=m.DoesNotMatch(
                 m.OneOf(m.Name("True"), m.Name("False")))), )),
         ))
     # Match any name node that doesn't match the regex for True
     self.assertTrue(
         matches(
             libcst.Name("False"),
             m.Name(value=m.DoesNotMatch(m.MatchRegex(r"True"))),
         ))
Esempio n. 2
0
 def test_at_least_n_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguments after that are
     # strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.SimpleString()), n=2),
                 ),
             ),
         ))
     # Fail to match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least three wildcard arguments after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)),
             ),
         ))
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguements that are integers with
     # the value 2 after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.Integer("2")), n=2),
                 ),
             ),
         ))
Esempio n. 3
0
 def test_at_least_n_matcher_args_true(self) -> None:
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two wildcard arguments after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=2)),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two arguements are integers of any value
     # after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer()), n=2)),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least two arguements that are integers with the
     # value 2 or 3 after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3"))), n=2),
                 ),
             ),
         )
     )
Esempio n. 4
0
 def test_at_most_n_matcher_args_true(self) -> None:
     # Match a function call to "foo" with at most two arguments, both of which
     # are the integer 1.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Integer("1")), )),
             m.Call(func=m.Name("foo"),
                    args=(m.AtMostN(m.Arg(m.Integer("1")), n=2), )),
         ))
     # Match a function call to "foo" with at most two arguments, both of which
     # can be the integer 1 or 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.AtMostN(m.Arg(
                     m.OneOf(m.Integer("1"), m.Integer("2"))),
                                 n=2), ),
             ),
         ))
     # Match a function call to "foo" with at most two arguments, the first
     # one being the integer 1 and the second one, if included, being the
     # integer 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")),
                       m.ZeroOrOne(m.Arg(m.Integer("2")))),
             ),
         ))
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1 and the second one, if included, being the
     # integer 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")),
                       cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")),
                       m.ZeroOrOne(m.Arg(m.Integer("2")))),
             ),
         ))
Esempio n. 5
0
 def test_at_least_n_matcher_no_args_false(self) -> None:
     # Fail to match a function call to "foo" with at least four arguments.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=4),)),
         )
     )
     # Fail to match a function call to "foo" with at least four arguments,
     # the first one being the value 1.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=3))
             ),
         )
     )
     # Fail to match a function call to "foo" with at least three arguments,
     # the last one being the value 2.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("2")))
             ),
         )
     )
Esempio n. 6
0
    def visit_BooleanOperation(self, node: cst.BooleanOperation) -> None:
        if node in self.seen_boolean_operations:
            return None

        stack = tuple(self.unwrap(node))
        operands, targets = self.collect_targets(stack)

        # If nothing gets collapsed, just exit from this short-path
        if len(operands) == len(stack):
            return None

        replacement = None
        for operand in operands:
            if operand in targets:
                matches = targets[operand]
                if len(matches) == 1:
                    arg = cst.Arg(value=matches[0])
                else:
                    arg = cst.Arg(cst.Tuple([cst.Element(match) for match in matches]))
                operand = cst.Call(cst.Name("isinstance"), [cst.Arg(operand), arg])

            if replacement is None:
                replacement = operand
            else:
                replacement = cst.BooleanOperation(
                    left=replacement, right=operand, operator=cst.Or()
                )

        if replacement is not None:
            self.report(node, replacement=replacement)
Esempio n. 7
0
 def pluck_asyncio_gather_expression_from_yield_list_or_list_comp(
     node: cst.Yield, ) -> cst.BaseExpression:
     return cst.Call(
         func=cst.Attribute(value=cst.Name("asyncio"),
                            attr=cst.Name("gather")),
         args=[cst.Arg(value=node.value, star="*")],
     )
Esempio n. 8
0
 def sampleop_to_logpdf(cst_generator, *args, **kwargs):
     name = kwargs.pop("var_name")
     return cst.Call(
         cst.Attribute(cst_generator(*args, **kwargs),
                       cst.Name("logpdf_sum")),
         [cst.Arg(name)],
     )
Esempio n. 9
0
 def to_sampler(cst_generator, *args, **kwargs):
     rng_key = kwargs.pop("rng_key")
     return cst.Call(
         func=cst.Attribute(value=cst_generator(*args, **kwargs),
                            attr=cst.Name("sample")),
         args=[cst.Arg(value=rng_key)],
     )
    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        for d in node.decorators:
            decorator = d.decorator
            if QualifiedNameProvider.has_name(
                self,
                decorator,
                QualifiedName(
                    name="dataclasses.dataclass", source=QualifiedNameSource.IMPORT
                ),
            ):
                if isinstance(decorator, cst.Call):
                    func = decorator.func
                    args = decorator.args
                else:  # decorator is either cst.Name or cst.Attribute
                    args = ()
                    func = decorator

                # pyre-fixme[29]: `typing.Union[typing.Callable(tuple.__iter__)[[], typing.Iterator[Variable[_T_co](covariant)]], typing.Callable(typing.Sequence.__iter__)[[], typing.Iterator[cst._nodes.expression.Arg]]]` is not a function.
                if not any(m.matches(arg.keyword, m.Name("frozen")) for arg in args):
                    new_decorator = cst.Call(
                        func=func,
                        args=list(args)
                        + [
                            cst.Arg(
                                keyword=cst.Name("frozen"),
                                value=cst.Name("True"),
                                equal=cst.AssignEqual(
                                    whitespace_before=SimpleWhitespace(value=""),
                                    whitespace_after=SimpleWhitespace(value=""),
                                ),
                            )
                        ],
                    )
                    self.report(d, replacement=d.with_changes(decorator=new_decorator))
    def test_gen_dotted_names(self) -> None:
        names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))}
        self.assertEqual(names, {"a"})

        names = {
            name
            for name, node in _gen_dotted_names(
                cst.Attribute(value=cst.Name(value="a"),
                              attr=cst.Name(value="b")))
        }
        self.assertEqual(names, {"a.b", "a"})

        names = {
            name
            for name, node in _gen_dotted_names(
                cst.Attribute(
                    value=cst.Call(
                        func=cst.Attribute(
                            value=cst.Attribute(value=cst.Name(value="a"),
                                                attr=cst.Name(value="b")),
                            attr=cst.Name(value="c"),
                        ),
                        args=[],
                    ),
                    attr=cst.Name(value="d"),
                ))
        }
        self.assertEqual(names, {"a.b.c", "a.b", "a"})
Esempio n. 12
0
    def _spot_reg_write(node: cst.Expr) -> Optional[NBAssign]:
        # Spot
        #
        #   state.gprs.get_reg(foo).write_unsigned(bar)
        #   state.gprs.get_reg(foo).write_signed(bar)
        #
        # and turn them into
        #
        #   GPRs[FOO] = bar
        #   GPRs[FOO] = to_2s_complement(bar)

        if not isinstance(node.value, cst.Call):
            return None

        call = node.value
        if len(call.args) != 1 or not isinstance(call.func, cst.Attribute):
            return None

        value = call.args[0].value

        if call.func.attr.value == 'write_unsigned':
            rhs = value
        elif call.func.attr.value == 'write_signed':
            rhs = cst.Call(func=cst.Name('to_2s_complement'),
                           args=[cst.Arg(value=value)])
        else:
            return None

        # We expect call.func.value to be match state.gprs.get_reg(foo).
        # Extract the array reference if we can.
        reg_ref = ImplTransformer.match_get_reg(call.func.value)
        if reg_ref is None:
            return None

        return NBAssign.make(reg_ref, rhs)
Esempio n. 13
0
 def test_and_matcher_false(self) -> None:
     # Fail to match since True and False cannot match.
     self.assertFalse(
         matches(cst.Name("None"), m.AllOf(m.Name("True"), m.Name("False")))
     )
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=m.AllOf(
                     (m.Arg(), m.Arg(), m.Arg()),
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                 ),
             ),
         )
     )
Esempio n. 14
0
 def test_and_matcher_true(self) -> None:
     # Match on True identifier in roundabout way.
     self.assertTrue(
         matches(
             cst.Name("True"), m.AllOf(m.Name(), m.Name(value=m.MatchRegex(r"True")))
         )
     )
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=m.AllOf(
                     (m.Arg(), m.Arg(), m.Arg()),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         )
     )
Esempio n. 15
0
 def test_at_most_n_matcher_no_args_false(self) -> None:
     # Fail to match a function call to "foo" with at most two arguments.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)),
         )
     )
     # Fail to match a function call to "foo" with at most two arguments,
     # the last one being the integer 3.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtMostN(n=1), m.Arg(m.Integer("3")))
             ),
         )
     )
     # Fail to match a function call to "foo" with at most two arguments,
     # the last one being the integer 3.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.ZeroOrOne(), m.Arg(m.Integer("3")))),
         )
     )
Esempio n. 16
0
 def to_call_cst(*args, **kwargs):
     # I don't exactly remember why we pass the `func` as a keyword
     # argument, but I think it has something to do with the fact
     # that at compilation the arguments are passed in the order they
     # were introduced in the graph, and nodes are deleted/re-inserted
     # when transforming to get logpdf and samplers.
     func = kwargs["__name__"]
     return cst.Call(func, args)
Esempio n. 17
0
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node
Esempio n. 18
0
 def test_adding_parens(self) -> None:
     node = cst.With(
         (
             cst.WithItem(
                 cst.Call(cst.Name("foo")),
                 comma=cst.Comma(
                     whitespace_after=cst.ParenthesizedWhitespace(), ),
             ),
             cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
         ),
         cst.SimpleStatementSuite((cst.Pass(), )),
         lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
         rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
     )
     module = cst.Module([])
     self.assertEqual(
         module.code_for_node(node),
         ("with ( foo(),\n"
          "bar(), ): pass\n")  # noqa
     )
Esempio n. 19
0
 def test_does_not_match_operator_false(self) -> None:
     # Match on any call that takes one argument that isn't the value None.
     self.assertFalse(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("None")), )),
             m.Call(args=(m.Arg(value=~m.Name("None")), )),
         ))
     self.assertFalse(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(args=((~m.Arg(m.Integer("1"))), )),
         ))
     # Match any call that takes an argument which isn't True or False.
     self.assertFalse(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("False")), )),
             m.Call(args=(m.Arg(
                 value=~(m.Name("True") | m.Name("False"))), )),
         ))
     # Roundabout way of verifying ~(x&y) behavior.
     self.assertFalse(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("False")), )),
             m.Call(args=(m.Arg(value=~(m.Name() & m.Name("False"))), )),
         ))
     # Roundabout way of verifying (~x)|(~y) behavior
     self.assertFalse(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Name("True")), )),
             m.Call(args=(m.Arg(value=(~m.Name("True"))
                                | (~m.Name("True"))), )),
         ))
     # Match any name node that doesn't match the regex for True
     self.assertFalse(
         matches(libcst.Name("True"), m.Name(value=~m.MatchRegex(r"True"))))
Esempio n. 20
0
 def test_at_most_n_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(m.Name("foo"), (m.AtMostN(n=2), )),
         ))
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(
                     libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.AtMostN(n=2), )),
         ))
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(libcst.Name("foo"),
                         (libcst.Arg(libcst.Integer("1")), )),
             m.Call(m.Name("foo"),
                    [m.AtMostN(n=5), m.Arg(m.Integer("1"))]),
         ))
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(
                     libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.AtMostN(n=5), m.Arg(m.Integer("2")))),
         ))
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(
                     libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.AtMostN(n=5))),
         ))
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (libcst.Arg(
                     libcst.Integer("1")), libcst.Arg(libcst.Integer("2"))),
             ),
             m.Call(m.Name("foo"), (m.Arg(m.Integer("1")), m.ZeroOrOne())),
         ))
Esempio n. 21
0
 def leave_Assert(self, _, updated_node):  # noqa
     test_code = cst.Module("").code_for_node(updated_node.test)
     try:
         test_literal = literal_eval(test_code)
     except Exception:
         return updated_node
     if test_literal:
         return cst.RemovalSentinel.REMOVE
     if updated_node.msg is None:
         return cst.Raise(cst.Name("AssertionError"))
     return cst.Raise(
         cst.Call(cst.Name("AssertionError"),
                  args=[cst.Arg(updated_node.msg)]))
Esempio n. 22
0
 def test_zero_or_more_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguments are strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))),
             ),
         )
     )
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguements are integers with the
     # value 2.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))),
             ),
         )
     )
 def test_statement(self) -> None:
     # Test that we can insert various types of statements into a
     # statement list.
     module = parse_template_module(
         "{statement1}\n{statement2}\n{statement3}\n",
         statement1=cst.If(
             test=cst.Name("foo"), body=cst.SimpleStatementSuite((cst.Pass(),),),
         ),
         statement2=cst.SimpleStatementLine((cst.Expr(cst.Call(cst.Name("bar"))),),),
         statement3=cst.Pass(),
     )
     self.assertEqual(
         module.code, "if foo: pass\nbar()\npass\n",
     )
Esempio n. 24
0
 def leave_IfExp(
         self,
         original_node: cst.IfExp,
         updated_node: cst.IfExp,
         ):
     return cst.Call(
             func=cst.Name(value=self.phi_name),
             args=[
                 cst.Arg(value=v) for v in (
                     updated_node.test,
                     updated_node.body,
                     updated_node.orelse
                 )
             ],
     )
Esempio n. 25
0
 def test_at_most_n_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" with at most three arguments,
     # all of which are the integer 4.
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(m.Name("foo"),
                    (m.AtMostN(m.Arg(m.Integer("4")), n=3), )),
         ))
Esempio n. 26
0
 def leave_Call(self, original: libcst.Call,
                updated: libcst.Call) -> libcst.Call:
     if m.matches(updated.func.value, m.Name("six")):
         for orig_name, updated_name in _CONVERSION_MAP.items():
             if m.matches(updated.func.attr, m.Name(orig_name)):
                 if len(updated.args) != 1:
                     self.warn(
                         f"Odd six.{orig_name} call does not have one argument. Cannot perform substitution."
                     )
                     continue
                 value = updated.args[0].value
                 return libcst.Call(func=libcst.Attribute(
                     value=value,
                     attr=libcst.Name(value=updated_name),
                 ))
     return updated
Esempio n. 27
0
 def choice_ast(rng_key):
     return cst.Call(
         func=cst.Attribute(
             value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")),
             attr=cst.Name("choice"),
         ),
         args=[
             cst.Arg(rng_key),
             cst.Arg(
                 cst.Subscript(
                     cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")),
                     [cst.SubscriptElement(cst.Index(cst.Integer("0")))],
                 )
             ),
         ],
     )
Esempio n. 28
0
 def _get_assert_replacement(self, node: cst.Assert):
     message = node.msg or str(cst.Module(body=[node]).code)
     return cst.If(
         test=cst.UnaryOperation(
             operator=cst.Not(),
             expression=node.test,  # Todo: parenthesize?
         ),
         body=cst.IndentedBlock(body=[
             cst.SimpleStatementLine(body=[
                 cst.Raise(exc=cst.Call(
                     func=cst.Name(value="AssertionError", ),
                     args=[
                         cst.Arg(value=cst.SimpleString(value=repr(message),
                                                        ), ),
                     ],
                 ), ),
             ]),
         ], ),
     )
Esempio n. 29
0
    def _spot_reg_read(node: cst.Call) -> Optional[cst.BaseExpression]:
        # Detect
        #
        #    state.gprs.get_reg(FOO).read_unsigned()
        #    state.gprs.get_reg(FOO).read_signed()
        #
        # and replace with the expressions
        #
        #    GPRs[FOO]
        #    from_2s_complement(GPRs[FOO])
        #
        # respectively.

        # In either case, we expect node.func to be some long attribute
        # (representing state.gprs.get_reg(FOO).read_X). For unsigned or
        # signed, we can check that it is indeed an Attribute and that
        # node.args is empty (neither function takes arguments).
        if node.args or not isinstance(node.func, cst.Attribute):
            return None

        # Now, check whether we're calling one of the functions we're
        # interested in.
        if node.func.attr.value == 'read_signed':
            signed = True
        elif node.func.attr.value == 'read_unsigned':
            signed = False
        else:
            return None

        # Check that node.func.value really does represent something of the
        # form "state.gprs.get_reg(FOO)".
        ret = ImplTransformer.match_get_reg(node.func.value)
        if ret is None:
            return None

        if signed:
            # If this is a call to read_signed, we want to wrap the returned
            # value in a call to a fake sign decode function.
            return cst.Call(func=cst.Name('from_2s_complement'),
                            args=[cst.Arg(value=ret)])
        else:
            return ret
Esempio n. 30
0
 def test_or_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(
         matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False")))
     )
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             libcst.Assign(
                 (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None")
             ),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         )
     )
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("4")),
                         m.Arg(m.Integer("5")),
                         m.Arg(m.Integer("6")),
                     ),
                 ),
             ),
         )
     )