Example #1
0
class ReturnCreateTest(CSTNodeTest):
    @data_provider((
        {
            "node": cst.SimpleStatementLine([cst.Return()]),
            "code": "return\n",
            "expected_position": CodeRange((1, 0), (1, 6)),
        },
        {
            "node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]),
            "code": "return abc\n",
            "expected_position": CodeRange((1, 0), (1, 10)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node":
        lambda: cst.Return(cst.Name("abc"),
                           whitespace_after_return=cst.SimpleWhitespace("")),
        "expected_re":
        "Must have at least one space after 'return'.",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Example #2
0
    def leave_FunctionDef(self, original_node: cst.FunctionDef,
                          updated_node: cst.FunctionDef) -> cst.FunctionDef:
        final_node = updated_node
        if original_node is self.scope:
            suite = updated_node.body
            tail = self.tail
            for name, attr in self.names_to_attr.items():
                state = self.attr_states.get(name, [])
                # default writeback initial value
                state.append(([], cst.Name(name)))
                attr_val = _fold_conditions(_simplify_gaurds(state),
                                            self.strict)
                write = to_stmt(make_assign(attr, attr_val))
                tail.append(write)

            if self.returns:
                strict = self.strict

                try:
                    return_val = _fold_conditions(
                        _simplify_gaurds(self.returns), strict)
                except IncompleteGaurdError:
                    raise SyntaxError(
                        'Cannot prove function always returns') from None
                return_stmt = cst.SimpleStatementLine(
                    [cst.Return(value=return_val)])
                tail.append(return_stmt)

        return final_node
Example #3
0
    def leave_Raise(self, node: cst.Raise,
                    updated_node: cst.Raise) -> Union[cst.Return, cst.Raise]:
        if not self.in_coroutine(self.coroutine_stack):
            return updated_node

        if not m.matches(node, gen_return_matcher):
            return updated_node

        return_value, whitespace_after = self.pluck_gen_return_value(
            updated_node)
        return cst.Return(
            value=return_value,
            whitespace_after_return=whitespace_after,
            semicolon=updated_node.semicolon,
        )
class FooterBehaviorTest(UnitTest):
    @data_provider({
        # Literally the most basic example
        "simple_module": {
            "code": "\n",
            "expected_module": cst.Module(body=())
        },
        # A module with a header comment
        "header_only_module": {
            "code":
            "# This is a header comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[],
            ),
        },
        # A module with a header and footer
        "simple_header_footer_module": {
            "code":
            "# This is a header comment\npass\n# This is a footer comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[cst.SimpleStatementLine([cst.Pass()])],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # A module which should have a footer comment taken from the
        # if statement's indented block.
        "simple_reparented_footer_module": {
            "code":
            "# This is a header comment\nif True:\n    pass\n# This is a footer comment\n",
            "expected_module":
            cst.Module(
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                body=[
                    cst.If(
                        test=cst.Name(value="True"),
                        body=cst.IndentedBlock(
                            header=cst.TrailingWhitespace(),
                            body=[
                                cst.SimpleStatementLine(
                                    body=[cst.Pass()],
                                    trailing_whitespace=cst.TrailingWhitespace(
                                    ),
                                )
                            ],
                        ),
                    )
                ],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # Verifying that we properly parse and spread out footer comments to the
        # relative indents they go with.
        "complex_reparented_footer_module": {
            "code":
            ("# This is a header comment\nif True:\n    if True:\n        pass"
             +
             "\n        # This is an inner indented block comment\n    # This "
             +
             "is an outer indented block comment\n# This is a footer comment\n"
             ),
            "expected_module":
            cst.Module(
                body=[
                    cst.If(
                        test=cst.Name(value="True"),
                        body=cst.IndentedBlock(
                            body=[
                                cst.If(
                                    test=cst.Name(value="True"),
                                    body=cst.IndentedBlock(
                                        body=[
                                            cst.SimpleStatementLine(
                                                body=[cst.Pass()])
                                        ],
                                        footer=[
                                            cst.EmptyLine(comment=cst.Comment(
                                                value=
                                                "# This is an inner indented block comment"
                                            ))
                                        ],
                                    ),
                                )
                            ],
                            footer=[
                                cst.EmptyLine(comment=cst.Comment(
                                    value=
                                    "# This is an outer indented block comment"
                                ))
                            ],
                        ),
                    )
                ],
                header=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a header comment"))
                ],
                footer=[
                    cst.EmptyLine(comment=cst.Comment(
                        value="# This is a footer comment"))
                ],
            ),
        },
        # Verify that comments belonging to statements are still owned even
        # after an indented block.
        "statement_comment_reparent": {
            "code":
            "if foo:\n    return\n# comment\nx = 7\n",
            "expected_module":
            cst.Module(body=[
                cst.If(
                    test=cst.Name(value="foo"),
                    body=cst.IndentedBlock(body=[
                        cst.SimpleStatementLine(body=[
                            cst.Return(
                                whitespace_after_return=cst.SimpleWhitespace(
                                    value=""))
                        ])
                    ]),
                ),
                cst.SimpleStatementLine(
                    body=[
                        cst.Assign(
                            targets=[
                                cst.AssignTarget(target=cst.Name(value="x"))
                            ],
                            value=cst.Integer(value="7"),
                        )
                    ],
                    leading_lines=[
                        cst.EmptyLine(comment=cst.Comment(value="# comment"))
                    ],
                ),
            ]),
        },
        # Verify that even if there are completely empty lines, we give all lines
        # up to and including the last line that's indented correctly. That way
        # comments that line up with indented block's indentation level aren't
        # parented to the next line just because there's a blank line or two
        # between them.
        "statement_comment_with_empty_lines": {
            "code":
            ("def foo():\n    if True:\n        pass\n\n        # Empty " +
             "line before me\n\n    else:\n        pass\n"),
            "expected_module":
            cst.Module(body=[
                cst.FunctionDef(
                    name=cst.Name(value="foo"),
                    params=cst.Parameters(),
                    body=cst.IndentedBlock(body=[
                        cst.If(
                            test=cst.Name(value="True"),
                            body=cst.IndentedBlock(
                                body=[
                                    cst.SimpleStatementLine(body=[cst.Pass()])
                                ],
                                footer=[
                                    cst.EmptyLine(indent=False),
                                    cst.EmptyLine(comment=cst.Comment(
                                        value="# Empty line before me")),
                                ],
                            ),
                            orelse=cst.Else(
                                body=cst.IndentedBlock(body=[
                                    cst.SimpleStatementLine(body=[cst.Pass()])
                                ]),
                                leading_lines=[cst.EmptyLine(indent=False)],
                            ),
                        )
                    ]),
                )
            ]),
        },
    })
    def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None:
        parsed_module = parse_module(dedent(code))
        self.assertTrue(
            deep_equals(parsed_module, expected_module),
            msg=
            f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}",
        )
Example #5
0
class ReturnParseTest(CSTNodeTest):
    @data_provider((
        {
            "node":
            cst.SimpleStatementLine([
                cst.Return(whitespace_after_return=cst.SimpleWhitespace(""))
            ]),
            "code":
            "return\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine([
                cst.Return(
                    cst.Name("abc"),
                    whitespace_after_return=cst.SimpleWhitespace(" "),
                )
            ]),
            "code":
            "return abc\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine([
                cst.Return(
                    cst.Name("abc"),
                    whitespace_after_return=cst.SimpleWhitespace("   "),
                )
            ]),
            "code":
            "return   abc\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine([
                cst.Return(
                    cst.Name("abc",
                             lpar=[cst.LeftParen()],
                             rpar=[cst.RightParen()]),
                    whitespace_after_return=cst.SimpleWhitespace(""),
                )
            ]),
            "code":
            "return(abc)\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine([
                cst.Return(
                    cst.Name("abc"),
                    whitespace_after_return=cst.SimpleWhitespace(" "),
                    semicolon=cst.Semicolon(),
                )
            ]),
            "code":
            "return abc;\n",
            "parser":
            parse_statement,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)
