Пример #1
0
    def check_integrity(self):
        """
        Call this for a diagnosis if things go awry.

        """
        nodes = set(applys_between(self.inputs, self.outputs))
        if self.apply_nodes != nodes:
            missing = nodes.difference(self.apply_nodes)
            excess = self.apply_nodes.difference(nodes)
            raise Exception(
                "The nodes are inappropriately cached. missing, in excess: ",
                missing,
                excess,
            )
        for node in nodes:
            for i, variable in enumerate(node.inputs):
                clients = self.clients[variable]
                if (node, i) not in clients:
                    raise Exception(
                        f"Inconsistent clients list {(node, i)} in {clients}"
                    )
        variables = set(vars_between(self.inputs, self.outputs))
        if set(self.variables) != variables:
            missing = variables.difference(self.variables)
            excess = self.variables.difference(variables)
            raise Exception(
                "The variables are inappropriately cached. missing, in excess: ",
                missing,
                excess,
            )
        for variable in variables:
            if (
                variable.owner is None
                and variable not in self.inputs
                and not isinstance(variable, Constant)
            ):
                raise Exception(f"Undeclared input: {variable}")
            for node, i in self.clients[variable]:
                if node == "output":
                    if self.outputs[i] is not variable:
                        raise Exception(
                            f"Inconsistent clients list: {variable}, {self.outputs[i]}"
                        )
                    continue
                if node not in nodes:
                    raise Exception(
                        f"Client not in FunctionGraph: {variable}, {(node, i)}"
                    )
                if node.inputs[i] is not variable:
                    raise Exception(
                        f"Inconsistent clients list: {variable}, {node.inputs[i]}"
                    )
Пример #2
0
def test_variables_and_orphans():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    vars_res = vars_between([r1, r2], [o2])
    orphans_res = orphans_between([r1, r2], [o2])

    vars_res_list = list(vars_res)
    orphans_res_list = list(orphans_res)
    assert vars_res_list == [o2, o1, r3, r2, r1]
    assert orphans_res_list == [r3]
Пример #3
0
    def check_integrity(self) -> None:
        """Check the integrity of nodes in the graph."""
        nodes = set(applys_between(self.inputs, self.outputs))
        if self.apply_nodes != nodes:
            nodes_missing = nodes.difference(self.apply_nodes)
            nodes_excess = self.apply_nodes.difference(nodes)
            raise Exception(
                f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}"
            )
        for node in nodes:
            for i, variable in enumerate(node.inputs):
                clients = self.clients[variable]
                if (node, i) not in clients:
                    raise Exception(
                        f"Inconsistent clients list {(node, i)} in {clients}")
        variables = set(vars_between(self.inputs, self.outputs))
        if set(self.variables) != variables:
            vars_missing = variables.difference(self.variables)
            vars_excess = self.variables.difference(variables)
            raise Exception(
                f"The following variables are inappropriately cached:\nmissing: {vars_missing}\nin excess: {vars_excess}"
            )
        for variable in variables:
            if (variable.owner is None and variable not in self.inputs
                    and not isinstance(variable, AtomicVariable)):
                raise Exception(f"Undeclared input: {variable}")
            for cl_node, i in self.clients[variable]:
                if cl_node == "output":
                    if self.outputs[i] is not variable:
                        raise Exception(
                            f"Inconsistent clients list: {variable}, {self.outputs[i]}"
                        )
                    continue

                assert isinstance(cl_node, Apply)

                if cl_node not in nodes:
                    raise Exception(
                        f"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
                    )
                if cl_node.inputs[i] is not variable:
                    raise Exception(
                        f"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
                    )
Пример #4
0
def is_same_graph(var1, var2, givens=None):
    """
    Return True iff Variables `var1` and `var2` perform the same computation.

    By 'performing the same computation', we mean that they must share the same
    graph, so that for instance this function will return False when comparing
    (x * (y * z)) with ((x * y) * z).

    The current implementation is not efficient since, when possible, it
    verifies equality by calling two different functions that are expected to
    return the same output. The goal is to verify this assumption, to
    eventually get rid of one of them in the future.

    Parameters
    ----------
    var1
        The first Variable to compare.
    var2
        The second Variable to compare.
    givens
        Similar to the `givens` argument of `aesara.function`, it can be used
        to perform substitutions in the computational graph of `var1` and
        `var2`. This argument is associated to neither `var1` nor `var2`:
        substitutions may affect both graphs if the substituted variable
        is present in both.

    Examples
    --------

        ======  ======  ======  ======
        var1    var2    givens  output
        ======  ======  ======  ======
        x + 1   x + 1   {}      True
        x + 1   y + 1   {}      False
        x + 1   y + 1   {x: y}  True
        ======  ======  ======  ======

    """
    use_equal_computations = True

    if givens is None:
        givens = {}

    if not isinstance(givens, dict):
        givens = dict(givens)

    # Get result from the merge-based function.
    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)

    if givens:
        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
        # to be able to tell whether a variable belongs to the computational
        # graph of `var1` or `var2`.
        # The typical case we want to handle is when `to_replace` belongs to
        # one of these graphs, and `replace_by` belongs to the other one. In
        # other situations, the current implementation of `equal_computations`
        # is probably not appropriate, so we do not call it.
        ok = True
        in_xs = []
        in_ys = []
        # Compute the sets of all variables found in each computational graph.
        inputs_var = list(map(graph_inputs, ([var1], [var2])))
        all_vars = [
            set(vars_between(v_i, v_o))
            for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
        ]

        def in_var(x, k):
            # Return True iff `x` is in computation graph of variable `vark`.
            return x in all_vars[k - 1]

        for to_replace, replace_by in givens.items():
            # Map a substitution variable to the computational graphs it
            # belongs to.
            inside = {
                v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
            }
            if (
                inside[to_replace][0]
                and not inside[to_replace][1]
                and inside[replace_by][1]
                and not inside[replace_by][0]
            ):
                # Substitute variable in `var1` by one from `var2`.
                in_xs.append(to_replace)
                in_ys.append(replace_by)
            elif (
                inside[to_replace][1]
                and not inside[to_replace][0]
                and inside[replace_by][0]
                and not inside[replace_by][1]
            ):
                # Substitute variable in `var2` by one from `var1`.
                in_xs.append(replace_by)
                in_ys.append(to_replace)
            else:
                ok = False
                break
        if not ok:
            # We cannot directly use `equal_computations`.
            use_equal_computations = False
    else:
        in_xs = None
        in_ys = None
    if use_equal_computations:
        rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
        assert rval2 == rval1
    return rval1
