def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(dcpe(f), expected)
def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(dcpe(f), expected)
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) res = dcpe(loop(const(1)), mod=mod) expected = Call(loop, [const(1)], None, [None]) assert alpha_equal(res, expected)
def test_head_cons(): mod = Module() p = Prelude(mod) hd = p.hd t = TypeVar("t") x = Var("x", t) body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert alpha_equal(res, Function([x], x, t, [t]))
def hd_impl(): a = TypeVar("a") x = Var("x", p.l(a)) y = Var("y") z = Var("z") cons_case = Clause( PatternConstructor(p.cons, [PatternVar(y), PatternVar(z)]), y) y = Var("y") z = Var("z") return Function([x], Match(x, [cons_case]), a, [a])
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) rlist, cons, nil = p.mod.get_type("List") hd = p.mod.get_global_var("hd") body = hd(cons(x, nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) expected_mod = tvm.IRModule.from_expr(Function([x], x, t, [t])) assert tvm.ir.structural_equal(res, expected_mod["main"])
def test_loop(): mod = tvm.IRModule() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod["main"] = Function([], expected) expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert tvm.ir.structural_equal(res.body, expected)
def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod[mod.entry_func] = Function([], expected) expected = mod[mod.entry_func].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert alpha_equal(res.body, expected)
def test_map(): mod = Module() p = Prelude(mod) f = GlobalVar("f") t = TypeVar("t") a = Var("a", t) mod[f] = Function([a], a, t, [t]) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) expected = Function([], expected) mod["main"] = expected expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, expected.body)
def test_map(): mod = tvm.IRModule() p = Prelude(mod) rlist, cons, nil = p.mod.get_type("List") rmap = p.mod.get_global_var("map") f = GlobalVar("f") t = TypeVar("t") a = Var("a", t) mod[f] = Function([a], a, t, [t]) orig = rmap(f, cons(const(1), cons(const(2), cons(const(3), nil())))) expected = cons((const(1)), cons((const(2)), cons((const(3)), nil()))) expected = Function([], expected) mod["main"] = expected mod = transform.InferType()(mod) expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, expected.body)
def test_head_cons(): mod = Module() p = Prelude(mod) def hd_impl(): a = TypeVar("a") x = Var("x", p.l(a)) y = Var("y") z = Var("z") cons_case = Clause( PatternConstructor(p.cons, [PatternVar(y), PatternVar(z)]), y) y = Var("y") z = Var("z") return Function([x], Match(x, [cons_case]), a, [a]) t = TypeVar("t") x = Var("x", t) hd = Var("hd") body = Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) f = Function([x], body, None, [t]) f = infer_type(f, mod=mod) res = dcpe(f) assert alpha_equal(res, Function([x], x, t, [t]))
def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))