Exemplo n.º 1
0
 def leave_AnnAssign(
     self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
 ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]:
     # It handles a special case where a type-annotated variable has not initialized, e.g. foo: str
     # This case will be converted to foo = ... so that nodes traversal won't encounter exceptions later on
     if match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Name(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     # Handles type-annotated class attributes that has not been initialized, e.g. self.foo: str
     elif match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Attribute(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     else:
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=original_node.value)
     return updated_node
Exemplo n.º 2
0
    def leave_AnnAssign(self, original_node: cst.AnnAssign,
                        updated_node: cst.AnnAssign):

        if updated_node.value is None:
            new_val = cst.Name(value='None', lpar=[], rpar=[])
        else:
            new_val = self.obf_universal(updated_node.value, 'v', 'a', 'ca')

        updated_node = updated_node.with_changes(value=new_val)
        updated_node = updated_node.with_changes(
            target=self.obf_universal(updated_node.target, 'v', 'a', 'ca'))

        if self.delete_annotations:
            space = cst.SimpleWhitespace(value=' ')
            updated_node = cst.Assign(targets=[
                cst.AssignTarget(
                    target=updated_node.target,
                    whitespace_before_equal=space,
                    whitespace_after_equal=space,
                )
            ],
                                      value=updated_node.value,
                                      semicolon=updated_node.semicolon)

            # new_annotation = updated_node.annotation.with_changes(annotation=None)
            # updated_node = updated_node.with_changes(annotation=new_annotation)

        return updated_node
Exemplo n.º 3
0
def make_assign(
    lhs: cst.BaseAssignTargetExpression,
    rhs: cst.BaseExpression,
) -> cst.Assign:
    return cst.Assign(
        targets=[
            cst.AssignTarget(lhs),
        ],
        value=rhs,
    )
Exemplo n.º 4
0
 def test_or_operator_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(matches(cst.Name("None"), m.Name("True") | m.Name("False")))
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("None")),
             m.Assign(value=m.Name("True") | m.Name("False")),
         )
     )
Exemplo n.º 5
0
    def leave_AnnAssign(self, original_node: cst.AnnAssign,
                        updated_node: cst.AnnAssign):
        if updated_node.value is None:
            # Annotate assignments so they can be commented out by a second pass
            return updated_node.with_changes(
                target=cst.Name("__COMMENT__" + original_node.target.value))
            # return cst.RemoveFromParent()

        return cst.Assign(
            targets=[cst.AssignTarget(target=updated_node.target)],
            value=updated_node.value)
Exemplo n.º 6
0
    def leave_AnnAssign(self, original_node: cst.AnnAssign,
                        updated_node: cst.AnnAssign):
        if updated_node.value is None:
            # e.g. `some_var: str`

            # these are *only* type declarations and have no runtime behavior,
            # so they should be removed entirely:
            return cst.RemoveFromParent()

        return cst.Assign(
            targets=[cst.AssignTarget(target=updated_node.target)],
            value=updated_node.value)
    def test_assign_target(self) -> None:
        # Test that we can insert an assignment target normally.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.Name("first"),
            b=cst.Name("second"),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )

        # Test that we can insert an assignment target as a special case.
        statement = parse_template_statement(
            "{a} = {b} = {val}",
            a=cst.AssignTarget(cst.Name("first")),
            b=cst.AssignTarget(cst.Name("second")),
            val=cst.Integer("5"),
        )
        self.assertEqual(
            self.code(statement), "first = second = 5\n",
        )
Exemplo n.º 8
0
 def test_or_operator_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(matches(cst.Name("True"), m.Name("True") | m.Name("False")))
     # Match on either True or False identifier.
     self.assertTrue(matches(cst.Name("False"), m.Name("True") | m.Name("False")))
     # Match on either True, False or None identifier.
     self.assertTrue(
         matches(cst.Name("None"), m.Name("True") | m.Name("False") | m.Name("None"))
     )
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("True")),
             m.Assign(value=m.Name("True") | m.Name("False")),
         )
     )
Exemplo n.º 9
0
 def test_or_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(
         matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False")))
     )
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             libcst.Assign(
                 (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None")
             ),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         )
     )
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("4")),
                         m.Arg(m.Integer("5")),
                         m.Arg(m.Integer("6")),
                     ),
                 ),
             ),
         )
     )
