def test_optimized_plated_einsum_adjoint(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) with interpretation(reflect): fwd_expr = einsum(equation, *funsor_operands, plates=plates, backend=backend) actuals = adjoint(fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=False, plates=plates, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_einsum_adjoint(einsum_impl, equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) sum_op, prod_op = BACKEND_ADJOINT_OPS[backend] with AdjointTape() as tape: # interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = tape.adjoint(sum_op, prod_op, fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=True, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)