def test_triangle_number(): t = relay.TensorType([], "int32") x = Var("x", t) f_var = Var("f") f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1)))) orig = run_infer_type(Let(f_var, f, f_var(const(10)))) assert_alpha_equal(dcpe(orig), const(55))
def hd_impl(): a = TypeVar("a") x = Var("x", p.l(a)) y = Var("y") z = Var("z") cons_case = Clause( PatternConstructor(p.cons, [PatternVar(y), PatternVar(z)]), y) y = Var("y") z = Var("z") return Function([x], Match(x, [cons_case]), a, [a])
def test_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_swap_loop(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = dcpe(prog, mod=mod) assert alpha_equal(prog, res)
def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) assert tvm.ir.structural_equal(prog, res.body)
def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, _, _ = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat()) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) assert tvm.ir.structural_equal(prog, res.body)
def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, _, _ = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def build_impl(self, input_size, memory_size, dtype="float32"): t = TensorType(shape=(1, memory_size), dtype=dtype) self.ret_type = TupleType([t, t]) tree_type = self.p.tree(TensorType(shape=(1, input_size), dtype=dtype)) t = self.input(Var("tlstm_input", tree_type)) i = Var("i", TensorType(shape=(1, input_size), dtype=dtype)) c = Var("c", self.p.l(tree_type)) cell = LSTMCell(input_size=input_size, memory_size=memory_size, dtype=dtype) rose_case = Clause( PatternConstructor(self.p.rose, [PatternVar(i), PatternVar(c)]), cell(i, self.p.map(lam(["x"], self), c))) return Match(t, [rose_case])
def test_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(z, []), z()) s_case = Clause(PatternConstructor(s, [PatternVar(y)]), s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def build_impl(self, input_size, memory_size, dtype="float32"): t = TensorType(shape=(1, memory_size), dtype=dtype) i = self.input( var("lstmcell_input", shape=(1, input_size), dtype=dtype)) c = self.input(Var("lstmcell_children", self.p.l(TupleType([t, t])))) sum = lam(["x", "y"], lambda x, y: x + y) child_h_sum = self.p.foldl( sum, op.zeros(shape=(1, memory_size), dtype=dtype), self.p.map(lam(["z"], lambda z: TupleGetItem(z, 1)), c)) ioux = Linear(input_size=input_size, output_size=memory_size * 3)(i) iouh = Linear(input_size=memory_size, output_size=memory_size * 3)(child_h_sum) iou = ioux + iouh fx = Linear(input_size=input_size, output_size=memory_size)(i) fh = Linear(input_size=memory_size, output_size=memory_size) i, o, u = op.split(iou, 3, axis=1) i, o, u = op.sigmoid(i), op.sigmoid(o), op.tanh(u) def foreach_children(children): f = op.sigmoid(fh(TupleGetItem(children, 1)) + fx) return f * TupleGetItem(children, 0) c = self.p.foldl(sum, i * u, self.p.map(lam(["z"], foreach_children), c)) return Tuple([c, o * op.tanh(c)])
def test_if_ref(): shape = () dtype = "bool" t = TensorType(shape, dtype) d = Var("d", t) r = Var("r") update = Function([], RefWrite(r, RefRead(r) + RefRead(r))) u = Var("u") body = If(d, u(), u()) eff = Var("eff") body = Let(eff, body, RefRead(r)) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) pe_f = tipe(f) f_res = create_executor().evaluate(f)(const(True)) pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), 2 * np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), 2 * np.ones_like(pe_f_res.numpy()))
def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(dcpe(f), expected)
def test_tuple(): t = TypeVar("t") x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(dcpe(f), expected)
def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) res = dcpe(loop(const(1)), mod=mod) expected = Call(loop, [const(1)], None, [None]) assert alpha_equal(res, expected)
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(gradient(f)) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) assert alpha_equal(g, expected)
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_map(): mod = Module() p = Prelude(mod) f = Var("f") orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil()))) assert alpha_equal(dcpe(orig, mod=mod), expected)
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(g, expected)
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected)
def test_head_cons(): mod = Module() p = Prelude(mod) hd = p.hd t = TypeVar("t") x = Var("x", t) body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert alpha_equal(res, Function([x], x, t, [t]))
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) xp = Var("x'", nat()) yp = Var("y'", nat()) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(z, []), x) y_s_case = Clause(PatternConstructor(s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(z, []), y) x_s_case = Clause(PatternConstructor(s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_global_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) # TODO(mbs): Revisit once DCE eliminates dead writes. g = dcpe(f, grad=True, ignore_impurity=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(g, expected)
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) xp = Var("x'", nat) yp = Var("y'", nat) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(p.z, []), x) y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(p.z, []), y) x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) rlist, cons, nil = p.mod.get_type("List") hd = p.mod.get_global_var("hd") body = hd(cons(x, nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) expected_mod = tvm.IRModule.from_expr(Function([x], x, t, [t])) assert tvm.ir.structural_equal(res, expected_mod["main"])
def test_global_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_loop(): mod = tvm.IRModule() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod["main"] = Function([], expected) expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert tvm.ir.structural_equal(res.body, expected)
def test_function_invalidate(): shape = () dtype = "bool" t = TensorType(shape, dtype) d = Var("d", t) r = Var("r") fetch = Function([], RefRead(r)) fet = Var("fetch") fet_obscured = Var("fetch_obscured") u = Var("u") body = If(d, fet_obscured(), fet_obscured()) body = Let(u, RefWrite(r, const(1)), body) body = Let(fet_obscured, If(d, fet, fet), body) body = Let(fet, fetch, body) body = Let(r, RefCreate(const(0)), body) f = Function([d], body) pe_f = tipe(f) f_res = create_executor().evaluate(f)(const(True)) pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), np.ones_like(pe_f_res.numpy()))