Exemplo n.º 1
0
def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl):
    inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,))
    inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,))

    # normalize the probs for ground-truth comparison
    operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True)
                 for operand in operands1]

    expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0]
    expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2),
                            backend=backend2, modulo_total=True)[0]

    with interpretation(normalize):
        funsor_operands1 = [
            Categorical(probs=Tensor(
                operand,
                inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]])
            ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp()
            for inp, operand in zip(inputs1, operands1)
        ]

        output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1)
        with interpretation(reflect):
            output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive
        output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2)
        with interpretation(reflect):
            output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive

    actual1 = reinterpret(output1)
    actual2 = reinterpret(output2)

    assert torch.allclose(expected1, actual1.data)
    assert torch.allclose(expected2, actual2.data)
Exemplo n.º 2
0
def test_nested_einsum_complete_sharing(eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2):

    inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example(eqn1, sizes=(3,))
    inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,))

    with memoize():
        output1_1 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1)
        output2_1 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_1] + funsor_operands2), backend=backend2)

        output1_2 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1)
        output2_2 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_2] + funsor_operands2), backend=backend2)

    assert output1_1 is output1_2
    assert output2_1 is output2_2
Exemplo n.º 3
0
def test_optimized_plated_einsum_adjoint(equation, plates, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)

    with interpretation(reflect):
        fwd_expr = einsum(equation,
                          *funsor_operands,
                          plates=plates,
                          backend=backend)
    actuals = adjoint(fwd_expr, funsor_operands)

    for operand in operands:
        pyro_require_backward(operand)
    expected_out = pyro_einsum(equation,
                               *operands,
                               modulo_total=False,
                               plates=plates,
                               backend=backend)[0]
    expected_out._pyro_backward()

    for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)):
        actual = actuals[fv]
        expected = tv._pyro_backward_result
        if inp:
            actual = actual.align(tuple(inp))
        assert isinstance(actual, funsor.Tensor)
        assert expected.shape == actual.data.shape
        assert torch.allclose(expected, actual.data, atol=1e-7)
Exemplo n.º 4
0
def test_normalize_einsum(equation, plates, backend, einsum_impl):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with interpretation(normalize):
        transformed_expr = reinterpret(expr)

    assert isinstance(transformed_expr, Contraction)
    check_funsor(transformed_expr, expr.inputs, expr.output)

    assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms)

    with interpretation(normalize):
        transformed_expr2 = reinterpret(transformed_expr)

    assert transformed_expr2 is transformed_expr  # check normalization

    with interpretation(eager):
        actual = reinterpret(transformed_expr)
        expected = reinterpret(expr)

    assert_close(actual, expected, rtol=1e-4)

    actual = eval(quote(expected))  # requires torch, bint
    assert_close(actual, expected)
Exemplo n.º 5
0
def test_einsum(equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    expected = opt_einsum.contract(equation, *operands, backend=backend)

    with interpretation(reflect):
        naive_ast = naive_einsum(equation, *funsor_operands, backend=backend)
        optimized_ast = apply_optimizer(naive_ast)
    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *funsor_operands, backend=backend)

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)
    assert expected.shape == actual.data.shape
    assert_close(expected, actual.data, rtol=1e-5, atol=1e-8)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 6
0
def test_plated_einsum(equation, plates, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    expected = pyro_einsum(equation,
                           *operands,
                           plates=plates,
                           backend=backend,
                           modulo_total=False)[0]
    with interpretation(reflect):
        naive_ast = naive_plated_einsum(equation,
                                        *funsor_operands,
                                        plates=plates,
                                        backend=backend)
        optimized_ast = apply_optimizer(naive_ast)
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_plated_einsum(equation,
                                 *funsor_operands,
                                 plates=plates,
                                 backend=backend)

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual,
                 actual_optimized,
                 atol=1e-3 if backend == 'torch' else 1e-4)

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 7
0
def test_einsum_adjoint(einsum_impl, equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    sum_op, prod_op = BACKEND_ADJOINT_OPS[backend]

    with AdjointTape() as tape:  # interpretation(reflect):
        fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend)
    actuals = tape.adjoint(sum_op, prod_op, fwd_expr, funsor_operands)

    for operand in operands:
        pyro_require_backward(operand)
    expected_out = pyro_einsum(equation,
                               *operands,
                               modulo_total=True,
                               backend=backend)[0]
    expected_out._pyro_backward()

    for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)):
        actual = actuals[fv]
        expected = tv._pyro_backward_result
        if inp:
            actual = actual.align(tuple(inp))
        assert isinstance(actual, funsor.Tensor)
        assert expected.shape == actual.data.shape
        assert torch.allclose(expected, actual.data, atol=1e-7)
