def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] with interpretation(reflect): naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0] actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(equation) < 10: actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) assert_close(actual, actual_naive) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]