예제 #1
def logpdf(model):
    """Returns a function that computes the log-probability."""
    graph = copy.deepcopy(model.graph)
    graph = _logpdf_core(graph)

    # Create a new `logpdf` node that is the sum of the contributions of each variable.
    def to_sum_of_logpdf(*args):
        def add(left, right):
            return cst.BinaryOperation(left, cst.Add(), right)

        args = list(args)
        if len(args) == 1:
            return args[0]
        elif len(args) == 2:
            left = args[0]
            right = args[1]
            return add(left, right)

        right = args.pop()
        left = args.pop()
        expr = add(left, right)
        for arg in args:
            expr = add(expr, arg)

        return expr

    logpdf_contribs = [node for node in graph if isinstance(node, SampleOp)]
    sum_node = Op(to_sum_of_logpdf, graph.name, "logpdf", is_returned=True)
    graph.add(sum_node, *logpdf_contribs)

    return compile_graph(graph, model.namespace, f"{graph.name}_logpdf")
예제 #2
def compile_op(node: Op, graph: GraphicalModel):
    """Compile an Op by recursively compiling and including its
    upstream nodes.
    op_args = {}
    op_kwargs = {}
    for predecessor in graph.predecessors(node):

        # I am not sure what this is about, consider deleting.
        if graph[predecessor][node] == {}:

        # If a predecessor has a name, it is either a random or
        # a deterministic variable. We only need to reference its
        # name here.
        if predecessor.name is not None:
            pred_ast = cst.Name(value=predecessor.name)
            pred_ast = compile_op(predecessor, graph)

        # To rebuild the node's CST we need to pass the compiled
        # CST of the arguments as arguments to the generator function.
        # !! Important !!
        # The compiler feeds the arguments in the order in which they
        # were added to the graph, not the order in which they were
        # given to `graph.add`. This is a potential source of mysterious
        # bugs and should be corrected.
        # This also means we do not feed repeated arguments several times
        # when we should.
        edge = graph[predecessor][node]
        if edge["type"] == "arg":
            for idx in edge["position"]:
                if idx in op_args:
                    raise IndexError("Duplicate argument position")
                op_args[idx] = pred_ast
            for key in graph[predecessor][node]["key"]:
                op_kwargs[key] = pred_ast

    args = [op_args[idx] for idx in sorted(op_args.keys())]

    return node.cst_generator(*args, **op_kwargs)
예제 #3
    def recursive_visit(self, node) -> Union[Constant, Op]:
        """Recursively visit the node and populate the graph with the traversed nodes.

        The recursion ends when the CST node being visited is a `Name` or a
        `BaseNumber` node.  While we follow strictly libcst's CST
        decomposition, it may be desirable to simplify the graph for our
        purposes. For instance:

        - Slices and subscripts. There is no reason to detail the succession of
          nodes in the graph. It can either be a constant (only depends on numerical constants),
          or a function of other variables.
        - Functions like `np.dot`. `np` and `dot` are currently store in different Ops. We should
          merge these.

        TODO: Implement a function that takes a GraphicalModel and applies
        these simplifications. This will be necessary when sampling
        deterministic functions.

        This function could be re-rewritten using functools'
        `singledispatchmethod` but it is not available for Python 3.7. While
        there is a library that does backward-compatibility I prefered to avoid
        adding a dependency.

        if isinstance(node, cst.Name):
            """If the node corresponds to a placeholder or a named op its name
            should be registered in `named_variables`. Otherwise it corresponds
            to the name of an attribute.

                name = self.named_variables[node.value]
            except KeyError:
                name = Name(lambda: node, node.value)
            return name

        if isinstance(node, cst.BaseNumber):
            new_node = Constant(lambda: node)
            return new_node

        # Parse function calls

        if isinstance(node, cst.Call):
            func = self.recursive_visit(node.func)
            args = [self.recursive_visit(arg) for arg in node.args]

            def to_call_cst(*args, **kwargs):
                # I don't exactly remember why we pass the `func` as a keyword
                # argument, but I think it has something to do with the fact
                # that at compilation the arguments are passed in the order they
                # were introduced in the graph, and nodes are deleted/re-inserted
                # when transforming to get logpdf and samplers.
                func = kwargs["__name__"]
                return cst.Call(func, args)

            op = Op(to_call_cst, self.scope)
            self.graph.add(op, *args, __name__=func)
            return op

        if isinstance(node, cst.Arg):
            value = self.recursive_visit(node.value)

            def to_arg_cst(value):
                return cst.Arg(value, node.keyword)

            op = Op(to_arg_cst, self.scope)
            self.graph.add(op, value)
            return op

        if isinstance(node, cst.Attribute):
            value = self.recursive_visit(node.value)
            attr = self.recursive_visit(node.attr)

            def to_attribute_cst(value, attr):
                return cst.Attribute(value, attr)

            op = Op(to_attribute_cst, self.scope)
            self.graph.add(op, value, attr)
            return op

        # Parse lists and tuples

        if isinstance(node, cst.List):
            elements = [self.recursive_visit(e) for e in node.elements]

            def to_list_cst(*list_elements):
                return cst.List(list_elements)

            op = Op(to_list_cst, self.scope)
            self.graph.add(op, *elements)

            return op

        if isinstance(node, cst.Element):
            value = self.recursive_visit(node.value)

            def to_element_cst(value):
                return cst.Element(value)

            op = Op(to_element_cst, self.scope)
            self.graph.add(op, value)

            return op

        # Parse slices and subscripts

        if isinstance(node, cst.Subscript):
            value = self.recursive_visit(node.value)
            slice_elements = [self.recursive_visit(s) for s in node.slice]

            def to_subscript_cst(value, *slice_elements):
                return cst.Subscript(value, slice_elements)

            op = Op(to_subscript_cst, self.scope)
            self.graph.add(op, value, *slice_elements)

            return op

        if isinstance(node, cst.SubscriptElement):
            sl = self.recursive_visit(node.slice)

            def to_subscript_element_cst(sl):
                return cst.SubscriptElement(sl)

            op = Op(to_subscript_element_cst, self.scope)
            self.graph.add(op, sl)

            return op

        if isinstance(node, cst.Index):
            value = self.recursive_visit(node.value)

            def to_index_cst(value):
                return cst.Index(value)

            op = Op(to_index_cst, self.scope)
            self.graph.add(op, value)

            return op

        # Parse Binary and Unary operations

        if isinstance(node, cst.BinaryOperation):
            left = self.recursive_visit(node.left)
            right = self.recursive_visit(node.right)

            def to_binary_operation_cst(left, right):
                return cst.BinaryOperation(left, node.operator, right=right)

            op = Op(to_binary_operation_cst, self.scope)
            self.graph.add(op, left, right)

            return op

        if isinstance(node, cst.UnaryOperation):
            expression = self.recursive_visit(node.expression)

            def to_unary_operation_cst(expression):
                return cst.UnaryOperation(node.operator, expression)

            op = Op(to_unary_operation_cst, self.scope)
            self.graph.add(op, expression)

            return op

        # In case we missed an important statement or expression leave a friendly error
        # message and redirect the user to the issue tracker to let us know.
        raise TypeError(
            f"The CST node {node.__class__.__name__} is currently not supported by MCX's parser. "
            "Please open an issue on https://github.com/rlouf/mcx so we can integrate it."
예제 #4
def logpdf_contributions(model):
    """Return the variables' individual constributions to the logpdf.

    The function returns a dictionary {'var_name': logpdf_contribution}. When
    there are several scopes it returns a nested dictionary {'scope':
    {'var_name': logpdf_contribution}} to avoid name conflicts.

    We cheat a little here: the function that returns the ast takes the contrib
    nodes as arguments, but these are not used: the content of the function is
    fully determined before adding the node to the graph. We do not have a
    choice because it is currently impossible to pass context (variable name
    and scope name) at compilation.

    graph = copy.deepcopy(model.graph)
    graph = _logpdf_core(graph)

    # add a new node, a dictionary that contains the contribution of each
    # variable to the log-probability.
    logpdf_contribs = [node for node in graph if isinstance(node, SampleOp)]

    scopes = set()
    scope_map = defaultdict(dict)
    for contrib in logpdf_contribs:
        var_name = (contrib.name).replace(f"logpdf_{contrib.scope}_", "")
        scope_map[contrib.scope][var_name] = contrib.name

    def to_dictionary_of_contributions(*_):

        # if there is only one scope we return a flat dictionary {'var': logpdf_var}
        num_scopes = len(scopes)
        if num_scopes == 1:
            scope = scopes.pop()
            return cst.Dict(
                        cst.SimpleString(f"'{var_name}'"), cst.Name(contrib_name)
                    for var_name, contrib_name in scope_map[scope].items()

        # Otherwise we return a nested dictionary where the first level is
        # the scope, and then the variables {'model': {}, 'submodel': {}}
        return cst.Dict(
                            for var_name, contrib_name in scope_map[scope].items()
                for scope in scopes

    dict_node = Op(
    graph.add(dict_node, *logpdf_contribs)

    return compile_graph(graph, model.namespace, f"{graph.name}_logpdf_contribs")
예제 #5
def sample_posterior_predictive(model, node_names):
    """Sample from the posterior predictive distribution.


    We transform MCX models of the form:

        >>> def linear_regression(X, lmbda=1.):
        ...    scale <~ Exponential(lmbda)
        ...    coef <~ Normal(jnp.zeros(X.shape[0]), 1)
        ...    y = jnp.dot(X, coef)
        ...    pred <~ Normal(y, scale)
        ...    return pred


        >>> def linear_regression_pred(rng_key, scale, coef, X, lambda=1.):
        ...    idx = jax.random.choice(rng_key, scale.shape[0])
        ...    scale_sample = scale[idx]
        ...    coef_sample = coef[idx]
        ...    y = jnp.dot(X, coef_sample)
        ...    pred = Normal(y, scale_sample).sample(rng_key)
        ...    return pred

    graph = copy.deepcopy(model.graph)
    nodes = [graph.find_node(name) for name in node_names]

    # We will need to pass a RNG key to the function to sample from
    # the distributions; we create a placeholder for this key.
    rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key")

    # To take a sampler from the posterior distribution we first choose a sample id
    # at random `idx = mcx.jax.choice(rng_key, num_samples)`. We later index each
    # array of samples passed by this `idx`.
    def choice_ast(rng_key):
        return cst.Call(
                value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")),
                        cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")),

    choice_node = Op(choice_ast, graph.name, "idx")
    graph.add(choice_node, rng_node)

    # Remove all edges incoming to the nodes that are targetted
    # by the intervention.
    to_remove = []
    for e in graph.in_edges(nodes):

    for edge in to_remove:

    # Each SampleOp that is intervened on is replaced by a placeholder that is indexed
    # by the index of the sample being taken.
    for node in reversed(nodes):
        rv_name = node.name

        # Add the placeholder
        placeholder = Placeholder(
            partial(lambda name: cst.Param(cst.Name(name)), rv_name),

        def sample_index(placeholder, idx):
            return cst.Subscript(placeholder, [cst.SubscriptElement(cst.Index(idx))])

        chosen_sample = Op(sample_index, graph.name, rv_name + "_sample")
        graph.add(chosen_sample, placeholder, choice_node)

        original_edges = []
        for e in graph.out_edges(node):
            data = graph.get_edge_data(*e)
            graph.add_edge(chosen_sample, e[1], **data)

        for e in original_edges:


    # recursively remove every node that has no outgoing edge and is not
    # returned
    graph = remove_dangling_nodes(graph)

    # replace SampleOps by sampling instruction
    def to_sampler(cst_generator, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
                value=cst_generator(*args, **kwargs), attr=cst.Name("sample")

    random_variables = []
    for node in reversed(list(graph.nodes())):
        if not isinstance(node, SampleOp):
        node.cst_generator = partial(to_sampler, node.cst_generator)

    # Add the placeholders to the graph
    for var in random_variables:
        graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"])

    return compile_graph(
        graph, model.namespace, f"{graph.name}_sample_posterior_predictive"
예제 #6
def sample_joint(model):
    """Obtain forward samples from the joint distribution defined by the model."""
    graph = copy.deepcopy(model.graph)
    namespace = model.namespace

    def to_dictionary_of_samples(random_variables, *_):
        scopes = [rv.scope for rv in random_variables]
        names = [rv.name for rv in random_variables]

        scoped = defaultdict(dict)
        for scope, var_name, var in zip(scopes, names, random_variables):
            scoped[scope][var_name] = var

        # if there is only one scope (99% of models) we return a flat dictionary
        if len(set(scopes)) == 1:

            scope = scopes[0]
            return cst.Dict(
                    for var_name, var in scoped[scope].items()

        # Otherwise we return a nested dictionary where the first level is
        # the scope, and then the variables.
        return cst.Dict(
                            for var_name, var in scoped[scope].items()
                for scope in scoped.keys()

    # no node is returned anymore
    for node in graph.nodes():
        if isinstance(node, Op):
            node.is_returned = False

    rng_node = Placeholder(lambda: cst.Param(cst.Name(value="rng_key")), "rng_key")

    # Update the SampleOps to return a sample from the distribution so that
    # `a <~ Normal(0, 1)` becomes `a = Normal(0, 1).sample(rng_key)`.
    def distribution_to_sampler(cst_generator, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(cst_generator(*args, **kwargs), cst.Name("sample")),

    def model_to_sampler(model_name, *args, **kwargs):
        rng_key = kwargs.pop("rng_key")
        return cst.Call(
            func=cst.Attribute(cst.Name(value=model_name), cst.Name("sample")),
            args=[cst.Arg(value=rng_key)] + list(args),

    random_variables = []
    for node in reversed(list(graph.random_variables)):
        if isinstance(node, SampleModelOp):
            node.cst_generator = partial(model_to_sampler, node.model_name)
            node.cst_generator = partial(distribution_to_sampler, node.cst_generator)

    # Link the `rng_key` placeholder to the sampling expressions
    for var in random_variables:
        graph.add_edge(rng_node, var, type="kwargs", key=["rng_key"])

    for node in graph.random_variables:
        if not isinstance(node, SampleModelOp):

        rv_name = node.name
        returned_var_name = node.graph.returned_variables[0].name

        def sample_index(rv, returned_var, *_):
            return cst.Subscript(

        chosen_sample = Op(
            partial(sample_index, rv_name, returned_var_name),
            rv_name + "_value",

        original_edges = []
        data = []
        out_nodes = []
        for e in graph.out_edges(node):
            datum = graph.get_edge_data(*e)

        for e in original_edges:

        graph.add(chosen_sample, node)
        for e, d in zip(out_nodes, data):
            graph.add_edge(chosen_sample, e, **d)

    tuple_node = Op(
        partial(to_dictionary_of_samples, graph.random_variables),
    graph.add(tuple_node, *graph.random_variables)

    return compile_graph(graph, namespace, f"{graph.name}_sample_forward")
예제 #7
def _logpdf_core(graph: GraphicalModel):
    """Transform the SampleOps to statements that compute the logpdf associated
    with the variables' values.
    placeholders = []
    logpdf_nodes = []

    def sampleop_to_logpdf(cst_generator, *args, **kwargs):
        name = kwargs.pop("var_name")
        return cst.Call(
            cst.Attribute(cst_generator(*args, **kwargs), cst.Name("logpdf_sum")),

    def samplemodelop_to_logpdf(model_name, *args, **kwargs):
        name = kwargs.pop("var_name")
        return cst.Call(
            cst.Attribute(cst.Name(model_name), cst.Name("logpdf")),
            list(args) + [cst.Arg(name, star="**")],

    def placeholder_to_param(name: str):
        return cst.Param(cst.Name(name))

    for node in graph.random_variables:
        if not isinstance(node, SampleModelOp):

        rv_name = node.name
        returned_var_name = node.graph.returned_variables[0].name

        def sample_index(rv, returned_var, *_):
            return cst.Subscript(

        chosen_sample = Op(
            partial(sample_index, rv_name, returned_var_name),

        original_edges = []
        data = []
        out_nodes = []
        for e in graph.out_edges(node):
            datum = graph.get_edge_data(*e)

        for e in original_edges:

        graph.add(chosen_sample, node)
        for e, d in zip(out_nodes, data):
            graph.add_edge(chosen_sample, e, **d)

    # We need to loop through the nodes in reverse order because of the compilation
    # quirk which makes it that nodes added first to the graph appear first in the
    # functions arguments. This should be taken care of properly before merging.
    for node in reversed(list(graph.random_variables)):

        # Create a new placeholder node with the random variable's name.
        # It represents the value that will be passed to the logpdf.
        name = node.name
        rv_placeholder = Placeholder(
            partial(placeholder_to_param, name), name, is_random_variable=True

        # Transform the SampleOps from `a <~ Normal(0, 1)` into
        # `lopdf_a = Normal(0, 1).logpdf_sum(a)`
        if isinstance(node, SampleModelOp):
            node.cst_generator = partial(samplemodelop_to_logpdf, node.model_name)
            node.cst_generator = partial(sampleop_to_logpdf, node.cst_generator)
        node.name = f"logpdf_{node.scope}_{node.name}"

    for placeholder, node in zip(placeholders, logpdf_nodes):
        # Add the placeholder to the graph and link it to the expression that
        # computes the logpdf. So far the expression looks like:
        #    >>> logpdf_a = Normal(0, 1).logpdf_sum(_)
        # `a` is the placeholder and will appear into the arguments of
        # the function. Below we assign it to `_`.
        graph.add_edge(placeholder, node, type="kwargs", key=["var_name"])

        # Remove edges from the former SampleOp and replace by new placeholder
        # For instance, assume that part of our model is:
        #     >>> a <~ Normal(0, 1)
        #     >>> x = jnp.log(a)
        # Transformed to a logpdf this would look like:
        #     >>> logpdf_a = Normal(0, 1).logpdf_sum(a)
        #     >>> x = jnp.log(a)
        # Where a is now a placeholder, passed as an argument. The following
        # code links this placeholder to the expression `jnp.log(a)` and removes
        # the edge from `a <~ Normal(0, 1)`.
        # We cannot remove edges while iterating over the graph, hence the two-step
        # process.
        # to_remove = []
        successors = list(graph.successors(node))
        for s in successors:
            edge_data = graph.get_edge_data(node, s)
            graph.add_edge(placeholder, s, **edge_data)

        for s in successors:
            graph.remove_edge(node, s)

    # The original MCX model may return one or many variables. None of
    # these variables should be returned, so we turn the `is_returned` flag
    # to `False`.
    for node in graph.nodes():
        if isinstance(node, Op):
            node.is_returned = False

    return graph