def logpdf(model): """Returns a function that computes the log-probability.""" graph = copy.deepcopy(model.graph) graph = _logpdf_core(graph) # Create a new `logpdf` node that is the sum of the contributions of each variable. def to_sum_of_logpdf(*args): def add(left, right): return cst.BinaryOperation(left, cst.Add(), right) args = list(args) if len(args) == 1: return args[0] elif len(args) == 2: left = args[0] right = args[1] return add(left, right) right = args.pop() left = args.pop() expr = add(left, right) for arg in args: expr = add(expr, arg) return expr logpdf_contribs = [node for node in graph if isinstance(node, SampleOp)] sum_node = Op(to_sum_of_logpdf, graph.name, "logpdf", is_returned=True) graph.add(sum_node, *logpdf_contribs) return compile_graph(graph, model.namespace, f"{graph.name}_logpdf")
def compile_op(node: Op, graph: GraphicalModel): """Compile an Op by recursively compiling and including its upstream nodes. """ op_args = {} op_kwargs = {} for predecessor in graph.predecessors(node): # I am not sure what this is about, consider deleting. if graph[predecessor][node] == {}: pass # If a predecessor has a name, it is either a random or # a deterministic variable. We only need to reference its # name here. if predecessor.name is not None: pred_ast = cst.Name(value=predecessor.name) else: pred_ast = compile_op(predecessor, graph) # To rebuild the node's CST we need to pass the compiled # CST of the arguments as arguments to the generator function. # # !! Important !! # # The compiler feeds the arguments in the order in which they # were added to the graph, not the order in which they were # given to `graph.add`. This is a potential source of mysterious # bugs and should be corrected. # This also means we do not feed repeated arguments several times # when we should. edge = graph[predecessor][node] if edge["type"] == "arg": for idx in edge["position"]: if idx in op_args: raise IndexError("Duplicate argument position") op_args[idx] = pred_ast else: for key in graph[predecessor][node]["key"]: op_kwargs[key] = pred_ast args = [op_args[idx] for idx in sorted(op_args.keys())] return node.cst_generator(*args, **op_kwargs)
def recursive_visit(self, node) -> Union[Constant, Op]: """Recursively visit the node and populate the graph with the traversed nodes. The recursion ends when the CST node being visited is a `Name` or a `BaseNumber` node. While we follow strictly libcst's CST decomposition, it may be desirable to simplify the graph for our purposes. For instance: - Slices and subscripts. There is no reason to detail the succession of nodes in the graph. It can either be a constant (only depends on numerical constants), or a function of other variables. - Functions like `np.dot`. `np` and `dot` are currently store in different Ops. We should merge these. TODO: Implement a function that takes a GraphicalModel and applies these simplifications. This will be necessary when sampling deterministic functions. Note ---- This function could be re-rewritten using functools' `singledispatchmethod` but it is not available for Python 3.7. While there is a library that does backward-compatibility I prefered to avoid adding a dependency. """ if isinstance(node, cst.Name): """If the node corresponds to a placeholder or a named op its name should be registered in `named_variables`. Otherwise it corresponds to the name of an attribute. """ try: name = self.named_variables[node.value] except KeyError: name = Name(lambda: node, node.value) return name if isinstance(node, cst.BaseNumber): new_node = Constant(lambda: node) return new_node # Parse function calls if isinstance(node, cst.Call): func = self.recursive_visit(node.func) args = [self.recursive_visit(arg) for arg in node.args] def to_call_cst(*args, **kwargs): # I don't exactly remember why we pass the `func` as a keyword # argument, but I think it has something to do with the fact # that at compilation the arguments are passed in the order they # were introduced in the graph, and nodes are deleted/re-inserted # when transforming to get logpdf and samplers. func = kwargs["__name__"] return cst.Call(func, args) op = Op(to_call_cst, self.scope) self.graph.add(op, *args, __name__=func) return op if isinstance(node, cst.Arg): value = self.recursive_visit(node.value) def to_arg_cst(value): return cst.Arg(value, node.keyword) op = Op(to_arg_cst, self.scope) self.graph.add(op, value) return op if isinstance(node, cst.Attribute): value = self.recursive_visit(node.value) attr = self.recursive_visit(node.attr) def to_attribute_cst(value, attr): return cst.Attribute(value, attr) op = Op(to_attribute_cst, self.scope) self.graph.add(op, value, attr) return op # Parse lists and tuples if isinstance(node, cst.List): elements = [self.recursive_visit(e) for e in node.elements] def to_list_cst(*list_elements): return cst.List(list_elements) op = Op(to_list_cst, self.scope) self.graph.add(op, *elements) return op if isinstance(node, cst.Element): value = self.recursive_visit(node.value) def to_element_cst(value): return cst.Element(value) op = Op(to_element_cst, self.scope) self.graph.add(op, value) return op # Parse slices and subscripts if isinstance(node, cst.Subscript): value = self.recursive_visit(node.value) slice_elements = [self.recursive_visit(s) for s in node.slice] def to_subscript_cst(value, *slice_elements): return cst.Subscript(value, slice_elements) op = Op(to_subscript_cst, self.scope) self.graph.add(op, value, *slice_elements) return op if isinstance(node, cst.SubscriptElement): sl = self.recursive_visit(node.slice) def to_subscript_element_cst(sl): return cst.SubscriptElement(sl) op = Op(to_subscript_element_cst, self.scope) self.graph.add(op, sl) return op if isinstance(node, cst.Index): value = self.recursive_visit(node.value) def to_index_cst(value): return cst.Index(value) op = Op(to_index_cst, self.scope) self.graph.add(op, value) return op # Parse Binary and Unary operations if isinstance(node, cst.BinaryOperation): left = self.recursive_visit(node.left) right = self.recursive_visit(node.right) def to_binary_operation_cst(left, right): return cst.BinaryOperation(left, node.operator, right=right) op = Op(to_binary_operation_cst, self.scope) self.graph.add(op, left, right) return op if isinstance(node, cst.UnaryOperation): expression = self.recursive_visit(node.expression) def to_unary_operation_cst(expression): return cst.UnaryOperation(node.operator, expression) op = Op(to_unary_operation_cst, self.scope) self.graph.add(op, expression) return op # In case we missed an important statement or expression leave a friendly error # message and redirect the user to the issue tracker to let us know. raise TypeError( f"The CST node {node.__class__.__name__} is currently not supported by MCX's parser. " "Please open an issue on https://github.com/rlouf/mcx so we can integrate it." )
def logpdf_contributions(model): """Return the variables' individual constributions to the logpdf. The function returns a dictionary {'var_name': logpdf_contribution}. When there are several scopes it returns a nested dictionary {'scope': {'var_name': logpdf_contribution}} to avoid name conflicts. We cheat a little here: the function that returns the ast takes the contrib nodes as arguments, but these are not used: the content of the function is fully determined before adding the node to the graph. We do not have a choice because it is currently impossible to pass context (variable name and scope name) at compilation. """ graph = copy.deepcopy(model.graph) graph = _logpdf_core(graph) # add a new node, a dictionary that contains the contribution of each # variable to the log-probability. logpdf_contribs = [node for node in graph if isinstance(node, SampleOp)] scopes = set() scope_map = defaultdict(dict) for contrib in logpdf_contribs: var_name = (contrib.name).replace(f"logpdf_{contrib.scope}_", "") scope_map[contrib.scope][var_name] = contrib.name scopes.add(contrib.scope) def to_dictionary_of_contributions(*_): # if there is only one scope we return a flat dictionary {'var': logpdf_var} num_scopes = len(scopes) if num_scopes == 1: scope = scopes.pop() return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(contrib_name) ) for var_name, contrib_name in scope_map[scope].items() ] ) # Otherwise we return a nested dictionary where the first level is # the scope, and then the variables {'model': {}, 'submodel': {}} return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{scope}'"), cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(contrib_name), ) for var_name, contrib_name in scope_map[scope].items() ] ), ) for scope in scopes ] ) dict_node = Op( to_dictionary_of_contributions, graph.name, "logpdf_contributions", is_returned=True, ) graph.add(dict_node, *logpdf_contribs) return compile_graph(graph, model.namespace, f"{graph.name}_logpdf_contribs")
def sample_posterior_predictive(model, node_names): """Sample from the posterior predictive distribution. Example ------- We transform MCX models of the form: >>> def linear_regression(X, lmbda=1.): ... scale <~ Exponential(lmbda) ... coef <~ Normal(jnp.zeros(X.shape[0]), 1) ... y = jnp.dot(X, coef) ... pred <~ Normal(y, scale) ... return pred into: >>> def linear_regression_pred(rng_key, scale, coef, X, lambda=1.): ... idx = jax.random.choice(rng_key, scale.shape[0]) ... scale_sample = scale[idx] ... coef_sample = coef[idx] ... y = jnp.dot(X, coef_sample) ... pred = Normal(y, scale_sample).sample(rng_key) ... return pred """ graph = copy.deepcopy(model.graph) nodes = [graph.find_node(name) for name in node_names] # We will need to pass a RNG key to the function to sample from # the distributions; we create a placeholder for this key. rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key") graph.add_node(rng_node) # To take a sampler from the posterior distribution we first choose a sample id # at random `idx = mcx.jax.choice(rng_key, num_samples)`. We later index each # array of samples passed by this `idx`. def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], ) choice_node = Op(choice_ast, graph.name, "idx") graph.add(choice_node, rng_node) # Remove all edges incoming to the nodes that are targetted # by the intervention. to_remove = [] for e in graph.in_edges(nodes): to_remove.append(e) for edge in to_remove: graph.remove_edge(*edge) # Each SampleOp that is intervened on is replaced by a placeholder that is indexed # by the index of the sample being taken. for node in reversed(nodes): rv_name = node.name # Add the placeholder placeholder = Placeholder( partial(lambda name: cst.Param(cst.Name(name)), rv_name), rv_name, is_random_variable=True, ) graph.add_node(placeholder) def sample_index(placeholder, idx): return cst.Subscript(placeholder, [cst.SubscriptElement(cst.Index(idx))]) chosen_sample = Op(sample_index, graph.name, rv_name + "_sample") graph.add(chosen_sample, placeholder, choice_node) original_edges = [] for e in graph.out_edges(node): data = graph.get_edge_data(*e) original_edges.append(e) graph.add_edge(chosen_sample, e[1], **data) for e in original_edges: graph.remove_edge(*e) graph.remove_node(node) # recursively remove every node that has no outgoing edge and is not # returned graph = remove_dangling_nodes(graph) # replace SampleOps by sampling instruction def to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute( value=cst_generator(*args, **kwargs), attr=cst.Name("sample") ), args=[cst.Arg(value=rng_key)], ) random_variables = [] for node in reversed(list(graph.nodes())): if not isinstance(node, SampleOp): continue node.cst_generator = partial(to_sampler, node.cst_generator) random_variables.append(node) # Add the placeholders to the graph for var in random_variables: graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"]) return compile_graph( graph, model.namespace, f"{graph.name}_sample_posterior_predictive" )
def sample_joint(model): """Obtain forward samples from the joint distribution defined by the model.""" graph = copy.deepcopy(model.graph) namespace = model.namespace def to_dictionary_of_samples(random_variables, *_): scopes = [rv.scope for rv in random_variables] names = [rv.name for rv in random_variables] scoped = defaultdict(dict) for scope, var_name, var in zip(scopes, names, random_variables): scoped[scope][var_name] = var # if there is only one scope (99% of models) we return a flat dictionary if len(set(scopes)) == 1: scope = scopes[0] return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(var.name), ) for var_name, var in scoped[scope].items() ] ) # Otherwise we return a nested dictionary where the first level is # the scope, and then the variables. return cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{scope}'"), cst.Dict( [ cst.DictElement( cst.SimpleString(f"'{var_name}'"), cst.Name(var.name), ) for var_name, var in scoped[scope].items() ] ), ) for scope in scoped.keys() ] ) # no node is returned anymore for node in graph.nodes(): if isinstance(node, Op): node.is_returned = False rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key") # Update the SampleOps to return a sample from the distribution so that # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`. def distribution_to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")), args=[cst.Arg(value=rng_key)], ) def model_to_sampler(model_name, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(cst.Name(value=model_name), cst.Name("sample")), args=[cst.Arg(value=rng_key)] + list(args), ) random_variables = [] for node in reversed(list(graph.random_variables)): if isinstance(node, SampleModelOp): node.cst_generator = partial(model_to_sampler, node.model_name) else: node.cst_generator = partial(distribution_to_sampler, node.cst_generator) random_variables.append(node) # Link the `rng_key` placeholder to the sampling expressions graph.add(rng_node) for var in random_variables: graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"]) for node in graph.random_variables: if not isinstance(node, SampleModelOp): continue rv_name = node.name returned_var_name = node.graph.returned_variables[0].name def sample_index(rv, returned_var, *_): return cst.Subscript( cst.Name(rv), [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))], ) chosen_sample = Op( partial(sample_index, rv_name, returned_var_name), graph.name, rv_name + "_value", ) original_edges = [] data = [] out_nodes = [] for e in graph.out_edges(node): datum = graph.get_edge_data(*e) data.append(datum) original_edges.append(e) out_nodes.append(e[1]) for e in original_edges: graph.remove_edge(*e) graph.add(chosen_sample, node) for e, d in zip(out_nodes, data): graph.add_edge(chosen_sample, e, **d) tuple_node = Op( partial(to_dictionary_of_samples, graph.random_variables), graph.name, "forward_samples", is_returned=True, ) graph.add(tuple_node, *graph.random_variables) return compile_graph(graph, namespace, f"{graph.name}_sample_forward")
def _logpdf_core(graph: GraphicalModel): """Transform the SampleOps to statements that compute the logpdf associated with the variables' values. """ placeholders = [] logpdf_nodes = [] def sampleop_to_logpdf(cst_generator, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst_generator(*args, **kwargs), cst.Name("logpdf_sum")), [cst.Arg(name)], ) def samplemodelop_to_logpdf(model_name, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst.Name(model_name), cst.Name("logpdf")), list(args) + [cst.Arg(name, star="**")], ) def placeholder_to_param(name: str): return cst.Param(cst.Name(name)) for node in graph.random_variables: if not isinstance(node, SampleModelOp): continue rv_name = node.name returned_var_name = node.graph.returned_variables[0].name def sample_index(rv, returned_var, *_): return cst.Subscript( cst.Name(rv), [cst.SubscriptElement(cst.SimpleString(f"'{returned_var}'"))], ) chosen_sample = Op( partial(sample_index, rv_name, returned_var_name), graph.name, f"{rv_name}_value", ) original_edges = [] data = [] out_nodes = [] for e in graph.out_edges(node): datum = graph.get_edge_data(*e) data.append(datum) original_edges.append(e) out_nodes.append(e[1]) for e in original_edges: graph.remove_edge(*e) graph.add(chosen_sample, node) for e, d in zip(out_nodes, data): graph.add_edge(chosen_sample, e, **d) # We need to loop through the nodes in reverse order because of the compilation # quirk which makes it that nodes added first to the graph appear first in the # functions arguments. This should be taken care of properly before merging. for node in reversed(list(graph.random_variables)): # Create a new placeholder node with the random variable's name. # It represents the value that will be passed to the logpdf. name = node.name rv_placeholder = Placeholder( partial(placeholder_to_param, name), name, is_random_variable=True ) placeholders.append(rv_placeholder) # Transform the SampleOps from `a <~ Normal(0, 1)` into # `lopdf_a = Normal(0, 1).logpdf_sum(a)` if isinstance(node, SampleModelOp): node.cst_generator = partial(samplemodelop_to_logpdf, node.model_name) else: node.cst_generator = partial(sampleop_to_logpdf, node.cst_generator) node.name = f"logpdf_{node.scope}_{node.name}" logpdf_nodes.append(node) for placeholder, node in zip(placeholders, logpdf_nodes): # Add the placeholder to the graph and link it to the expression that # computes the logpdf. So far the expression looks like: # # >>> logpdf_a = Normal(0, 1).logpdf_sum(_) # # `a` is the placeholder and will appear into the arguments of # the function. Below we assign it to `_`. graph.add_node(placeholder) graph.add_edge(placeholder, node, type="kwargs", key=["var_name"]) # Remove edges from the former SampleOp and replace by new placeholder # For instance, assume that part of our model is: # # >>> a <~ Normal(0, 1) # >>> x = jnp.log(a) # # Transformed to a logpdf this would look like: # # >>> logpdf_a = Normal(0, 1).logpdf_sum(a) # >>> x = jnp.log(a) # # Where a is now a placeholder, passed as an argument. The following # code links this placeholder to the expression `jnp.log(a)` and removes # the edge from `a <~ Normal(0, 1)`. # # We cannot remove edges while iterating over the graph, hence the two-step # process. # to_remove = [] successors = list(graph.successors(node)) for s in successors: edge_data = graph.get_edge_data(node, s) graph.add_edge(placeholder, s, **edge_data) for s in successors: graph.remove_edge(node, s) # The original MCX model may return one or many variables. None of # these variables should be returned, so we turn the `is_returned` flag # to `False`. for node in graph.nodes(): if isinstance(node, Op): node.is_returned = False return graph