Пример #5
0
def compile_pymc(
    inputs,
    outputs,
    random_seed: SeedSequenceSeed = None,
    mode=None,
    **kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
    """Use ``aesara.function`` with specialized pymc rewrites always enabled.

    This function also ensures shared RandomState/Generator used by RandomVariables
    in the graph are updated across calls, to ensure independent draws.

    Parameters
    ----------
    inputs: list of TensorVariables, optional
        Inputs of the compiled Aesara function
    outputs: list of TensorVariables, optional
        Outputs of the compiled Aesara function
    random_seed: int, array-like of int or SeedSequence, optional
        Seed used to override any RandomState/Generator shared variables in the graph.
        If not specified, the value of original shared variables will still be overwritten.
    mode: optional
        Aesara mode used to compile the function

    Included rewrites
    -----------------
    random_make_inplace
        Ensures that compiled functions containing random variables will produce new
        samples on each call.
    local_check_parameter_to_ninf_switch
        Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
        that return -inf in place of the assert.

    Optional rewrites
    -----------------
    local_remove_check_parameter
        Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
        as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
        this function is called within a model context and the model `check_bounds` flag
        is set to False.
    """
    # Create an update mapping of RandomVariable's RNG so that it is automatically
    # updated after every function call
    rng_updates = {}
    output_to_list = outputs if isinstance(outputs,
                                           (list, tuple)) else [outputs]
    for random_var in (
            var for var in vars_between(inputs, output_to_list)
            if var.owner and isinstance(var.owner.op, (
                RandomVariable, MeasurableVariable)) and var not in inputs):
        if isinstance(random_var.owner.op, RandomVariable):
            rng = random_var.owner.inputs[0]
            if not hasattr(rng, "default_update"):
                rng_updates[rng] = random_var.owner.outputs[0]
            else:
                rng_updates[rng] = rng.default_update
        else:
            update_fn = getattr(random_var.owner.op, "update", None)
            if update_fn is not None:
                rng_updates.update(update_fn(random_var.owner))

    # We always reseed random variables as this provides RNGs with no chances of collision
    if rng_updates:
        reseed_rngs(rng_updates.keys(), random_seed)

    # If called inside a model context, see if check_bounds flag is set to False
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        check_bounds = model.check_bounds
    except TypeError:
        check_bounds = True
    check_parameter_opt = ("local_check_parameter_to_ninf_switch"
                           if check_bounds else "local_remove_check_parameter")

    mode = get_mode(mode)
    opt_qry = mode.provided_optimizer.including("random_make_inplace",
                                                check_parameter_opt)
    mode = Mode(linker=mode.linker, optimizer=opt_qry)
    aesara_function = aesara.function(
        inputs,
        outputs,
        updates={
            **rng_updates,
            **kwargs.pop("updates", {})
        },
        mode=mode,
        **kwargs,
    )
    return aesara_function
Пример #6
0
def compile_pymc(
    inputs, outputs, mode=None, **kwargs
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
    """Use ``aesara.function`` with specialized pymc rewrites always enabled.

    Included rewrites
    -----------------
    random_make_inplace
        Ensures that compiled functions containing random variables will produce new
        samples on each call.
    local_check_parameter_to_ninf_switch
        Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
        that return -inf in place of the assert.

    Optional rewrites
    -----------------
    local_remove_check_parameter
        Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
        as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
        this function is called within a model context and the model `check_bounds` flag
        is set to False.
    """
    # Create an update mapping of RandomVariable's RNG so that it is automatically
    # updated after every function call
    # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
    rng_updates = {}
    output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
    for random_var in (
        var
        for var in vars_between(inputs, output_to_list)
        if var.owner
        and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
        and var not in inputs
    ):
        if isinstance(random_var.owner.op, RandomVariable):
            rng = random_var.owner.inputs[0]
            if not hasattr(rng, "default_update"):
                rng_updates[rng] = random_var.owner.outputs[0]
        else:
            update_fn = getattr(random_var.owner.op, "update", None)
            if update_fn is not None:
                rng_updates.update(update_fn(random_var.owner))

    # If called inside a model context, see if check_bounds flag is set to False
    try:
        from pymc.model import modelcontext

        model = modelcontext(None)
        check_bounds = model.check_bounds
    except TypeError:
        check_bounds = True
    check_parameter_opt = (
        "local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
    )

    mode = get_mode(mode)
    opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    mode = Mode(linker=mode.linker, optimizer=opt_qry)
    aesara_function = aesara.function(
        inputs,
        outputs,
        updates={**rng_updates, **kwargs.pop("updates", {})},
        mode=mode,
        **kwargs,
    )
    return aesara_function