Exemplo n.º 10
0
 def test_or_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(
         matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False")))
     )
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             libcst.Assign(
                 (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("True")
             ),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         )
     )
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         )
     )
Exemplo n.º 11
0
class FooterBehaviorTest(UnitTest):
    @data_provider({
        # Literally the most basic example
        "simple_module": {
            "code": "\n",
            "expected_module": cst.Module(body=())
        },
        # A module with a header comment
        "header_only_module": {
            "code":
            "# This is a header comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[],
            ),
        },
        # A module with a header and footer
        "simple_header_footer_module": {
            "code":
            "# This is a header comment\npass\n# This is a footer comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[cst.SimpleStatementLine([cst.Pass()])],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # A module which should have a footer comment taken from the
        # if statement's indented block.
        "simple_reparented_footer_module": {
            "code":
            "# This is a header comment\nif True:\n    pass\n# This is a footer comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[
                    cst.If(
                        test=cst.Name(value="True"),
                        body=cst.IndentedBlock(
                            header=cst.TrailingWhitespace(),
                            body=[
                                cst.SimpleStatementLine(
                                    body=[cst.Pass()],
                                    trailing_whitespace=cst.TrailingWhitespace(
                                    ),
                                )
                            ],
                        ),
                    )
                ],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # Verifying that we properly parse and spread out footer comments to the
        # relative indents they go with.
        "complex_reparented_footer_module": {
            "code":
            ("# This is a header comment\nif True:\n    if True:\n        pass"
             +
             "\n        # This is an inner indented block comment\n    # This "
             +
             "is an outer indented block comment\n# This is a footer comment\n"
             ),
            "expected_module":
            cst.Module(
                body=[
                    cst.If(
                        test=cst.Name(value="True"),
                        body=cst.IndentedBlock(
                            body=[
                                cst.If(
                                    test=cst.Name(value="True"),
                                    body=cst.IndentedBlock(
                                        body=[
                                            cst.SimpleStatementLine(
                                                body=[cst.Pass()])
                                        ],
                                        footer=[
                                            cst.EmptyLine(comment=cst.Comment(
                                                value=
                                                "# This is an inner indented block comment"
                                            ))
                                        ],
                                    ),
                                )
                            ],
                            footer=[
                                cst.EmptyLine(comment=cst.Comment(
                                    value=
                                    "# This is an outer indented block comment"
                                ))
                            ],
                        ),
                    )
                ],
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # Verify that comments belonging to statements are still owned even
        # after an indented block.
        "statement_comment_reparent": {
            "code":
            "if foo:\n    return\n# comment\nx = 7\n",
            "expected_module":
            cst.Module(body=[
                cst.If(
                    test=cst.Name(value="foo"),
                    body=cst.IndentedBlock(body=[
                        cst.SimpleStatementLine(body=[
                            cst.Return(
                                whitespace_after_return=cst.SimpleWhitespace(
                                    value=""))
                        ])
                    ]),
                ),
                cst.SimpleStatementLine(
                    body=[
                        cst.Assign(
                            targets=[
                                cst.AssignTarget(target=cst.Name(value="x"))
                            ],
                            value=cst.Integer(value="7"),
                        )
                    ],
                    leading_lines=[
                        cst.EmptyLine(comment=cst.Comment(value="# comment"))
                    ],
                ),
            ]),
        },
        # Verify that even if there are completely empty lines, we give all lines
        # up to and including the last line that's indented correctly. That way
        # comments that line up with indented block's indentation level aren't
        # parented to the next line just because there's a blank line or two
        # between them.
        "statement_comment_with_empty_lines": {
            "code":
            ("def foo():\n    if True:\n        pass\n\n        # Empty " +
             "line before me\n\n    else:\n        pass\n"),
            "expected_module":
            cst.Module(body=[
                cst.FunctionDef(
                    name=cst.Name(value="foo"),
                    params=cst.Parameters(),
                    body=cst.IndentedBlock(body=[
                        cst.If(
                            test=cst.Name(value="True"),
                            body=cst.IndentedBlock(
                                body=[
                                    cst.SimpleStatementLine(body=[cst.Pass()])
                                ],
                                footer=[
                                    cst.EmptyLine(indent=False),
                                    cst.EmptyLine(comment=cst.Comment(
                                        value="# Empty line before me")),
                                ],
                            ),
                            orelse=cst.Else(
                                body=cst.IndentedBlock(body=[
                                    cst.SimpleStatementLine(body=[cst.Pass()])
                                ]),
                                leading_lines=[cst.EmptyLine(indent=False)],
                            ),
                        )
                    ]),
                )
            ]),
        },
    })
    def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None:
        parsed_module = parse_module(dedent(code))
        self.assertTrue(
            deep_equals(parsed_module, expected_module),
            msg=
            f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}",
        )
Exemplo n.º 12
0
class AssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.Assign((cst.AssignTarget(cst.Name("foo")), ),
                       cst.Integer("5")),
            "code":
            "foo = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 7)),
        },
        # Multiple targets creation
        {
            "node":
            cst.Assign(
                (
                    cst.AssignTarget(cst.Name("foo")),
                    cst.AssignTarget(cst.Name("bar")),
                ),
                cst.Integer("5"),
            ),
            "code":
            "foo = bar = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 13)),
        },
        # Whitespace test for creating nodes
        {
            "node":
            cst.Assign(
                (cst.AssignTarget(
                    cst.Name("foo"),
                    whitespace_before_equal=cst.SimpleWhitespace(""),
                    whitespace_after_equal=cst.SimpleWhitespace(""),
                ), ),
                cst.Integer("5"),
            ),
            "code":
            "foo=5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 5)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), )),
            "code":
            "foo = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Multiple targets parser
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (
                    cst.AssignTarget(cst.Name("foo")),
                    cst.AssignTarget(cst.Name("bar")),
                ),
                cst.Integer("5"),
            ), )),
            "code":
            "foo = bar = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test parser
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (cst.AssignTarget(
                    cst.Name("foo"),
                    whitespace_before_equal=cst.SimpleWhitespace(""),
                    whitespace_after_equal=cst.SimpleWhitespace(""),
                ), ),
                cst.Integer("5"),
            ), )),
            "code":
            "foo=5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.Assign(targets=(), value=cst.Integer("5"))),
        "expected_re":
        "at least one AssignTarget",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Exemplo n.º 13
