Exemplo n.º 1
0
def test_replace_shared_variables():
    x = aesara.shared(5, name="shared_x")

    new_x = replace_shared_variables([x])
    shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
    assert not shared_variables

    x.default_update = x + 1
    with pytest.raises(ValueError, match="shared variables with default_update"):
        replace_shared_variables([x])
Exemplo n.º 2
0
def test_missing_symmetric():
    """Check that logpt works when partially observed variable have equal observed and
    unobserved dimensions.

    This would fail in a previous implementation because the two variables would be
    equivalent and one of them would be discarded during MergeOptimization while
    buling the logpt graph
    """
    with Model() as m:
        x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan]))

    x_obs_rv = m["x_observed"]
    x_obs_vv = m.rvs_to_values[x_obs_rv]

    x_unobs_rv = m["x_missing"]
    x_unobs_vv = m.rvs_to_values[x_unobs_rv]

    logp = logpt([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv})
    logp_inputs = list(graph_inputs([logp]))
    assert x_obs_vv in logp_inputs
    assert x_unobs_vv in logp_inputs