def test_explicit_bound(): x = relay.const(1) y = op.add(x, x) z = op.add(y, y) f = relay.Function([], op.add(z, z)) assert not Feature.fLet in detect_feature(f) anf = run_opt_pass(f, transform.ToANormalForm()) assert Feature.fLet in detect_feature(anf) check_eval(f(), 8.0) check_eval(anf(), 8.0)
def test_no_explicit_bind(): x = relay.const(1) y = op.add(x, x) z = op.add(y, y) f = relay.Function([], op.add(z, z)) '\n fn () {\n %0 = add(1, 1);\n %1 = add(%0, %0);\n add(%1, %1)\n }\n ' assert (not (Feature.fLet in detect_feature(f))) bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm()) assert (Feature.fLet not in detect_feature(bblock)) check_eval(f(), 8.0) check_eval(bblock(), 8.0) check_basic_block_normal_form(bblock)
def test_implicit_share(): x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) g = run_opt_pass(f, transform.ToGraphNormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0)
def test_round_trip(): x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) g = run_opt_pass(f, transform.ToGraphNormalForm()) h = run_opt_pass(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) check_eval(h, [], 8.0)
def test_ad(): shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) func = relay.Function([x], x + x) func = run_infer_type(func) mod = tvm.IRModule.from_expr(gradient(func)) mod = relay.transform.InferType()(mod) back_func = mod["main"] feats = detect_feature(back_func) assert feats == set( [ Feature.fVar, Feature.fTuple, Feature.fTupleGetItem, Feature.fFunction, Feature.fOp, Feature.fCall, Feature.fLet, Feature.fRefCreate, Feature.fRefRead, Feature.fRefWrite, ] )
def test_prelude(): p = Prelude() feats = detect_feature(p.mod) assert feats == set([ Feature.fVar, Feature.fGlobalVar, Feature.fConstant, Feature.fTuple, Feature.fTupleGetItem, Feature.fFunction, Feature.fOp, Feature.fCall, Feature.fLet, Feature.fIf, Feature.fConstructor, Feature.fMatch ])
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x)
def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") add = p.mod.get_global_var("nat_add") dev = tvm.device("llvm", 0) intrp = create_executor(mod=mod, device=dev, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.ToANormalForm()(mod) expr = mod["f"] assert count(p, intrp.evaluate(expr.body)) == 2 assert Feature.fLet in detect_feature(mod[add])
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] ), ) assert Feature.fRefCreate not in detect_feature(x)
def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat add = p.add s = p.s z = p.z ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.ToANormalForm()(mod) expr = mod["f"] assert count(p, intrp.evaluate(expr.body)) == 2 assert Feature.fLet in detect_feature(mod[add])
def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) nat = p.nat add = p.add s = p.s z = p.z ctx = tvm.context('llvm', 0) intrp = create_executor(mod=mod, ctx=ctx, target='llvm') assert (mod[add].checked_type == relay.FuncType([nat(), nat()], nat())) assert (count(p, intrp.evaluate(add(s(z()), s(z())))) == 2) expr = add(s(z()), s(z())) f = relay.GlobalVar('f') mod[f] = relay.Function([], expr) mod = transform.ToBasicBlockNormalForm()(mod) opt_expr = mod['f'] assert (count(p, intrp.evaluate(opt_expr.body)) == 2) assert (not (Feature.fLet in detect_feature(mod[add]))) check_basic_block_normal_form(opt_expr)
def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") add = p.mod.get_global_var("nat_add") ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.InferType()(mod) mod = transform.ToBasicBlockNormalForm()(mod) opt_expr = mod["f"] assert count(p, intrp.evaluate(opt_expr.body)) == 2 assert not Feature.fLet in detect_feature(mod[add]) check_basic_block_normal_form(opt_expr)