def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) # normalize the probs for ground-truth comparison operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands1] expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0] expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2), backend=backend2, modulo_total=True)[0] with interpretation(normalize): funsor_operands1 = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp() for inp, operand in zip(inputs1, operands1) ] output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1) with interpretation(reflect): output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2) with interpretation(reflect): output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive actual1 = reinterpret(output1) actual2 = reinterpret(output2) assert torch.allclose(expected1, actual1.data) assert torch.allclose(expected2, actual2.data)
def test_optimized_plated_einsum_adjoint(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) with interpretation(reflect): fwd_expr = einsum(equation, *funsor_operands, plates=plates, backend=backend) actuals = adjoint(fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=False, plates=plates, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] with interpretation(reflect): naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_adjoint(einsum_impl, equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) sum_op, prod_op = BACKEND_ADJOINT_OPS[backend] with AdjointTape() as tape: # interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = tape.adjoint(sum_op, prod_op, fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=True, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] with interpretation(normalize): naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual = reinterpret(optimized_ast) # eager by default assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend)[0] actual = einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(equation) < 10: actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) assert_close(actual, actual_naive) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]