Exemplo n.º 1
0
def test_rvs_to_value_vars():

    with pm.Model() as m:
        a = pm.Uniform("a", 0.0, 1.0)
        b = pm.Uniform("b", 0, a + 1.0)
        c = pm.Normal("c")
        d = at.log(c + b) + 2.0

    a_value_var = m.rvs_to_values[a]
    assert a_value_var.tag.transform

    b_value_var = m.rvs_to_values[b]
    c_value_var = m.rvs_to_values[c]

    (res, ), replaced = rvs_to_value_vars((d, ))

    assert res.owner.op == at.add
    log_output = res.owner.inputs[0]
    assert log_output.owner.op == at.log
    log_add_output = res.owner.inputs[0].owner.inputs[0]
    assert log_add_output.owner.op == at.add
    c_output = log_add_output.owner.inputs[0]

    # We make sure that the random variables were replaced
    # with their value variables
    assert c_output == c_value_var
    b_output = log_add_output.owner.inputs[1]
    assert b_output == b_value_var

    res_ancestors = list(walk_model((res, ), walk_past_rvs=True))
    res_rv_ancestors = [
        v for v in res_ancestors
        if v.owner and isinstance(v.owner.op, RandomVariable)
    ]

    # There shouldn't be any `RandomVariable`s in the resulting graph
    assert len(res_rv_ancestors) == 0
    assert b_value_var in res_ancestors
    assert c_value_var in res_ancestors
    assert a_value_var not in res_ancestors

    (res, ), replaced = rvs_to_value_vars((d, ), apply_transforms=True)

    res_ancestors = list(walk_model((res, ), walk_past_rvs=True))
    res_rv_ancestors = [
        v for v in res_ancestors
        if v.owner and isinstance(v.owner.op, RandomVariable)
    ]

    assert len(res_rv_ancestors) == 0
    assert a_value_var in res_ancestors
    assert b_value_var in res_ancestors
    assert c_value_var in res_ancestors
Exemplo n.º 2
0
def test_logpt_basic():
    """Make sure we can compute a log-likelihood for a hierarchical model with transforms."""

    with Model() as m:
        a = Uniform("a", 0.0, 1.0)
        c = Normal("c")
        b_l = c * a + 2.0
        b = Uniform("b", b_l, b_l + 1.0)

    a_value_var = m.rvs_to_values[a]
    assert a_value_var.tag.transform

    b_value_var = m.rvs_to_values[b]
    assert b_value_var.tag.transform

    c_value_var = m.rvs_to_values[c]

    b_logp = logpt(b, b_value_var, sum=False)

    res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True))
    res_rv_ancestors = [
        v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
    ]

    # There shouldn't be any `RandomVariable`s in the resulting graph
    assert len(res_rv_ancestors) == 0
    assert b_value_var in res_ancestors
    assert c_value_var in res_ancestors
    assert a_value_var in res_ancestors
Exemplo n.º 3
0
def test_walk_model():
    d = at.vector("d")
    b = at.vector("b")
    c = uniform(0.0, d)
    c.name = "c"
    e = at.log(c)
    a = normal(e, b)
    a.name = "a"

    test_graph = at.exp(a + 1)
    res = list(walk_model((test_graph,)))
    assert a in res
    assert c not in res

    res = list(walk_model((test_graph,), walk_past_rvs=True))
    assert a in res
    assert c in res

    res = list(walk_model((test_graph,), walk_past_rvs=True, stop_at_vars={e}))
    assert a in res
    assert c not in res
Exemplo n.º 4
0
def replace_with_values(vars_needed, replacements=None, model=None):
    R"""
    Replace random variable nodes in the graph with values given by the replacements dict.
    Uses untransformed versions of the inputs, performs some basic input validation.

    Parameters
    ----------
    vars_needed: list of TensorVariables
        A list of variable outputs
    replacements: dict with string keys, numeric values
        The variable name and values to be replaced in the model graph.
    model: Model
        A PyMC model object
    """
    model = modelcontext(model)

    inputs, input_names = [], []
    for rv in walk_model(vars_needed, walk_past_rvs=True):
        if rv in model.named_vars.values() and not isinstance(
                rv, SharedVariable):
            inputs.append(rv)
            input_names.append(rv.name)

    # Then it's deterministic, no inputs are required, can eval and return
    if len(inputs) == 0:
        return tuple(v.eval() for v in vars_needed)

    fn = compile_pymc(
        inputs,
        vars_needed,
        allow_input_downcast=True,
        accept_inplace=True,
        on_unused_input="ignore",
    )

    # Remove unneeded inputs
    replacements = {
        name: val
        for name, val in replacements.items() if name in input_names
    }
    missing = set(input_names) - set(replacements.keys())

    # Error if more inputs are needed
    if len(missing) > 0:
        missing_str = ", ".join(missing)
        raise ValueError(
            f"Values for {missing_str} must be included in `replacements`.")

    return fn(**replacements)