def test_constant_unification(self): x = Constant(MyType(), 2, name="x") y = MyVariable("y") z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g) assert str(g) == "[Op1(Op2(y, z), y)]"
def test_constant_merging(self): x = MyVariable("x") y = Constant(MyType(), 2, name="y") z = Constant(MyType(), 2, name="z") e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e]) MergeOptimizer().optimize(g) strg = str(g) assert (strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" or strg == "[Op1(*1 -> Op2(x, z), *1, *1)]")
def test_identical_constant_args(self): x = MyVariable("x") y = Constant(MyType(), 2, name="y") z = Constant(MyType(), 2, name="z") ctv_backup = config.compute_test_value config.compute_test_value = "off" try: e1 = op1(y, z) finally: config.compute_test_value = ctv_backup g = FunctionGraph([x, y, z], [e1]) MergeOptimizer().optimize(g) strg = str(g) assert strg == "[Op1(y, y)]" or strg == "[Op1(z, z)]"
def test_c_fail_error(): x, y, z = inputs() x = Constant(tdouble, 7.2, name="x") e = add_fail(mul(x, y), mul(y, z)) lnk = OpWiseCLinker().accept(Env([y, z], [e])) fn = lnk.make_function() with pytest.raises(RuntimeError): fn(1.5, 3.0)
def test_opwiseclinker_constant(): x, y, z = inputs() x = Constant(tdouble, 7.2, name="x") e = add(mul(x, y), mul(y, z)) lnk = OpWiseCLinker().accept(Env([y, z], [e])) fn = lnk.make_function() res = fn(1.5, 3.0) assert res == 15.3
def test_clinker_literal_inlining(): x, y, z = inputs() z = Constant(tdouble, 4.12345678) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) lnk = CLinker().accept(Env([x, y], [e])) fn = lnk.make_function() assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9 code = lnk.code_gen() # print "=== Code generated ===" # print code assert "4.12345678" in code # we expect the number to be inlined
def test_constant(self): x, y, z = inputs() y = Constant(tdouble, 2.0) e = mul(add(x, y), div(x, y)) fn = perform_linker(FunctionGraph([x], [e])).make_function() assert fn(1.0) == 1.5