def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) # normalize the probs for ground-truth comparison operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands1] expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0] expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2), backend=backend2, modulo_total=True)[0] with interpretation(normalize): funsor_operands1 = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp() for inp, operand in zip(inputs1, operands1) ] output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1) with interpretation(reflect): output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2) with interpretation(reflect): output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive actual1 = reinterpret(output1) actual2 = reinterpret(output2) assert torch.allclose(expected1, actual1.data) assert torch.allclose(expected2, actual2.data)
def test_nested_einsum_complete_sharing(eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2): inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example(eqn1, sizes=(3,)) inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) with memoize(): output1_1 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1) output2_1 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_1] + funsor_operands2), backend=backend2) output1_2 = einsum_impl1(eqn1, *funsor_operands1, backend=backend1) output2_2 = einsum_impl2(outputs1[0] + "," + eqn2, *([output1_2] + funsor_operands2), backend=backend2) assert output1_1 is output1_2 assert output2_1 is output2_2
def test_optimized_plated_einsum_adjoint(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) with interpretation(reflect): fwd_expr = einsum(equation, *funsor_operands, plates=plates, backend=backend) actuals = adjoint(fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=False, plates=plates, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_normalize_einsum(equation, plates, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with interpretation(normalize): transformed_expr = reinterpret(expr) assert isinstance(transformed_expr, Contraction) check_funsor(transformed_expr, expr.inputs, expr.output) assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms) with interpretation(normalize): transformed_expr2 = reinterpret(transformed_expr) assert transformed_expr2 is transformed_expr # check normalization with interpretation(eager): actual = reinterpret(transformed_expr) expected = reinterpret(expr) assert_close(actual, expected, rtol=1e-4) actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected)
def test_einsum(equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = opt_einsum.contract(equation, *operands, backend=backend) with interpretation(reflect): naive_ast = naive_einsum(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *funsor_operands, backend=backend) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert_close(expected, actual.data, rtol=1e-5, atol=1e-8) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] with interpretation(reflect): naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_adjoint(einsum_impl, equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) sum_op, prod_op = BACKEND_ADJOINT_OPS[backend] with AdjointTape() as tape: # interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = tape.adjoint(sum_op, prod_op, fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=True, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0] actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(equation) < 10: actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) assert_close(actual, actual_naive) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) lazy_expr2 = lazy_expr1 if same_lazy else \ einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with memoize(): expr1 = reinterpret(lazy_expr1) expr2 = reinterpret(lazy_expr2) expr3 = reinterpret(lazy_expr1) assert expr1 is expr2 assert expr1 is not expr3
def test_nested_complete_sharing_direct(): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d") ab, bc, cd = funsor_operands # avoids the complicated internal interpreter usage of the nested optimized einsum tests above with interpretation(reflect): c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d1 = (c1 * cd).reduce(ops.add, frozenset({"c"})) # this does not trigger a second alpha-renaming c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d2 = (c2 * cd).reduce(ops.add, frozenset({"c"})) with memoize(): assert reinterpret(c1) is reinterpret(c2) assert reinterpret(d1) is reinterpret(d2)
def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] with interpretation(normalize): naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual = reinterpret(optimized_ast) # eager by default assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) equation = ",".join(inputs) + "->" targets = [Variable(k, bint(sizes[k])) for k in set(sizes)] with interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = adjoint(fwd_expr, targets) for target in targets: actual = actuals[target] expected = opt_einsum.contract(equation + target.name, *operands, backend=backend) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_einsum_categorical(equation): if get_backend() == "jax": from funsor.jax.distributions import Categorical else: from funsor.torch.distributions import Categorical inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None] for operand in operands] expected = opt_einsum.contract(equation, *operands, backend=BACKEND_TO_EINSUM_BACKEND[get_backend()]) with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]]) ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert_close(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_categorical(equation): inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands] expected = opt_einsum.contract(equation, *operands, backend='torch') with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]