class ReturnCreateTest(CSTNodeTest): @data_provider(( { "node": cst.SimpleStatementLine([cst.Return()]), "code": "return\n", "expected_position": CodeRange((1, 0), (1, 6)), }, { "node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), "code": "return abc\n", "expected_position": CodeRange((1, 0), (1, 10)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": lambda: cst.Return(cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace("")), "expected_re": "Must have at least one space after 'return'.", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: final_node = updated_node if original_node is self.scope: suite = updated_node.body tail = self.tail for name, attr in self.names_to_attr.items(): state = self.attr_states.get(name, []) # default writeback initial value state.append(([], cst.Name(name))) attr_val = _fold_conditions(_simplify_gaurds(state), self.strict) write = to_stmt(make_assign(attr, attr_val)) tail.append(write) if self.returns: strict = self.strict try: return_val = _fold_conditions( _simplify_gaurds(self.returns), strict) except IncompleteGaurdError: raise SyntaxError( 'Cannot prove function always returns') from None return_stmt = cst.SimpleStatementLine( [cst.Return(value=return_val)]) tail.append(return_stmt) return final_node
def leave_Raise(self, node: cst.Raise, updated_node: cst.Raise) -> Union[cst.Return, cst.Raise]: if not self.in_coroutine(self.coroutine_stack): return updated_node if not m.matches(node, gen_return_matcher): return updated_node return_value, whitespace_after = self.pluck_gen_return_value( updated_node) return cst.Return( value=return_value, whitespace_after_return=whitespace_after, semicolon=updated_node.semicolon, )
class FooterBehaviorTest(UnitTest): @data_provider({ # Literally the most basic example "simple_module": { "code": "\n", "expected_module": cst.Module(body=()) }, # A module with a header comment "header_only_module": { "code": "# This is a header comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[], ), }, # A module with a header and footer "simple_header_footer_module": { "code": "# This is a header comment\npass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[cst.SimpleStatementLine([cst.Pass()])], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # A module which should have a footer comment taken from the # if statement's indented block. "simple_reparented_footer_module": { "code": "# This is a header comment\nif True:\n pass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( header=cst.TrailingWhitespace(), body=[ cst.SimpleStatementLine( body=[cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( ), ) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verifying that we properly parse and spread out footer comments to the # relative indents they go with. "complex_reparented_footer_module": { "code": ("# This is a header comment\nif True:\n if True:\n pass" + "\n # This is an inner indented block comment\n # This " + "is an outer indented block comment\n# This is a footer comment\n" ), "expected_module": cst.Module( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine( body=[cst.Pass()]) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an inner indented block comment" )) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an outer indented block comment" )) ], ), ) ], header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verify that comments belonging to statements are still owned even # after an indented block. "statement_comment_reparent": { "code": "if foo:\n return\n# comment\nx = 7\n", "expected_module": cst.Module(body=[ cst.If( test=cst.Name(value="foo"), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Return( whitespace_after_return=cst.SimpleWhitespace( value="")) ]) ]), ), cst.SimpleStatementLine( body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value="x")) ], value=cst.Integer(value="7"), ) ], leading_lines=[ cst.EmptyLine(comment=cst.Comment(value="# comment")) ], ), ]), }, # Verify that even if there are completely empty lines, we give all lines # up to and including the last line that's indented correctly. That way # comments that line up with indented block's indentation level aren't # parented to the next line just because there's a blank line or two # between them. "statement_comment_with_empty_lines": { "code": ("def foo():\n if True:\n pass\n\n # Empty " + "line before me\n\n else:\n pass\n"), "expected_module": cst.Module(body=[ cst.FunctionDef( name=cst.Name(value="foo"), params=cst.Parameters(), body=cst.IndentedBlock(body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ], footer=[ cst.EmptyLine(indent=False), cst.EmptyLine(comment=cst.Comment( value="# Empty line before me")), ], ), orelse=cst.Else( body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ]), leading_lines=[cst.EmptyLine(indent=False)], ), ) ]), ) ]), }, }) def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None: parsed_module = parse_module(dedent(code)) self.assertTrue( deep_equals(parsed_module, expected_module), msg= f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}", )
class ReturnParseTest(CSTNodeTest): @data_provider(( { "node": cst.SimpleStatementLine([ cst.Return(whitespace_after_return=cst.SimpleWhitespace("")) ]), "code": "return\n", "parser": parse_statement, }, { "node": cst.SimpleStatementLine([ cst.Return( cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace(" "), ) ]), "code": "return abc\n", "parser": parse_statement, }, { "node": cst.SimpleStatementLine([ cst.Return( cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace(" "), ) ]), "code": "return abc\n", "parser": parse_statement, }, { "node": cst.SimpleStatementLine([ cst.Return( cst.Name("abc", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_after_return=cst.SimpleWhitespace(""), ) ]), "code": "return(abc)\n", "parser": parse_statement, }, { "node": cst.SimpleStatementLine([ cst.Return( cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace(" "), semicolon=cst.Semicolon(), ) ]), "code": "return abc;\n", "parser": parse_statement, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs)
def test_tuple_return(self) -> None: node = libcst.parse_expression("1, 2, 3") new_node = parenthesize.parenthesize_using_previous( node, libcst.Return() ) self.assert_has_parentheses(new_node)
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name): """Compile MCX's graph into a python (executable) function.""" # Model arguments are passed in the following order: # 1. (samplers only) rng_key; # 2. (logpdf only) random variables, in the order in which they appear in the model. # 3. (all) the model's arguments and keyword arguments. maybe_rng_key = [ compile_placeholder(node, graph) for node in graph.placeholders if node.name == "rng_key" ] maybe_random_variables = [ compile_placeholder(node, graph) for node in reversed(list(graph.placeholders)) if node.is_random_variable ] model_args = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and not node.has_default ] model_kwargs = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and node.has_default ] args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs # Every statement in the function corresponds to either a constant definition or # a variable assignment. We use a topological sort to respect the # dependency order. stmts = [] returns = [] for node in nx.topological_sort(graph): if node.name is None: continue if isinstance(node, Constant): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=node.cst_generator(), ) ]) stmts.append(stmt) if isinstance(node, Op): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=compile_op(node, graph), ) ]) stmts.append(stmt) if node.is_returned: returns.append( cst.SimpleStatementLine( body=[cst.Return(value=cst.Name(value=node.name))])) # Assemble the function's CST using the previously translated nodes. ast_fn = cst.Module(body=[ cst.FunctionDef( name=cst.Name(value=fn_name), params=cst.Parameters(params=args), body=cst.IndentedBlock(body=stmts + returns), ) ]) code = ast_fn.code exec(code, namespace) fn = namespace[fn_name] return fn, code