def test_normalize_einsum(equation, plates, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with interpretation(normalize): transformed_expr = reinterpret(expr) assert isinstance(transformed_expr, Contraction) check_funsor(transformed_expr, expr.inputs, expr.output) assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms) with interpretation(normalize): transformed_expr2 = reinterpret(transformed_expr) assert transformed_expr2 is transformed_expr # check normalization with interpretation(eager): actual = reinterpret(transformed_expr) expected = reinterpret(expr) assert_close(actual, expected, rtol=1e-4) actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected)
def quote(self): return quote(self)