def __init__(self, *, name="f", **kwargs): name = f"{name}_{Network.cnt}" Network.cnt += 1 if len(Network.stack) is not 0: mod = Network.stack[-1].mod p = Network.stack[-1].p else: mod = Module() p = Prelude(mod) add_nat_definitions(p) self.mod = mod self.p = p self.inputs = [] self.weights = OrderedSet() self.sub_network = OrderedSet() self.f = relay.GlobalVar(name) self.recurse = relay.Var("recurse") self.use_recurse = False self.ret_type = None body = self.build(**kwargs) assert isinstance(body, relay.Expr) if self.use_recurse: inputs = [copy_var(v) for v in self.inputs] body = relay.Let( self.recurse, relay.Function(inputs, self.call_from_outside(*inputs)), body) self.mod[self.f] = relay.Function(self.inputs + self.all_weights(), body, self.ret_type)
def test_add_convert(): mod = Module() p = Prelude(mod) add_nat_definitions(p) cfunc = compile(p.add, mod) output = cfunc(int_to_nat(p, 12), int_to_nat(p, 34)) assert nat_to_int(output) == 46
def test_double(): mod = Module() p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 6))
def test_nat_3(): mod = Module() p = Prelude(mod) add_nat_definitions(p) cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod) output = cfunc() assert nat_to_int(output) == 3
def test_double(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
def test_global_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_global_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_swap_loop(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = dcpe(prog, mod=mod) assert alpha_equal(prog, res)
def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_compose(): mod = Module() p = Prelude(mod) add_nat_definitions(p) x = relay.Var('x') inc = GlobalVar('inc') mod[inc] = Function([x], p.s(x)) x = relay.Var('x') func = GlobalVar('func') f = Function([x], relay.Call(p.compose(inc, p.double), [x])) mod[func] = f cfunc = compile(func, mod) assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5
def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) assert tvm.ir.structural_equal(prog, res.body)
def test_match_nat_id(): mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) res = dcpe(orig, mod=mod) assert alpha_equal(res, make_nat_expr(p, 3))
def test_nat_add(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat add = p.add s = p.s z = p.z ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(intrp.evaluate(add(s(z()), s(z())))) == 2 assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext()
def test_constructor_tag_round_trip(): mod1 = tvm.IRModule() p1 = Prelude(mod1) add_nat_definitions(p1) mod2 = tvm.IRModule() p2 = Prelude(mod2) add_nat_definitions(p2) # ensure hashes match across modules ctors1 = constructor_list(p1) ctors2 = constructor_list(p2) for i in range(len(ctors1)): tag = ctors1[i].tag ctor = mod2.get_constructor(tag) assert ctor == ctors2[i] assert ctor.name_hint == ctors1[i].name_hint
def test_recursion(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) mod[mod.entry_func] = func mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod) mod[mod.entry_func] = un_cps(mod[mod.entry_func]) ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) forward = ex.evaluate(mod.entry_func)(i_nd) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
def test_constructor_tag_differences(): # ensure that if we have the type data for a given ADT, the tags # for the constructors of the *same ADT* are simple offsets from # each other mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) adts = adt_list(p) for adt in adts: data = mod[adt] for i in range(len(data.constructors) - 1): ctor1 = data.constructors[i] ctor2 = data.constructors[i + 1] assert ctor2.tag - ctor1.tag == 1 # make sure there is something present at the MSB assert ctor1.tag - i != 0 assert ctor2.tag - (i + 1) != 0
def test_nat_add(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat add = p.add s = p.s z = p.z ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.ToANormalForm()(mod) expr = mod["f"] assert count(p, intrp.evaluate(expr.body)) == 2 assert Feature.fLet in detect_feature(mod[add])
def test_pow(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) assert back_func.checked_type == relay.FuncType( [t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) ex = create_executor(mod=mod) forward, (grad_i, ) = ex.evaluate(back_func)(i_nd) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) tvm.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) xp = Var("x'", nat) yp = Var("y'", nat) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(p.z, []), x) y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(p.z, []), y) x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_nat_update(): m = tvm.IRModule() p = Prelude(m) add_nat_definitions(p) m = transform.ToANormalForm()(m) transform.PartialEvaluate()(m)
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) def count(e): return count_(p, e) ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") z = p.z s = p.s nat = p.nat double = p.double add = p.add