Esempio n. 1
0
def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
    # add to inference that supports aevb
    cls, kw = parametric_grouped_approxes
    with three_var_model:
        try:
            approx = Approximation([cls([three_var_model.one], rowwise=True, **kw), Group(None, vfam='mf')])
            inference = pm.KLqp(approx)
            approx = inference.fit(3, obj_n_mc=2)
            approx.sample(10)
            approx.sample_node(
                three_var_model.one
            ).eval()
        except pm.opvi.BatchedGroupError:
            pytest.skip('Does not support rowwise grouping')
def test_init_groups(three_var_model, raises, grouping):
    with raises, three_var_model:
        approxes, groups = zip(*grouping.items())
        groups = [list(map(functools.partial(getattr, three_var_model), g))
                  if g is not None else None
                  for g in groups]
        inited_groups = [a(group=g) for a, g in zip(approxes, groups)]
        approx = Approximation(inited_groups)
        for ig, g in zip(inited_groups, groups):
            if g is None:
                pass
            else:
                assert set(pm.util.get_transformed(z) for z in g) == set(ig.group)
        else:
            assert approx.ndim == three_var_model.ndim
def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
    # add to inference that supports aevb
    cls, kw = parametric_grouped_approxes
    with three_var_model:
        try:
            approx = Approximation([cls([three_var_model.one], rowwise=True, **kw), Group(None, vfam='mf')])
            inference = pm.KLqp(approx)
            approx = inference.fit(3, obj_n_mc=2)
            approx.sample(10)
            approx.sample_node(
                three_var_model.one
            ).eval()
        except pm.opvi.BatchedGroupError:
            pytest.skip('Does not support rowwise grouping')
def three_var_aevb_approx(three_var_model, three_var_aevb_groups):
    approx = Approximation(three_var_aevb_groups, model=three_var_model)
    return approx