Exemplo n.º 8
0
def test_optimized_plated_einsum(equation, plates, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    expected = pyro_einsum.einsum(equation,
                                  *operands,
                                  plates=plates,
                                  backend=backend)[0]
    actual = einsum(equation, *funsor_operands, plates=plates, backend=backend)

    if len(equation) < 10:
        actual_naive = naive_plated_einsum(equation,
                                           *funsor_operands,
                                           plates=plates,
                                           backend=backend)
        assert_close(actual, actual_naive)

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 9
0
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)
        lazy_expr2 = lazy_expr1 if same_lazy else \
            einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with memoize():
        expr1 = reinterpret(lazy_expr1)
        expr2 = reinterpret(lazy_expr2)
    expr3 = reinterpret(lazy_expr1)

    assert expr1 is expr2
    assert expr1 is not expr3
Exemplo n.º 10
0
def test_nested_complete_sharing_direct():

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d")
    ab, bc, cd = funsor_operands

    # avoids the complicated internal interpreter usage of the nested optimized einsum tests above
    with interpretation(reflect):
        c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d1 = (c1 * cd).reduce(ops.add, frozenset({"c"}))

        # this does not trigger a second alpha-renaming
        c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d2 = (c2 * cd).reduce(ops.add, frozenset({"c"}))

    with memoize():
        assert reinterpret(c1) is reinterpret(c2)
        assert reinterpret(d1) is reinterpret(d2)
Exemplo n.º 11
0
def test_optimized_einsum(equation, backend, einsum_impl):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)
    expected = pyro_einsum(equation, *operands, backend=backend)[0]
    with interpretation(normalize):
        naive_ast = einsum_impl(equation, *funsor_operands, backend=backend)
    optimized_ast = apply_optimizer(naive_ast)
    actual = reinterpret(optimized_ast)  # eager by default

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 12
0
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    equation = ",".join(inputs) + "->"

    targets = [Variable(k, bint(sizes[k])) for k in set(sizes)]
    with interpretation(reflect):
        fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend)
    actuals = adjoint(fwd_expr, targets)

    for target in targets:
        actual = actuals[target]

        expected = opt_einsum.contract(equation + target.name,
                                       *operands,
                                       backend=backend)
        assert isinstance(actual, funsor.Tensor)
        assert expected.shape == actual.data.shape
        assert torch.allclose(expected, actual.data, atol=1e-7)
Exemplo n.º 13
0
def test_einsum_categorical(equation):
    if get_backend() == "jax":
        from funsor.jax.distributions import Categorical
    else:
        from funsor.torch.distributions import Categorical

    inputs, outputs, sizes, operands, _ = make_einsum_example(equation)
    operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None]
                for operand in operands]

    expected = opt_einsum.contract(equation, *operands,
                                   backend=BACKEND_TO_EINSUM_BACKEND[get_backend()])

    with interpretation(reflect):
        funsor_operands = [
            Categorical(probs=Tensor(
                operand,
                inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]])
            ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp()
            for inp, operand in zip(inputs, operands)
        ]

        naive_ast = naive_einsum(equation, *funsor_operands)
        optimized_ast = apply_optimizer(naive_ast)

    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)

    assert expected.shape == actual.data.shape
    assert_close(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 14
0
def test_einsum_categorical(equation):
    inputs, outputs, sizes, operands, _ = make_einsum_example(equation)
    operands = [operand.abs() / operand.abs().sum(-1, keepdim=True)
                for operand in operands]

    expected = opt_einsum.contract(equation, *operands, backend='torch')

    with interpretation(reflect):
        funsor_operands = [
            Categorical(probs=Tensor(
                operand,
                inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]])
            ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp()
            for inp, operand in zip(inputs, operands)
        ]

        naive_ast = naive_einsum(equation, *funsor_operands)
        optimized_ast = apply_optimizer(naive_ast)

    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]