def test_rowwise_approx(three_var_model, parametric_grouped_approxes): # add to inference that supports aevb cls, kw = parametric_grouped_approxes with three_var_model: try: approx = Approximation([cls([three_var_model.one], rowwise=True, **kw), Group(None, vfam='mf')]) inference = pm.KLqp(approx) approx = inference.fit(3, obj_n_mc=2) approx.sample(10) approx.sample_node( three_var_model.one ).eval() except pm.opvi.BatchedGroupError: pytest.skip('Does not support rowwise grouping')
def test_init_groups(three_var_model, raises, grouping): with raises, three_var_model: approxes, groups = zip(*grouping.items()) groups = [list(map(functools.partial(getattr, three_var_model), g)) if g is not None else None for g in groups] inited_groups = [a(group=g) for a, g in zip(approxes, groups)] approx = Approximation(inited_groups) for ig, g in zip(inited_groups, groups): if g is None: pass else: assert set(pm.util.get_transformed(z) for z in g) == set(ig.group) else: assert approx.ndim == three_var_model.ndim
def three_var_aevb_approx(three_var_model, three_var_aevb_groups): approx = Approximation(three_var_aevb_groups, model=three_var_model) return approx