def hd_impl(): a = TypeVar("a") x = Var("x", p.l(a)) y = Var("y") z = Var("z") cons_case = Clause( PatternConstructor(p.cons, [PatternVar(y), PatternVar(z)]), y) y = Var("y") z = Var("z") return Function([x], Match(x, [cons_case]), a, [a])
def build_impl(self, input_size, memory_size, dtype="float32"): t = TensorType(shape=(1, memory_size), dtype=dtype) self.ret_type = TupleType([t, t]) tree_type = self.p.tree(TensorType(shape=(1, input_size), dtype=dtype)) t = self.input(Var("tlstm_input", tree_type)) i = Var("i", TensorType(shape=(1, input_size), dtype=dtype)) c = Var("c", self.p.l(tree_type)) cell = LSTMCell(input_size=input_size, memory_size=memory_size, dtype=dtype) rose_case = Clause( PatternConstructor(self.p.rose, [PatternVar(i), PatternVar(c)]), cell(i, self.p.map(lam(["x"], self), c))) return Match(t, [rose_case])
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) xp = Var("x'", nat) yp = Var("y'", nat) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(p.z, []), x) y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(p.z, []), y) x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) xp = Var("x'", nat()) yp = Var("y'", nat()) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(z, []), x) y_s_case = Clause(PatternConstructor(s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(z, []), y) x_s_case = Clause(PatternConstructor(s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_global_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_global_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) z_case = Clause(PatternConstructor(z, []), z()) s_case = Clause(PatternConstructor(s, [PatternVar(x)]), s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))