Example #6
0
 def test_tuple_return(self) -> None:
     node = libcst.parse_expression("1, 2, 3")
     new_node = parenthesize.parenthesize_using_previous(
         node, libcst.Return()
     )
     self.assert_has_parentheses(new_node)
Example #7
0
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name):
    """Compile MCX's graph into a python (executable) function."""

    # Model arguments are passed in the following order:
    #  1. (samplers only) rng_key;
    #  2. (logpdf only) random variables, in the order in which they appear in the model.
    #  3. (all) the model's arguments and keyword arguments.
    maybe_rng_key = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if node.name == "rng_key"
    ]
    maybe_random_variables = [
        compile_placeholder(node, graph)
        for node in reversed(list(graph.placeholders))
        if node.is_random_variable
    ]
    model_args = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if not node.is_random_variable and node.name != "rng_key"
        and not node.has_default
    ]
    model_kwargs = [
        compile_placeholder(node, graph) for node in graph.placeholders
        if not node.is_random_variable and node.name != "rng_key"
        and node.has_default
    ]
    args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs

    # Every statement in the function corresponds to either a constant definition or
    # a variable assignment. We use a topological sort to respect the
    # dependency order.
    stmts = []
    returns = []
    for node in nx.topological_sort(graph):

        if node.name is None:
            continue

        if isinstance(node, Constant):
            stmt = cst.SimpleStatementLine(body=[
                cst.Assign(
                    targets=[
                        cst.AssignTarget(target=cst.Name(value=node.name))
                    ],
                    value=node.cst_generator(),
                )
            ])
            stmts.append(stmt)

        if isinstance(node, Op):
            stmt = cst.SimpleStatementLine(body=[
                cst.Assign(
                    targets=[
                        cst.AssignTarget(target=cst.Name(value=node.name))
                    ],
                    value=compile_op(node, graph),
                )
            ])
            stmts.append(stmt)

            if node.is_returned:
                returns.append(
                    cst.SimpleStatementLine(
                        body=[cst.Return(value=cst.Name(value=node.name))]))

    # Assemble the function's CST using the previously translated nodes.
    ast_fn = cst.Module(body=[
        cst.FunctionDef(
            name=cst.Name(value=fn_name),
            params=cst.Parameters(params=args),
            body=cst.IndentedBlock(body=stmts + returns),
        )
    ])

    code = ast_fn.code
    exec(code, namespace)
    fn = namespace[fn_name]

    return fn, code