0
class AssignTest(CSTNodeTest):
    @data_provider((
        # Simple assignment creation case.
        {
            "node":
            cst.Assign((cst.AssignTarget(cst.Name("foo")), ),
                       cst.Integer("5")),
            "code":
            "foo = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 7)),
        },
        # Multiple targets creation
        {
            "node":
            cst.Assign(
                (
                    cst.AssignTarget(cst.Name("foo")),
                    cst.AssignTarget(cst.Name("bar")),
                ),
                cst.Integer("5"),
            ),
            "code":
            "foo = bar = 5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 13)),
        },
        # Whitespace test for creating nodes
        {
            "node":
            cst.Assign(
                (cst.AssignTarget(
                    cst.Name("foo"),
                    whitespace_before_equal=cst.SimpleWhitespace(""),
                    whitespace_after_equal=cst.SimpleWhitespace(""),
                ), ),
                cst.Integer("5"),
            ),
            "code":
            "foo=5",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 0), (1, 5)),
        },
        # Simple assignment parser case.
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), )),
            "code":
            "foo = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Multiple targets parser
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (
                    cst.AssignTarget(cst.Name("foo")),
                    cst.AssignTarget(cst.Name("bar")),
                ),
                cst.Integer("5"),
            ), )),
            "code":
            "foo = bar = 5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
        # Whitespace test parser
        {
            "node":
            cst.SimpleStatementLine((cst.Assign(
                (cst.AssignTarget(
                    cst.Name("foo"),
                    whitespace_before_equal=cst.SimpleWhitespace(""),
                    whitespace_after_equal=cst.SimpleWhitespace(""),
                ), ),
                cst.Integer("5"),
            ), )),
            "code":
            "foo=5\n",
            "parser":
            parse_statement,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node": (lambda: cst.Assign(targets=(), value=cst.Integer("5"))),
        "expected_re":
        "at least one AssignTarget",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider((
        {
            "get_node": (
                lambda: cst.Assign(
                    # pyre-ignore: Incompatible parameter type [6]
                    targets=[
                        cst.BinaryOperation(
                            left=cst.Name("x"),
                            operator=cst.Add(),
                            right=cst.Integer("1"),
                        ),
                    ],
                    value=cst.Name("y"),
                )),
            "expected_re":
            "Expected an instance of .*statement.AssignTarget.*",
        }, ))
    def test_invalid_types(self, **kwargs: Any) -> None:
        self.assert_invalid_types(**kwargs)
Exemplo n.º 14
0
def make_assign(lhs, rhs):
    return cst.SimpleStatementLine(
        [cst.Assign(targets=[cst.AssignTarget(lhs)], value=rhs)])
Exemplo n.º 15
0
    name = scope_expr.capture.pattern.pattern
    return scope_elem_type(
        **{
            **({'name': cst.Name(name), }
               if 'name' in scope_elem_type.__slots__ else {}),
            **({'body': cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]), }
               if 'body' in scope_elem_type.__slots__ else {}),
            **({
                'params': cst.Parameters(),
                'decorators': (),
                **({
                    'asynchronous': cst.Asynchronous()
                } if scope_expr.properties.get('async') else {})
            } if scope_elem_type is cst.FunctionDef else {}),
            **({
                'targets': [cst.AssignTarget(target=cst.Name(name))],
                'value': cst.Name("None")
            } if scope_elem_type is cst.Assign else {}),
            **scope_elem_extra_kwargs
        }
    )


