def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
    # add to inference that supports aevb
    cls, kw = parametric_grouped_approxes
    with three_var_model:
        try:
            approx = Approximation([
                cls([three_var_model.one], rowwise=True, **kw),
                Group(None, vfam="mf")
            ])
            inference = pm.KLqp(approx)
            approx = inference.fit(3, obj_n_mc=2)
            approx.sample(10)
            approx.sample_node(three_var_model.one).eval()
        except pm.opvi.BatchedGroupError:
            pytest.skip("Does not support rowwise grouping")
def test_clear_cache():
    import cloudpickle

    with pm.Model():
        pm.Normal("n", 0, 1)
        inference = ADVI()
        inference.fit(n=10)
        assert any(len(c) != 0 for c in inference.approx._cache.values())
        inference.approx._cache.clear()
        # should not be cleared at this call
        assert all(len(c) == 0 for c in inference.approx._cache.values())
        new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx))
        assert not hasattr(new_a, "_cache")
        inference_new = pm.KLqp(new_a)
        inference_new.fit(n=10)
        assert any(len(c) != 0 for c in inference_new.approx._cache.values())
        inference_new.approx._cache.clear()
        assert all(len(c) == 0 for c in inference_new.approx._cache.values())
def test_sample_aevb(three_var_aevb_approx, aevb_initial):
    inf = pm.KLqp(three_var_aevb_approx)
    inf.fit(1,
            more_replacements={
                aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]
            })
    aevb_initial.set_value(np.random.rand(7, 7).astype("float32"))
    trace = three_var_aevb_approx.sample(500, return_inferencedata=False)
    assert set(trace.varnames) == {"one", "one_log__", "two", "three"}
    assert len(trace) == 500
    assert trace[0]["one"].shape == (7, 2)
    assert trace[0]["two"].shape == (10, )
    assert trace[0]["three"].shape == (10, 1, 2)

    aevb_initial.set_value(np.random.rand(13, 7).astype("float32"))
    trace = three_var_aevb_approx.sample(500, return_inferencedata=False)
    assert set(trace.varnames) == {"one", "one_log__", "two", "three"}
    assert len(trace) == 500
    assert trace[0]["one"].shape == (13, 2)
    assert trace[0]["two"].shape == (10, )
    assert trace[0]["three"].shape == (10, 1, 2)