예제 #1
0
 def hd_impl():
     a = TypeVar("a")
     x = Var("x", p.l(a))
     y = Var("y")
     z = Var("z")
     cons_case = Clause(
         PatternConstructor(p.cons,
                            [PatternVar(y), PatternVar(z)]), y)
     y = Var("y")
     z = Var("z")
     return Function([x], Match(x, [cons_case]), a, [a])
예제 #2
0
 def build_impl(self, input_size, memory_size, dtype="float32"):
     t = TensorType(shape=(1, memory_size), dtype=dtype)
     self.ret_type = TupleType([t, t])
     tree_type = self.p.tree(TensorType(shape=(1, input_size), dtype=dtype))
     t = self.input(Var("tlstm_input", tree_type))
     i = Var("i", TensorType(shape=(1, input_size), dtype=dtype))
     c = Var("c", self.p.l(tree_type))
     cell = LSTMCell(input_size=input_size,
                     memory_size=memory_size,
                     dtype=dtype)
     rose_case = Clause(
         PatternConstructor(self.p.rose,
                            [PatternVar(i), PatternVar(c)]),
         cell(i, self.p.map(lam(["x"], self), c)))
     return Match(t, [rose_case])
예제 #3
0
def test_abs_diff():
    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    xp = Var("x'", nat)
    yp = Var("y'", nat)
    diff = GlobalVar("diff")
    y_z_case = Clause(PatternConstructor(p.z, []), x)
    y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
    x_z_case = Clause(PatternConstructor(p.z, []), y)
    x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_abs_diff():
    # TODO(@M.K.): refactor using tuple pattern (not yet implemented)
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    x = Var("x", nat())
    y = Var("y", nat())
    xp = Var("x'", nat())
    yp = Var("y'", nat())
    diff = GlobalVar("diff")
    y_z_case = Clause(PatternConstructor(z, []), x)
    y_s_case = Clause(PatternConstructor(s, [PatternVar(yp)]), diff(yp, xp))
    x_z_case = Clause(PatternConstructor(z, []), y)
    x_s_case = Clause(PatternConstructor(s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
    mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
    orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
예제 #5
0
def test_global_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 3))
예제 #6
0
def test_global_match_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    x = Var("x", nat())
    z_case = Clause(PatternConstructor(z, []), z())
    s_case = Clause(PatternConstructor(s, [PatternVar(x)]), s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
예제 #7
0
def test_global_match_nat_id():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
    orig = Match(make_nat_expr(p, 3), [z_case, s_case])
    orig = Function([], orig)
    res = dcpe(orig, mod=mod)
    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
예제 #8
0
def test_match_nat_id():
    mod = Module()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat()
    x = Var("x", nat)
    y = Var("y", nat)
    nat_id = GlobalVar("nat_id")
    z_case = Clause(PatternConstructor(p.z, []), p.z())
    s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
    mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
    orig = nat_id(make_nat_expr(p, 3))
    res = dcpe(orig, mod=mod)
    assert alpha_equal(res, make_nat_expr(p, 3))