# TODO: may be preferable to convert the selector into a bytecode of path manip instructions
def select(root: cst.Module, transform: Transform) -> Transform:
    matches: List[Match] = []

    # TODO: dont root search at global scope, that's not the original design
    # I'll probably need to change the parser to store the prefixing nesting op
    # NOTE: I have no idea how mypy works yet, I'm just pretending it's typescript
    # NOTE: I saw other typing usage briefly and I'm pretty sure it doesn't work this way
Exemplo n.º 16
0
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name):
    """Compile MCX's graph into a python (executable) function."""

    # Model arguments are passed in the following order:
    #  1. (samplers only) rng_key;
    #  2. (logpdf only) random variables, in the order in which they appear in the model.
    #  3. (all) the model's arguments and keyword arguments.
    maybe_rng_key = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if node.name == "rng_key"
    ]
    maybe_random_variables = [
        compile_placeholder(node, graph)
        for node in reversed(list(graph.placeholders))
        if node.is_random_variable
    ]
    model_args = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if not node.is_random_variable and node.name != "rng_key"
        and not node.has_default
    ]
    model_kwargs = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if not node.is_random_variable and node.name != "rng_key"
        and node.has_default
    ]
    args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs

    # Every statement in the function corresponds to either a constant definition or
    # a variable assignment. We use a topological sort to respect the
    # dependency order.
    stmts = []
    returns = []
    for node in nx.topological_sort(graph):

        if node.name is None:
            continue

        if isinstance(node, Constant):
            stmt = cst.SimpleStatementLine(body=[
                cst.Assign(
                    targets=[
                        cst.AssignTarget(target=cst.Name(value=node.name))
                    ],
                    value=node.cst_generator(),
                )
            ])
            stmts.append(stmt)

        if isinstance(node, Op):
            stmt = cst.SimpleStatementLine(body=[
                cst.Assign(
                    targets=[
                        cst.AssignTarget(target=cst.Name(value=node.name))
                    ],
                    value=compile_op(node, graph),
                )
            ])
            stmts.append(stmt)

            if node.is_returned:
                returns.append(
                    cst.SimpleStatementLine(
                        body=[cst.Return(value=cst.Name(value=node.name))]))

    # Assemble the function's CST using the previously translated nodes.
    ast_fn = cst.Module(body=[
        cst.FunctionDef(
            name=cst.Name(value=fn_name),
            params=cst.Parameters(params=args),
            body=cst.IndentedBlock(body=stmts + returns),
        )
    ])

    code = ast_fn.code
    exec(code, namespace)
    fn = namespace[fn_name]

    return fn, code