Esempio n. 1
0
def test_einsum_complete_sharing_reuse_cache(equation, plates, backend, einsum_impl, same_lazy):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)
        lazy_expr2 = lazy_expr1 if same_lazy else \
            einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    cache = {}
    with memoize(cache) as cache:
        expr1 = reinterpret(lazy_expr1)

    with memoize(cache):
        expr2 = reinterpret(lazy_expr2)

    expr3 = reinterpret(lazy_expr1)

    assert expr1 is expr2
    assert expr1 is not expr3
Esempio n. 2
0
def test_nested_einsum_complete_sharing(eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2):

    inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example(eqn1, sizes=(3,))
    inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,))

    with memoize():
        output1_1 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1)
        output2_1 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_1] + funsor_operands2), backend=backend2)

        output1_2 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1)
        output2_2 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_2] + funsor_operands2), backend=backend2)

    assert output1_1 is output1_2
    assert output2_1 is output2_2
Esempio n. 3
0
def test_memoize_sample(check_sample):

    with memoize():
        m, s = torch.tensor(0.), torch.tensor(1.)
        j1 = Normal(m, s, 'x')
        j2 = Normal(m, s, 'x')
        x1 = j1.sample(frozenset({'x'}))
        x12 = j1.sample(frozenset({'x'}))
        x2 = j2.sample(frozenset({'x'}))

    # this assertion now passes
    assert j1 is j2

    # these assertions fail because sample is not memoized
    if check_sample:
        assert x1 is x12
        assert x1 is x2
Esempio n. 4
0
def test_nested_complete_sharing_direct():

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d")
    ab, bc, cd = funsor_operands

    # avoids the complicated internal interpreter usage of the nested optimized einsum tests above
    with interpretation(reflect):
        c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d1 = (c1 * cd).reduce(ops.add, frozenset({"c"}))

        # this does not trigger a second alpha-renaming
        c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d2 = (c2 * cd).reduce(ops.add, frozenset({"c"}))

    with memoize():
        assert reinterpret(c1) is reinterpret(c2)
        assert reinterpret(d1) is reinterpret(d2)
Esempio n. 5
0
def test_memoize_sample(check_sample):
    if get_backend() == "jax":
        from funsor.jax.distributions import Normal
    else:
        from funsor.torch.distributions import Normal

    rng_keys = (None, None, None) if get_backend() == "torch" \
        else np.array([[0, 1], [0, 2], [0, 3]], dtype=np.uint32)

    with memoize():
        m, s = numeric_array(0.), numeric_array(1.)
        j1 = Normal(m, s, 'x')
        j2 = Normal(m, s, 'x')
        x1 = j1.sample(frozenset({'x'}), rng_key=rng_keys[0])
        x12 = j1.sample(frozenset({'x'}), rng_key=rng_keys[1])
        x2 = j2.sample(frozenset({'x'}), rng_key=rng_keys[2])

    # this assertion now passes
    assert j1 is j2

    # these assertions fail because sample is not memoized
    if check_sample:
        assert x1 is x12
        assert x1 is x2