def parameters( self, type: typing.Literal["function", "classmethod", "method"] ) -> cst.Parameters: posonly_params = [ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_only_optional, self.pos_only_optional_ordering).items() ] if type == "classmethod": posonly_params.insert(0, cst.Param(cst.Name("cls"))) elif type == "method": posonly_params.insert(0, cst.Param(cst.Name("self"))) return cst.Parameters( posonly_params=posonly_params, params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_or_kw_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_or_kw_optional, self.pos_or_kw_optional_ordering).items() ], star_arg=(cst.Param( cst.Name(self.var_pos[0]), cst.Annotation(self.var_pos[1].annotation), ) if self.var_pos else cst.MaybeSentinel.DEFAULT), star_kwarg=(cst.Param(cst.Name(self.var_kw[0]), cst.Annotation(self.var_kw[1].annotation)) if self.var_kw else None), kwonly_params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.kw_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in self.kw_only_optional.items() ], )
def test_parameters(self) -> None: # Test that we can insert a parameter into a function def normally. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Name("bar"), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameter as a special case. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Param(cst.Name("bar")), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameters list as a special case. statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( (cst.Param(cst.Name("bar")),), ), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test filling out multiple parameters statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( params=( cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")), ), star_kwarg=cst.Param(cst.Name("rest")), ), ) self.assertEqual( self.code(statement), "def foo(bar, baz, **rest): pass\n", )
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if not any( QualifiedNameProvider.has_name( self, decorator.decorator, QualifiedName(name="builtins.classmethod", source=QualifiedNameSource.BUILTIN), ) for decorator in node.decorators ): return # If it's not a @classmethod, we are not interested. if not node.params.params: # No params, but there must be the 'cls' param. # Note that pyre[47] already catches this, but we also generate # an autofix, so it still makes sense for us to report it here. new_params = node.params.with_changes(params=(cst.Param(name=cst.Name(value=CLS)),)) repl = node.with_changes(params=new_params) self.report(node, replacement=repl) return p0_name = node.params.params[0].name if p0_name.value == CLS: return # All good. # Rename all assignments and references of the first param within the # function scope, as long as they are done via a Name node. # We rely on the parser to correctly derive all # assigments and references within the FunctionScope. # The Param node's scope is our classmethod's FunctionScope. scope = self.get_metadata(ScopeProvider, p0_name, None) if not scope: # Cannot autofix without scope metadata. Only report in this case. # Not sure how to repro+cover this in a unit test... # If metadata creation fails, then the whole lint fails, and if it succeeds, # then there is valid metadata. But many other lint rule implementations contain # a defensive scope None check like this one, so I assume it is necessary. self.report(node) return if scope[CLS]: # The scope already has another assignment to "cls". # Trying to rename the first param to "cls" as well may produce broken code. # We should therefore refrain from suggesting an autofix in this case. self.report(node) return refs: List[Union[cst.Name, cst.Attribute]] = [] assignments = scope[p0_name.value] for a in assignments: if isinstance(a, Assignment): assign_node = a.node if isinstance(assign_node, cst.Name): refs.append(assign_node) elif isinstance(assign_node, cst.Param): refs.append(assign_node.name) # There are other types of possible assignment nodes: ClassDef, # FunctionDef, Import, etc. We deliberately do not handle those here. refs += [r.node for r in a.references] repl = node.visit(_RenameTransformer(refs, CLS)) self.report(node, replacement=repl)
def sample_predictive(model): """Sample from the model's predictive distribution.""" graph = copy.deepcopy(model.graph) rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key") # Update the SampleOps to return a sample from the distribution so that # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`. def distribution_to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")), args=[cst.Arg(value=rng_key)], ) def model_to_sampler(model_name, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Name(value=model_name), args=[cst.Arg(value=rng_key)] + list(args) ) random_variables = [] for node in reversed(list(graph.random_variables)): if isinstance(node, SampleModelOp): node.cst_generator = partial(model_to_sampler, node.model_name) else: node.cst_generator = partial(distribution_to_sampler, node.cst_generator) random_variables.append(node) # Link the `rng_key` placeholder to the sampling expressions graph.add(rng_node) for var in random_variables: graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"]) return compile_graph(graph, model.namespace, f"{graph.name}_sample")
def test_from_function_data(self) -> None: three_parameters = [ cst.Param(name=cst.Name("x1"), annotation=None), cst.Param(name=cst.Name("x2"), annotation=None), cst.Param(name=cst.Name("x3"), annotation=None), ] self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=True, annotated_parameter_count=3, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.FULLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=True, annotated_parameter_count=0, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=0, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.NOT_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=False, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) # An untyped `self` parameter of a method does not count for partial # annotation. As per PEP 484, we need an explicitly annotated parameter. self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=True, parameters=three_parameters, ), FunctionAnnotationKind.NOT_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=2, is_method_or_classmethod=True, parameters=three_parameters, ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) # An annotated `self` suffices to make Pyre typecheck the method. self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=1, is_method_or_classmethod=True, parameters=[ cst.Param( name=cst.Name("self"), annotation=cst.Annotation(cst.Name("Foo")), ) ], ), FunctionAnnotationKind.PARTIALLY_ANNOTATED, ) self.assertEqual( FunctionAnnotationKind.from_function_data( is_return_annotated=False, annotated_parameter_count=0, is_method_or_classmethod=True, parameters=[], ), FunctionAnnotationKind.NOT_ANNOTATED, )
def argument_cst(name, default=None): return cst.Param(cst.Name(name), default=default)
class LambdaParserTest(CSTNodeTest): @data_provider(( # Simple lambda (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"), # Test basic positional params ( cst.Lambda( cst.Parameters(params=( cst.Param( cst.Name("bar"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param(cst.Name("baz"), star=""), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda bar, baz: 5", ), # Test basic positional default params ( cst.Lambda( cst.Parameters(default_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), default=cst.Integer("5"), equal=cst.AssignEqual(), star="", ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda bar = "one", baz = 5: 5', ), # Mixed positional and default params. ( cst.Lambda( cst.Parameters( params=(cst.Param( cst.Name("bar"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=(cst.Param( cst.Name("baz"), default=cst.Integer("5"), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda bar, baz = 5: 5", ), # Test kwonly params ( cst.Lambda( cst.Parameters( star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param(cst.Name("baz"), star=""), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda *, bar = "one", baz: 5', ), # Mixed params and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, *, bar = "one", baz, biz = "two": 5', ), # Mixed default_params and kwonly_params ( cst.Lambda( cst.Parameters( default_params=( cst.Param( cst.Name("first"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Mixed params, default_params, and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=( cst.Param( cst.Name("third"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("fourth"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Test star_arg ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params"), star="*")), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda *params: 5", ), # Typed star_arg, include kwonly_params ( cst.Lambda( cst.Parameters( star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda *params, bar = "one", baz, biz = "two": 5', ), # Mixed params default_params, star_arg and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=( cst.Param( cst.Name("third"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("fourth"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', ), # Test star_arg and star_kwarg ( cst.Lambda( cst.Parameters( star_kwarg=cst.Param(cst.Name("kwparams"), star="**")), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda **kwparams: 5", ), # Test star_arg and kwarg ( cst.Lambda( cst.Parameters( star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), star_kwarg=cst.Param(cst.Name("kwparams"), star="**"), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda *params, **kwparams: 5", ), # Inner whitespace ( cst.Lambda( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), params=cst.Parameters(), colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), body=cst.Integer("5"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( lambda : 5 )", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, parse_expression, position)
class LambdaCreationTest(CSTNodeTest): @data_provider(( # Simple lambda (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"), # Test basic positional params ( cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")))), cst.Integer("5"), ), "lambda bar, baz: 5", ), # Test basic positional default params ( cst.Lambda( cst.Parameters(default_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz"), default=cst.Integer("5")), )), cst.Integer("5"), ), 'lambda bar = "one", baz = 5: 5', ), # Mixed positional and default params. ( cst.Lambda( cst.Parameters( params=(cst.Param(cst.Name("bar")), ), default_params=(cst.Param(cst.Name("baz"), default=cst.Integer("5")), ), ), cst.Integer("5"), ), "lambda bar, baz = 5: 5", ), # Test kwonly params ( cst.Lambda( cst.Parameters(kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), )), cst.Integer("5"), ), 'lambda *, bar = "one", baz: 5', ), # Mixed params and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, *, bar = "one", baz, biz = "two": 5', ), # Mixed default_params and kwonly_params ( cst.Lambda( cst.Parameters( default_params=( cst.Param(cst.Name("first"), default=cst.Float("1.0")), cst.Param(cst.Name("second"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Mixed params, default_params, and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', CodeRange((1, 0), (1, 84)), ), # Test star_arg ( cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("params"))), cst.Integer("5"), ), "lambda *params: 5", ), # Typed star_arg, include kwonly_params ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda *params, bar = "one", baz, biz = "two": 5', ), # Mixed params default_params, star_arg and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', ), # Test star_arg and star_kwarg ( cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))), cst.Integer("5"), ), "lambda **kwparams: 5", ), # Test star_arg and kwarg ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), star_kwarg=cst.Param(cst.Name("kwparams")), ), cst.Integer("5"), ), "lambda *params, **kwparams: 5", ), # Inner whitespace ( cst.Lambda( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), whitespace_after_lambda=cst.SimpleWhitespace(" "), params=cst.Parameters(), colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")), body=cst.Integer("5"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( lambda : 5 )", CodeRange((1, 2), (1, 13)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), equal=cst.AssignEqual())), cst.Integer("5"), ), "Must have a default when specifying an AssignEqual.", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="***")), cst.Integer("5"), ), r"Must specify either '', '\*' or '\*\*' for star.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"')), )), cst.Integer("5"), ), "Cannot have defaults for params", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )), cst.Integer("5"), ), "Must have defaults for default_params", ), ( lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()), cst.Integer("5")), "Must have at least one kwonly param if ParamStar is used.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), ) ), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), star="*", ), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"), star="*"), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")), cst.Integer("5"), ), r"Expecting a star prefix of '\*'", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*") ), cst.Integer("5"), ), r"Expecting a star prefix of '\*\*'", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
'async': lambda _: {cst.FunctionDef} } nesting_op_children_getter = { '.': lambda node: node.body.body if isinstance(node, cst.ClassDef) else (), '(': lambda node: node.params.params if isinstance(node, cst.FunctionDef) else (), ',': getNextParam, None: lambda node: node.children } per_path_default_kwargs = RelativePathDict({ (cst.ClassDef, cst.FunctionDef): lambda path: { # TODO: if node is not decorated as staticmethod or class method! 'params': cst.Parameters(params=[cst.Param(cst.Name('self'))]) }, }) # NOTE: it is bad design to allow users to craft coincidentally unambiguous # scope expressions, instead, it would be better to have a plenty of unambigous # property keys (i.e. func, class, async, static, etc) for the user to craft with def possibleElemTypes(scope_expr: ScopeExpr) -> Set[cst.CSTNode]: return reduce(and_, map(lambda key, val: possible_node_classes_per_prop[key](val), scope_expr.properties.keys(), scope_expr.properties.values()), node_types)
def sample_posterior_predictive(model, node_names): """Sample from the posterior predictive distribution. Example ------- We transform MCX models of the form: >>> def linear_regression(X, lmbda=1.): ... scale <~ Exponential(lmbda) ... coef <~ Normal(jnp.zeros(X.shape[0]), 1) ... y = jnp.dot(X, coef) ... pred <~ Normal(y, scale) ... return pred into: >>> def linear_regression_pred(rng_key, scale, coef, X, lambda=1.): ... idx = jax.random.choice(rng_key, scale.shape[0]) ... scale_sample = scale[idx] ... coef_sample = coef[idx] ... y = jnp.dot(X, coef_sample) ... pred = Normal(y, scale_sample).sample(rng_key) ... return pred """ graph = copy.deepcopy(model.graph) nodes = [graph.find_node(name) for name in node_names] # We will need to pass a RNG key to the function to sample from # the distributions; we create a placeholder for this key. rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key") graph.add_node(rng_node) # To take a sampler from the posterior distribution we first choose a sample id # at random `idx = mcx.jax.choice(rng_key, num_samples)`. We later index each # array of samples passed by this `idx`. def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], ) choice_node = Op(choice_ast, graph.name, "idx") graph.add(choice_node, rng_node) # Remove all edges incoming to the nodes that are targetted # by the intervention. to_remove = [] for e in graph.in_edges(nodes): to_remove.append(e) for edge in to_remove: graph.remove_edge(*edge) # Each SampleOp that is intervened on is replaced by a placeholder that is indexed # by the index of the sample being taken. for node in reversed(nodes): rv_name = node.name # Add the placeholder placeholder = Placeholder( partial(lambda name: cst.Param(cst.Name(name)), rv_name), rv_name, is_random_variable=True, ) graph.add_node(placeholder) def sample_index(placeholder, idx): return cst.Subscript(placeholder, [cst.SubscriptElement(cst.Index(idx))]) chosen_sample = Op(sample_index, graph.name, rv_name + "_sample") graph.add(chosen_sample, placeholder, choice_node) original_edges = [] for e in graph.out_edges(node): data = graph.get_edge_data(*e) original_edges.append(e) graph.add_edge(chosen_sample, e[1], **data) for e in original_edges: graph.remove_edge(*e) graph.remove_node(node) # recursively remove every node that has no outgoing edge and is not # returned graph = remove_dangling_nodes(graph) # replace SampleOps by sampling instruction def to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute( value=cst_generator(*args, **kwargs), attr=cst.Name("sample") ), args=[cst.Arg(value=rng_key)], ) random_variables = [] for node in reversed(list(graph.nodes())): if not isinstance(node, SampleOp): continue node.cst_generator = partial(to_sampler, node.cst_generator) random_variables.append(node) # Add the placeholders to the graph for var in random_variables: graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"]) return compile_graph( graph, model.namespace, f"{graph.name}_sample_posterior_predictive" )
def sample_joint(model): """Obtain forward samples from the joint distribution defined by the model.""" graph = copy.deepcopy(model.graph) namespace = model.namespace def to_dictionary_of_samples(random_variables, *_): scopes = [rv.scope for rv in random_variables] names = [rv.name for rv in random_variables] scoped = defaultdict(dict) for scope, var_name, var in zip(scopes, names, random_variables): scoped[scope][var_name] = var # if there is only one scope (99% of models) we return a flat dictionary if len(set(scopes)) == 1: scope = scopes[0] return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(var.name), ) for var_name, var in scoped[scope].items() ] ) # Otherwise we return a nested dictionary where the first level is # the scope, and then the variables. return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{scope}'"), cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(var.name), ) for var_name, var in scoped[scope].items() ] ), ) for scope in scoped.keys() ] ) # no node is returned anymore for node in graph.nodes(): if isinstance(node, Op): node.is_returned = False rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key") # Update the SampleOps to return a sample from the distribution so that # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`. def distribution_to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")), args=[cst.Arg(value=rng_key)], ) def model_to_sampler(model_name, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(cst.Name(value=model_name), cst.Name("sample")), args=[cst.Arg(value=rng_key)] + list(args), ) random_variables = [] for node in reversed(list(graph.random_variables)): if isinstance(node, SampleModelOp): node.cst_generator = partial(model_to_sampler, node.model_name) else: node.cst_generator = partial(distribution_to_sampler, node.cst_generator) random_variables.append(node) # Link the `rng_key` placeholder to the sampling expressions graph.add(rng_node) for var in random_variables: graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"]) for node in graph.random_variables: if not isinstance(node, SampleModelOp): continue rv_name = node.name returned_var_name = node.graph.returned_variables[0].name def sample_index(rv, returned_var, *_): return cst.Subscript( cst.Name(rv), [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))], ) chosen_sample = Op( partial(sample_index, rv_name, returned_var_name), graph.name, rv_name + "_value", ) original_edges = [] data = [] out_nodes = [] for e in graph.out_edges(node): datum = graph.get_edge_data(*e) data.append(datum) original_edges.append(e) out_nodes.append(e[1]) for e in original_edges: graph.remove_edge(*e) graph.add(chosen_sample, node) for e, d in zip(out_nodes, data): graph.add_edge(chosen_sample, e, **d) tuple_node = Op( partial(to_dictionary_of_samples, graph.random_variables), graph.name, "forward_samples", is_returned=True, ) graph.add(tuple_node, *graph.random_variables) return compile_graph(graph, namespace, f"{graph.name}_sample_forward")
def placeholder_to_param(name: str): return cst.Param(cst.Name(name))