def compile_placeholder(node: Placeholder, graph: GraphicalModel): """Compile a placeholder by fetching its default value.""" default = [] for predecessor in graph.predecessors(node): default.append(predecessor.cst_generator()) return node.cst_generator(*default)
def compile_op(node: Op, graph: GraphicalModel): """Compile an Op by recursively compiling and including its upstream nodes. """ op_args = {} op_kwargs = {} for predecessor in graph.predecessors(node): # I am not sure what this is about, consider deleting. if graph[predecessor][node] == {}: pass # If a predecessor has a name, it is either a random or # a deterministic variable. We only need to reference its # name here. if predecessor.name is not None: pred_ast = cst.Name(value=predecessor.name) else: pred_ast = compile_op(predecessor, graph) # To rebuild the node's CST we need to pass the compiled # CST of the arguments as arguments to the generator function. # # !! Important !! # # The compiler feeds the arguments in the order in which they # were added to the graph, not the order in which they were # given to `graph.add`. This is a potential source of mysterious # bugs and should be corrected. # This also means we do not feed repeated arguments several times # when we should. edge = graph[predecessor][node] if edge["type"] == "arg": for idx in edge["position"]: if idx in op_args: raise IndexError("Duplicate argument position") op_args[idx] = pred_ast else: for key in graph[predecessor][node]["key"]: op_kwargs[key] = pred_ast args = [op_args[idx] for idx in sorted(op_args.keys())] return node.cst_generator(*args, **op_kwargs)
def __init__(self, namespace): self.model = GraphicalModel() self.namespace = namespace
class ModelParser(ast.NodeVisitor): """Recursively parse the model definition's AST and translate it to a graphical model. """ def __init__(self, namespace): self.model = GraphicalModel() self.namespace = namespace def generic_visit(self, node): node_type = type(node).__name__ raise SyntaxError(f"Method to process {node_type} not specified") def visit_Module(self, node: ast.Module) -> GraphicalModel: """Parsing the source code into an Abstract Syntax Tree returns an `ast.Module` object. We check that the object being parsed is indeed a single function and pass on to another method to parse the function. Returns ------- A graphical model instance that contains the graphical representation of the model. Raises ------ SyntaxError If the module's body is empty or contains more than one object. SyntaxError If the module's body does not contain a funtion definition. """ if len(node.body) != 1: raise SyntaxError("You must pass a single model definition.") model_fn = node.body[0] if not isinstance(model_fn, ast.FunctionDef): raise SyntaxError("The model must be defined inside a function") self.visit_model(model_fn) return self.model # # MCX-specfic visitors # def visit_model(self, node: ast.FunctionDef) -> None: """Visit the function in which the model is defined. Raises ------ SyntaxError If the function's body contains unsupported constructs, i.e. anything else than an assignment, an expression or return. """ self.visit_model_arguments(node) for stmt in node.body: if isinstance(stmt, ast.Assign): self.visit_deterministic(stmt) elif isinstance(stmt, ast.Expr): self.visit_Expr(stmt) elif isinstance(stmt, ast.Return): self.visit_Return(stmt) else: raise SyntaxError( "Only variable, random variable assignments and transformations are currently supported" ) def visit_model_arguments(self, node: ast.FunctionDef) -> None: """Record the model's name and its arguments.""" self.model.graph["name"] = node.name argument_names: List[str] = [] for arg in node.args.args: argument_names.append(arg.arg) num_arguments = len(argument_names) default_values: List[Union[ast.expr, None]] = [] for default in node.args.defaults: default_values.append(default) default_values = (num_arguments - len(default_values) ) * [None] + default_values # type: ignore for name, value in zip(argument_names, default_values): self.model.add_argument(name, value) def visit_deterministic(self, node: ast.Assign) -> None: """Visit and add deterministic variables to the graphical model. Deterministic expression can be of two kinds: - Constant assignments; - Transformation of existing variables. Since constants can also be the result of a call of the form `np.array([0, 1])`, we need to walk down the assignments' values. If any `ast.Name` node is find, the assignment is a transformation otherwise it is a constant. Raises ------ SyntaxError If several variables are being assigned in one statement. SyntaxError If this is not a variable assignment. #TODO: check that the functions being called are in scope. """ if len(node.targets) != 1: raise SyntaxError( "You can only assign one variable per statement.") target = node.targets[0] if isinstance(target, ast.Name): var_name = target.id if isinstance(node.value, ast.Constant): constant_value = node.value self.model.add_variable(var_name, value=constant_value) elif isinstance(node.value, ast.Num): num_value = node.value self.model.add_variable(var_name, value=num_value) else: arg_names = find_variable_arguments(node) if arg_names: expression = node.value self.model.add_transformation(var_name, expression, arg_names) else: value = node.value self.model.add_variable(var_name, value) else: raise SyntaxError( "Require a name on the left-hand-side of a deterministic " f"variable assignment, got {target}") def visit_Expr(self, node: ast.Expr) -> None: if isinstance(node.value, ast.Compare): self.visit_Compare(node.value) def visit_Compare(self, node: ast.Compare) -> None: if isinstance(node.ops[0], ast.Is): self.visit_RandAssign(node) def visit_RandAssign(self, node: ast.Compare) -> None: """Visit a random variable assignment, and add a new random node to the graph. Random variable assignments are distinguished from deterministic assignments by the use of the `<~` operator. Raises ------ SyntaxError If there is no variable name on the left-hand-side of the `<~` operator. SyntaxError If the right-hand-side is not a function call. #TODO: check that the class being initialized is a Distribution instance. """ if not isinstance(node.left, ast.Name): raise SyntaxError( "You need a variable name on the left of a random variable assignment" f", got {node.left}") if not isinstance(node.comparators[0], ast.Call): raise SyntaxError( "Statements on the right of the `<~` operator must be distribution " f"initialization, got {node.comparators[0]}") name = node.left.id distribution = node.comparators[0] args = self.visit_Call(node.comparators[0]) # To allows model composition, whenever a `mcx` model appears at the # right-hand-side of a `<~` operator we merge its graph with the current # model's graph dist_path = read_object_name(distribution.func) dist_obj = eval(dist_path, self.namespace) if isinstance(dist_obj, mcx.model): print(args) self.model = self.model.merge_models(name, dist_obj.graph, args) else: self.model.add_randvar(name, distribution, args) def visit_Call(self, node: ast.Call) -> List[Union[str, int, float, complex]]: return self.visit_Arguments(node.args) def visit_Arguments( self, args: List[Any]) -> List[Union[str, float, int, complex]]: """Visits and returns the arguments used to initialize the distribution. Returns ------- A list of the names or values of the arguments passed to the distribution. Raises ------ SyntaxError If the distribution is initialized with anything different from a constant or a previously defined variable. """ arguments: List[Union[str, float, int, complex]] = [] for arg in args: if isinstance(arg, ast.Name): arguments.append(arg.id) elif isinstance(arg, ast.Constant): arguments.append(arg.value) elif isinstance(arg, ast.Num): arguments.append(arg.n) else: raise SyntaxError( "Expected a random variable of a constant to initialize " f"distribution, got {astor.code_gen.to_source(arg)} instead.\n" "Maybe you are trying to initialize a distribution directly, " "or call a function inside the distribution initialization. " "While this would be a perfectly legitimate move, it is currently " "not supported in mcx. Use an intermediate variable instead: \n\n" "Do not do `x <~ Normal(Normal(0, 1), 1)` or " "`x <~ Normal(my_function(10), 1)`, instead do " "`y <~ Normal(0, 1) & x <~ Normal(y, 1)` and " "`y = my_function(10) & x <~ Normal(y, 1)`") return arguments def visit_Return(self, node): """Visits the `return` expression of the model definition and mark the corresponding variables as returned in the graphical model. Raises ------ SyntaxError If the model does not return any variable. """ if isinstance(node.value, ast.Tuple): for var in node.value.elts: if isinstance(var, ast.Name): self.model.mark_as_returned(var.id) elif isinstance(node.value, ast.Name): self.model.mark_as_returned(node.value.id) else: raise SyntaxError( "Expected the generative model to return a (random) variable or a tuple" f"of (random) variables, got {node.value}")
def visit_FunctionDef(self, node: cst.FunctionDef) -> Union[None, bool]: """Visit a function definition. When we traverse the Concrete Syntax Tree of a MCX model, a function definition can represent several objects. The main model definition ~~~~~~~~~~~~~~~~~~~~~~~~~ >>> @mcx.model ... def my_model(*args): ... # do things ... return A regular function ~~~~~~~~~~~~~~~~~~~ >>> @mcx.model ... def my_model(*args): ... x <~ Normal(0, 1) ... y <~ Normal(0, 1) ... ... def add(a, b): ... return a + b ... ... z = add(x, y) ... return z A closure ~~~~~~~~~ It is perfectly reasonable (and is necessary to work with nonparametrics) to define a model like the following: >>> @mcx.model ... def mint_coin(): ... p <~ Beta(1, 1) ... ... @mcx.model ... def coin(): ... head <~ Bernoulli(p) ... return head ... ... return coin A submodel ~~~~~~~~~~ >>> @mcx.model ... def linreg(x): ... scale <~ HalfCauchy(0, 1) ... ... @mcx.model ... def Horseshoe(mu=0, tau=1., s=1.): ... scale <~ HalfCauchy(0, s) ... noise <~ Normal(0, tau) ... res = scale * noise + mu ... return res ... ... coefs <~ Horseshoe(np.zeros(x.shape[1])) ... predictions <~ Normal(np.matmul(x, coefs), scale) ... return predictions We can even have nested submodels. """ # Standard python functions defined within a model need to be included # as is in the resulting source code. So do submodels. if hasattr(self, "graph"): if is_model_definition(node, self.namespace): model_node = ModelOp(lambda: node, node.name.value) self.graph.add(model_node) self.named_variables[node.name] = model_node else: function_node = FunctionOp(lambda: node, node.name.value) self.graph.add(function_node) self.named_variables[node.name] = function_node return False # don't visit the node's children # Each time we enter a model definition we create a new GraphicalModel # which is returned after the definition's children have been visited. # The current version does not support nested models but will. self.graph: GraphicalModel = GraphicalModel() self.graph.name = node.name.value self.scope = node.name.value def argument_cst(name, default=None): return cst.Param(cst.Name(name), default=default) function_args = node.params.params for _, argument in enumerate(function_args): name = argument.name.value try: # parse argument default value is any default = self.recursive_visit(argument.default) argument_node = Placeholder(partial(argument_cst, name), name, False, True) self.graph.add(argument_node, default) except TypeError: argument_node = Placeholder(partial(argument_cst, name), name) self.graph.add(argument_node) self.named_variables[name] = argument_node return None
def compile_to_logpdf(graph: GraphicalModel, namespace: Dict) -> Artifact: """Compile a graphical model into a log-probability density function. Example ------- Let us consider a simple linear regression example: >>> @mcx.model ... def linear_regression(x, lmbda=1.): ... scale <~ Exponential(lmbda) ... coeff <~ Normal(0, 1) ... y = np.dot(x, coeff) ... predictions <~ Normal(y, scale) ... return predictions MCX parses this definition into a graphical model. This function compiles the graph in a python function that returns the values of the log-probability density function: >>> def linear_regression_logpdf(x, scale, coeffs, predictions, lmbda=1.): ... logpdf = 0 ... logpdf += Exponential(lmbda).logpdf(scale) ... logpdf += Normal(0, 1).logpdf(coeff) ... y = np.dot(x, coeff) ... logpdf += Normal(y, coeff).logpdf(predictions) ... return logpdf The logpdf is then partially applied on the dataset {(x, prediction)} for inference. Of course it would impact the core Parameters ---------- model: A probabilistic graphical model. namespace: The names contained in the model definition's global scope. Returns ------- logpdf: A function that returns the log probability of a model at one point the parameter space. var_names: The name of the random variables arguments of the logpdf function, in the order in which they appear. logpdf_source: A string containing the source code of the logpdf. Useful for inspection by the user. """ fn_name = graph.name + "_logpdf" # # ARGUMENTS # kwarg_nodes = [ node[1]["content"] for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is not None ] # The (keyword) arguments of the model definition and random variables # are passed as arguments to the logpdf. model_kwargs = [kwarg.to_logpdf_iadd() for kwarg in kwarg_nodes] model_arguments = [ node[1]["content"].to_logpdf_iadd() for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is None ] random_variables = [ ast.arg(arg=node[1]["content"].name, annotation=None) for node in graph.nodes(data=True) if isinstance(node[1]["content"], RandVar) ] logpdf_arguments = random_variables + model_arguments + model_kwargs # We propagate the kwargs' default values defaults = [kwarg.default_value for kwarg in kwarg_nodes] # # FUNCTION BODY # To write the function body, we traverse the graph in topological order # while incrementing the value of the logpdf. # body: List[Union[ast.Assign, ast.Constant, ast.Num, ast.Return]] = [] body.append( ast.Assign( targets=[ast.Name(id="logpdf", ctx=ast.Store())], value=ast.Constant(value=0), )) ordered_nodes = [ graph.nodes[node]["content"] for node in nx.topological_sort(graph) if not isinstance(graph.nodes[node]["content"], Argument) ] for node in ordered_nodes: body.append(node.to_logpdf_iadd()) returned = ast.Return(value=ast.Name(id="logpdf", ctx=ast.Load())) body.append(returned) logpdf_ast = ast.Module( body=[ ast.FunctionDef( name=fn_name, args=ast.arguments( args=logpdf_arguments, vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], defaults=defaults, kw_defaults=[], ), body=body, decorator_list=[], ) ], type_ignores=[], ) logpdf_ast = ast.fix_missing_locations(logpdf_ast) logpdf = compile(logpdf_ast, filename="<ast>", mode="exec") exec(logpdf, namespace) fn = namespace[fn_name] argument_names = [arg.arg for arg in logpdf_arguments] return Artifact(fn, argument_names, fn_name, astor.code_gen.to_source(logpdf_ast))
def compile_to_loglikelihoods(graph: GraphicalModel, namespace: Dict) -> Artifact: """ Example ------- Let us consider a simple linear regression example: >>> @mcx.model ... def linear_regression(x, lmbda=1.): ... scale <~ Exponential(lmbda) ... coeff <~ Normal(0, 1) ... y = np.dot(x, coeff) ... predictions <~ Normal(y, scale) ... return predictions We can get the log-likelihood contribution of each parameter by doing: >>> def linear_regression_logpdf(x, scale, coeffs, predictions, lmbda=1.): ... logpdf_scale = Exponential(lmbda).logpdf(scale) ... logpdf_coeff = Normal(0, 1).logpdf(coeff) ... y = np.dot(x, coeff) ... logpdf_predictions = Normal(y, coeff).logpdf(predictions) ... return np.array([logpdf_scale, logpdf_coeff, logpdf_predictions]) """ fn_name = graph.name + "_loglikelihoods" # # ARGUMENTS # kwarg_nodes = [ node[1]["content"] for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is not None ] # The (keyword) arguments of the model definition and random variables # are passed as arguments to the logpdf. model_kwargs = [kwarg.to_logpdf() for kwarg in kwarg_nodes] model_arguments = [ node[1]["content"].to_logpdf() for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is None ] random_variables = [ ast.arg(arg=node[1]["content"].name, annotation=None) for node in graph.nodes(data=True) if isinstance(node[1]["content"], RandVar) ] logpdf_arguments = random_variables + model_arguments + model_kwargs # We propagate the kwargs' default values defaults = [kwarg.default_value for kwarg in kwarg_nodes] # # FUNCTION BODY # To write the function body, we traverse the graph in topological order # while incrementing the value of the logpdf. # body: List[Union[ast.Assign, ast.Constant, ast.Num, ast.Return]] = [] ordered_nodes = [ graph.nodes[node]["content"] for node in nx.topological_sort(graph) if not isinstance(graph.nodes[node]["content"], Argument) ] # ordered_transformations = [ # graph.nodes[node]["content"].name # for node in nx.topological_sort(graph) # if isinstance(graph.nodes[node]["content"], Transformation) # ] for node in ordered_nodes: body.append(node.to_logpdf()) # Returned values # # In this situation we want to function to return both the individual # values of the log-likelihood as well as the values of the deterministic # variables. returned = ast.Return( value=ast.Dict( keys=[ ast.Constant(value=f"{name.arg}", kind=None) for name in random_variables ], values=[ ast.Name(id=f"logpdf_{name.arg}", ctx=ast.Load()) for name in random_variables ], ), type_ignores=[], ) body.append(returned) logpdf_ast = ast.Module( body=[ ast.FunctionDef( name=fn_name, args=ast.arguments( args=logpdf_arguments, vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], defaults=defaults, kw_defaults=[], ), body=body, decorator_list=[], ) ], type_ignores=[], ) logpdf_ast = ast.fix_missing_locations(logpdf_ast) logpdf = compile(logpdf_ast, filename="<ast>", mode="exec") exec(logpdf, namespace) fn = namespace[fn_name] argument_names = [arg.arg for arg in logpdf_arguments] return Artifact(fn, argument_names, fn_name, astor.code_gen.to_source(logpdf_ast))
def _logpdf_core(graph: GraphicalModel): """Transform the SampleOps to statements that compute the logpdf associated with the variables' values. """ placeholders = [] logpdf_nodes = [] def sampleop_to_logpdf(cst_generator, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst_generator(*args, **kwargs), cst.Name("logpdf_sum")), [cst.Arg(name)], ) def samplemodelop_to_logpdf(model_name, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst.Name(model_name), cst.Name("logpdf")), list(args) + [cst.Arg(name, star="**")], ) def placeholder_to_param(name: str): return cst.Param(cst.Name(name)) for node in graph.random_variables: if not isinstance(node, SampleModelOp): continue rv_name = node.name returned_var_name = node.graph.returned_variables[0].name def sample_index(rv, returned_var, *_): return cst.Subscript( cst.Name(rv), [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))], ) chosen_sample = Op( partial(sample_index, rv_name, returned_var_name), graph.name, f"{rv_name}_value", ) original_edges = [] data = [] out_nodes = [] for e in graph.out_edges(node): datum = graph.get_edge_data(*e) data.append(datum) original_edges.append(e) out_nodes.append(e[1]) for e in original_edges: graph.remove_edge(*e) graph.add(chosen_sample, node) for e, d in zip(out_nodes, data): graph.add_edge(chosen_sample, e, **d) # We need to loop through the nodes in reverse order because of the compilation # quirk which makes it that nodes added first to the graph appear first in the # functions arguments. This should be taken care of properly before merging. for node in reversed(list(graph.random_variables)): # Create a new placeholder node with the random variable's name. # It represents the value that will be passed to the logpdf. name = node.name rv_placeholder = Placeholder( partial(placeholder_to_param, name), name, is_random_variable=True ) placeholders.append(rv_placeholder) # Transform the SampleOps from `a <~ Normal(0, 1)` into # `lopdf_a = Normal(0, 1).logpdf_sum(a)` if isinstance(node, SampleModelOp): node.cst_generator = partial(samplemodelop_to_logpdf, node.model_name) else: node.cst_generator = partial(sampleop_to_logpdf, node.cst_generator) node.name = f"logpdf_{node.scope}_{node.name}" logpdf_nodes.append(node) for placeholder, node in zip(placeholders, logpdf_nodes): # Add the placeholder to the graph and link it to the expression that # computes the logpdf. So far the expression looks like: # # >>> logpdf_a = Normal(0, 1).logpdf_sum(_) # # `a` is the placeholder and will appear into the arguments of # the function. Below we assign it to `_`. graph.add_node(placeholder) graph.add_edge(placeholder, node, type="kwargs", key=["var_name"]) # Remove edges from the former SampleOp and replace by new placeholder # For instance, assume that part of our model is: # # >>> a <~ Normal(0, 1) # >>> x = jnp.log(a) # # Transformed to a logpdf this would look like: # # >>> logpdf_a = Normal(0, 1).logpdf_sum(a) # >>> x = jnp.log(a) # # Where a is now a placeholder, passed as an argument. The following # code links this placeholder to the expression `jnp.log(a)` and removes # the edge from `a <~ Normal(0, 1)`. # # We cannot remove edges while iterating over the graph, hence the two-step # process. # to_remove = [] successors = list(graph.successors(node)) for s in successors: edge_data = graph.get_edge_data(node, s) graph.add_edge(placeholder, s, **edge_data) for s in successors: graph.remove_edge(node, s) # The original MCX model may return one or many variables. None of # these variables should be returned, so we turn the `is_returned` flag # to `False`. for node in graph.nodes(): if isinstance(node, Op): node.is_returned = False return graph
def compile_to_logpdf(graph: GraphicalModel, namespace: Dict) -> Artifact: """Compile a graphical model into a log-probability density function. Arguments --------- model: A probabilistic graphical model. namespace: The names contained in the model definition's global scope. Returns ------- logpdf: A function that returns the log probability of a model at one point the parameter space. var_names: The name of the random variables arguments of the logpdf function, in the order in which they appear. logpdf_source: A string containing the source code of the logpdf. Useful for inspection by the user. """ fn_name = graph.name + "_logpdf" # # ARGUMENTS # kwarg_nodes = [ node[1]["content"] for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is not None ] # The (keyword) arguments of the model definition and random variables # are passed as arguments to the logpdf. model_kwargs = [kwarg.to_logpdf() for kwarg in kwarg_nodes] model_arguments = [ node[1]["content"].to_logpdf() for node in graph.nodes(data=True) if isinstance(node[1]["content"], Argument) and node[1]["content"].default_value is None ] random_variables = [ ast.arg(arg=node[1]["content"].name, annotation=None) for node in graph.nodes(data=True) if isinstance(node[1]["content"], RandVar) ] logpdf_arguments = random_variables + model_arguments + model_kwargs # We propagate the kwargs' default values defaults = [kwarg.default_value for kwarg in kwarg_nodes] # # FUNCTION BODY # To write the function body, we traverse the graph in topological order # while incrementing the value of the logpdf. # body: List[Union[ast.Assign, ast.Constant, ast.Num, ast.Return]] = [] body.append( ast.Assign( targets=[ast.Name(id="logpdf", ctx=ast.Store())], value=ast.Constant(value=0), )) ordered_nodes = [ graph.nodes[node]["content"] for node in nx.topological_sort(graph) if not isinstance(graph.nodes[node]["content"], Argument) ] for node in ordered_nodes: body.append(node.to_logpdf()) returned = ast.Return(value=ast.Name(id="logpdf", ctx=ast.Load())) body.append(returned) logpdf_ast = ast.Module( body=[ ast.FunctionDef( name=fn_name, args=ast.arguments( args=logpdf_arguments, vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], defaults=defaults, kw_defaults=[], ), body=body, decorator_list=[], ) ], type_ignores=[], ) logpdf_ast = ast.fix_missing_locations(logpdf_ast) logpdf = compile(logpdf_ast, filename="<ast>", mode="exec") exec(logpdf, namespace) fn = namespace[fn_name] argument_names = [arg.arg for arg in logpdf_arguments] return Artifact(fn, argument_names, fn_name, astor.code_gen.to_source(logpdf_ast))