def leave_AnnAssign( self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]: # It handles a special case where a type-annotated variable has not initialized, e.g. foo: str # This case will be converted to foo = ... so that nodes traversal won't encounter exceptions later on if match.matches( original_node, match.AnnAssign( target=match.Name(value=match.DoNotCare()), annotation=match.Annotation(annotation=match.DoNotCare()), value=None)): updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=cst.Ellipsis()) # Handles type-annotated class attributes that has not been initialized, e.g. self.foo: str elif match.matches( original_node, match.AnnAssign( target=match.Attribute(value=match.DoNotCare()), annotation=match.Annotation(annotation=match.DoNotCare()), value=None)): updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=cst.Ellipsis()) else: updated_node = cst.Assign( targets=[cst.AssignTarget(target=original_node.target)], value=original_node.value) return updated_node
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign): if updated_node.value is None: new_val = cst.Name(value='None', lpar=[], rpar=[]) else: new_val = self.obf_universal(updated_node.value, 'v', 'a', 'ca') updated_node = updated_node.with_changes(value=new_val) updated_node = updated_node.with_changes( target=self.obf_universal(updated_node.target, 'v', 'a', 'ca')) if self.delete_annotations: space = cst.SimpleWhitespace(value=' ') updated_node = cst.Assign(targets=[ cst.AssignTarget( target=updated_node.target, whitespace_before_equal=space, whitespace_after_equal=space, ) ], value=updated_node.value, semicolon=updated_node.semicolon) # new_annotation = updated_node.annotation.with_changes(annotation=None) # updated_node = updated_node.with_changes(annotation=new_annotation) return updated_node
def make_assign( lhs: cst.BaseAssignTargetExpression, rhs: cst.BaseExpression, ) -> cst.Assign: return cst.Assign( targets=[ cst.AssignTarget(lhs), ], value=rhs, )
def test_or_operator_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse(matches(cst.Name("None"), m.Name("True") | m.Name("False"))) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("None")), m.Assign(value=m.Name("True") | m.Name("False")), ) )
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign): if updated_node.value is None: # Annotate assignments so they can be commented out by a second pass return updated_node.with_changes( target=cst.Name("__COMMENT__" + original_node.target.value)) # return cst.RemoveFromParent() return cst.Assign( targets=[cst.AssignTarget(target=updated_node.target)], value=updated_node.value)
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign): if updated_node.value is None: # e.g. `some_var: str` # these are *only* type declarations and have no runtime behavior, # so they should be removed entirely: return cst.RemoveFromParent() return cst.Assign( targets=[cst.AssignTarget(target=updated_node.target)], value=updated_node.value)
def test_assign_target(self) -> None: # Test that we can insert an assignment target normally. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.Name("first"), b=cst.Name("second"), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", ) # Test that we can insert an assignment target as a special case. statement = parse_template_statement( "{a} = {b} = {val}", a=cst.AssignTarget(cst.Name("first")), b=cst.AssignTarget(cst.Name("second")), val=cst.Integer("5"), ) self.assertEqual( self.code(statement), "first = second = 5\n", )
def test_or_operator_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue(matches(cst.Name("True"), m.Name("True") | m.Name("False"))) # Match on either True or False identifier. self.assertTrue(matches(cst.Name("False"), m.Name("True") | m.Name("False"))) # Match on either True, False or None identifier. self.assertTrue( matches(cst.Name("None"), m.Name("True") | m.Name("False") | m.Name("None")) ) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( cst.Assign((cst.AssignTarget(cst.Name("x")),), cst.Name("True")), m.Assign(value=m.Name("True") | m.Name("False")), ) )
def test_or_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse( matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("None") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("4")), m.Arg(m.Integer("5")), m.Arg(m.Integer("6")), ), ), ), ) )
def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue( matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False"))) ) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( libcst.Assign( (libcst.AssignTarget(libcst.Name("x")),), libcst.Name("True") ), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), ) ) self.assertTrue( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ) )
class FooterBehaviorTest(UnitTest): @data_provider({ # Literally the most basic example "simple_module": { "code": "\n", "expected_module": cst.Module(body=()) }, # A module with a header comment "header_only_module": { "code": "# This is a header comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[], ), }, # A module with a header and footer "simple_header_footer_module": { "code": "# This is a header comment\npass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[cst.SimpleStatementLine([cst.Pass()])], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # A module which should have a footer comment taken from the # if statement's indented block. "simple_reparented_footer_module": { "code": "# This is a header comment\nif True:\n pass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( header=cst.TrailingWhitespace(), body=[ cst.SimpleStatementLine( body=[cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( ), ) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verifying that we properly parse and spread out footer comments to the # relative indents they go with. "complex_reparented_footer_module": { "code": ("# This is a header comment\nif True:\n if True:\n pass" + "\n # This is an inner indented block comment\n # This " + "is an outer indented block comment\n# This is a footer comment\n" ), "expected_module": cst.Module( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine( body=[cst.Pass()]) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an inner indented block comment" )) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an outer indented block comment" )) ], ), ) ], header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verify that comments belonging to statements are still owned even # after an indented block. "statement_comment_reparent": { "code": "if foo:\n return\n# comment\nx = 7\n", "expected_module": cst.Module(body=[ cst.If( test=cst.Name(value="foo"), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Return( whitespace_after_return=cst.SimpleWhitespace( value="")) ]) ]), ), cst.SimpleStatementLine( body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value="x")) ], value=cst.Integer(value="7"), ) ], leading_lines=[ cst.EmptyLine(comment=cst.Comment(value="# comment")) ], ), ]), }, # Verify that even if there are completely empty lines, we give all lines # up to and including the last line that's indented correctly. That way # comments that line up with indented block's indentation level aren't # parented to the next line just because there's a blank line or two # between them. "statement_comment_with_empty_lines": { "code": ("def foo():\n if True:\n pass\n\n # Empty " + "line before me\n\n else:\n pass\n"), "expected_module": cst.Module(body=[ cst.FunctionDef( name=cst.Name(value="foo"), params=cst.Parameters(), body=cst.IndentedBlock(body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ], footer=[ cst.EmptyLine(indent=False), cst.EmptyLine(comment=cst.Comment( value="# Empty line before me")), ], ), orelse=cst.Else( body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ]), leading_lines=[cst.EmptyLine(indent=False)], ), ) ]), ) ]), }, }) def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None: parsed_module = parse_module(dedent(code)) self.assertTrue( deep_equals(parsed_module, expected_module), msg= f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}", )
class AssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.Assign((cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), "code": "foo = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 7)), }, # Multiple targets creation { "node": cst.Assign( ( cst.AssignTarget(cst.Name("foo")), cst.AssignTarget(cst.Name("bar")), ), cst.Integer("5"), ), "code": "foo = bar = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 13)), }, # Whitespace test for creating nodes { "node": cst.Assign( (cst.AssignTarget( cst.Name("foo"), whitespace_before_equal=cst.SimpleWhitespace(""), whitespace_after_equal=cst.SimpleWhitespace(""), ), ), cst.Integer("5"), ), "code": "foo=5", "parser": None, "expected_position": CodeRange((1, 0), (1, 5)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.Assign( (cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), )), "code": "foo = 5\n", "parser": parse_statement, "expected_position": None, }, # Multiple targets parser { "node": cst.SimpleStatementLine((cst.Assign( ( cst.AssignTarget(cst.Name("foo")), cst.AssignTarget(cst.Name("bar")), ), cst.Integer("5"), ), )), "code": "foo = bar = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test parser { "node": cst.SimpleStatementLine((cst.Assign( (cst.AssignTarget( cst.Name("foo"), whitespace_before_equal=cst.SimpleWhitespace(""), whitespace_after_equal=cst.SimpleWhitespace(""), ), ), cst.Integer("5"), ), )), "code": "foo=5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.Assign(targets=(), value=cst.Integer("5"))), "expected_re": "at least one AssignTarget", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class AssignTest(CSTNodeTest): @data_provider(( # Simple assignment creation case. { "node": cst.Assign((cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), "code": "foo = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 7)), }, # Multiple targets creation { "node": cst.Assign( ( cst.AssignTarget(cst.Name("foo")), cst.AssignTarget(cst.Name("bar")), ), cst.Integer("5"), ), "code": "foo = bar = 5", "parser": None, "expected_position": CodeRange((1, 0), (1, 13)), }, # Whitespace test for creating nodes { "node": cst.Assign( (cst.AssignTarget( cst.Name("foo"), whitespace_before_equal=cst.SimpleWhitespace(""), whitespace_after_equal=cst.SimpleWhitespace(""), ), ), cst.Integer("5"), ), "code": "foo=5", "parser": None, "expected_position": CodeRange((1, 0), (1, 5)), }, # Simple assignment parser case. { "node": cst.SimpleStatementLine((cst.Assign( (cst.AssignTarget(cst.Name("foo")), ), cst.Integer("5")), )), "code": "foo = 5\n", "parser": parse_statement, "expected_position": None, }, # Multiple targets parser { "node": cst.SimpleStatementLine((cst.Assign( ( cst.AssignTarget(cst.Name("foo")), cst.AssignTarget(cst.Name("bar")), ), cst.Integer("5"), ), )), "code": "foo = bar = 5\n", "parser": parse_statement, "expected_position": None, }, # Whitespace test parser { "node": cst.SimpleStatementLine((cst.Assign( (cst.AssignTarget( cst.Name("foo"), whitespace_before_equal=cst.SimpleWhitespace(""), whitespace_after_equal=cst.SimpleWhitespace(""), ), ), cst.Integer("5"), ), )), "code": "foo=5\n", "parser": parse_statement, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": (lambda: cst.Assign(targets=(), value=cst.Integer("5"))), "expected_re": "at least one AssignTarget", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "get_node": ( lambda: cst.Assign( # pyre-ignore: Incompatible parameter type [6] targets=[ cst.BinaryOperation( left=cst.Name("x"), operator=cst.Add(), right=cst.Integer("1"), ), ], value=cst.Name("y"), )), "expected_re": "Expected an instance of .*statement.AssignTarget.*", }, )) def test_invalid_types(self, **kwargs: Any) -> None: self.assert_invalid_types(**kwargs)
def make_assign(lhs, rhs): return cst.SimpleStatementLine( [cst.Assign(targets=[cst.AssignTarget(lhs)], value=rhs)])
name = scope_expr.capture.pattern.pattern return scope_elem_type( **{ **({'name': cst.Name(name), } if 'name' in scope_elem_type.__slots__ else {}), **({'body': cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]), } if 'body' in scope_elem_type.__slots__ else {}), **({ 'params': cst.Parameters(), 'decorators': (), **({ 'asynchronous': cst.Asynchronous() } if scope_expr.properties.get('async') else {}) } if scope_elem_type is cst.FunctionDef else {}), **({ 'targets': [cst.AssignTarget(target=cst.Name(name))], 'value': cst.Name("None") } if scope_elem_type is cst.Assign else {}), **scope_elem_extra_kwargs } ) # TODO: may be preferable to convert the selector into a bytecode of path manip instructions def select(root: cst.Module, transform: Transform) -> Transform: matches: List[Match] = [] # TODO: dont root search at global scope, that's not the original design # I'll probably need to change the parser to store the prefixing nesting op # NOTE: I have no idea how mypy works yet, I'm just pretending it's typescript # NOTE: I saw other typing usage briefly and I'm pretty sure it doesn't work this way
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name): """Compile MCX's graph into a python (executable) function.""" # Model arguments are passed in the following order: # 1. (samplers only) rng_key; # 2. (logpdf only) random variables, in the order in which they appear in the model. # 3. (all) the model's arguments and keyword arguments. maybe_rng_key = [ compile_placeholder(node, graph) for node in graph.placeholders if node.name == "rng_key" ] maybe_random_variables = [ compile_placeholder(node, graph) for node in reversed(list(graph.placeholders)) if node.is_random_variable ] model_args = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and not node.has_default ] model_kwargs = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and node.has_default ] args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs # Every statement in the function corresponds to either a constant definition or # a variable assignment. We use a topological sort to respect the # dependency order. stmts = [] returns = [] for node in nx.topological_sort(graph): if node.name is None: continue if isinstance(node, Constant): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=node.cst_generator(), ) ]) stmts.append(stmt) if isinstance(node, Op): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=compile_op(node, graph), ) ]) stmts.append(stmt) if node.is_returned: returns.append( cst.SimpleStatementLine( body=[cst.Return(value=cst.Name(value=node.name))])) # Assemble the function's CST using the previously translated nodes. ast_fn = cst.Module(body=[ cst.FunctionDef( name=cst.Name(value=fn_name), params=cst.Parameters(params=args), body=cst.IndentedBlock(body=stmts + returns), ) ]) code = ast_fn.code exec(code, namespace) fn = namespace[fn_name] return fn, code