def check_grad(func, mod=None): """ Test that directional gradient calculated by reverse mode is close to the one calculated by finite difference. """ global CHECK_GRAD_COUNTER if mod is None: mod = relay.Module() def make(name): return GlobalVar(name + str(CHECK_GRAD_COUNTER)) func_name = make("func_") back_func_name = make("back_func_") finite_difference_func_name = make("finite_difference_") reverse_mode_func_name = make("reverse_mode_") check_func_name = make("check_func_") CHECK_GRAD_COUNTER = CHECK_GRAD_COUNTER + 1 epsilon = relay.const(0.01) mod[func_name] = func mod[back_func_name] = gradient(mod[func_name], mod=mod) params = mod[func_name].params directions = [rand_from_type(x.checked_type) for x in params] ft = TensorType(()) sb = ScopeBuilder() def get_reverse_mode_result(e, d, t): assert isinstance(t, TensorType) return op.cast(e * d, 'float32') bf = sb.let("bf", TupleGetItem(back_func_name(*params), 1)) reverse_mode_results = [ get_reverse_mode_result(TupleGetItem(bf, i), directions[i], x.checked_type) for i, x in enumerate(params) ] reverse_mode_result = relay.const(0.0) for x in reverse_mode_results: reverse_mode_result = reverse_mode_result + op.reduce.sum(x) sb.ret(reverse_mode_result) reverse_mode_result = sb.get() mod[reverse_mode_func_name] = Function(params, reverse_mode_result, ft, mod[func_name].type_params, mod[func_name].attrs) finite_difference_result = op.reduce.sum( (func_name(*[x + epsilon * y for x, y in zip(params, directions)]) - func_name(*params)) / epsilon) mod[finite_difference_func_name] = Function(params, finite_difference_result, ft, mod[func_name].type_params, mod[func_name].attrs) check_func_result = op.abs( reverse_mode_func_name(*params) - finite_difference_func_name(*params)) mod[check_func_name] = Function(params, check_func_result, ft, mod[func_name].type_params, mod[func_name].attrs) ex = create_executor(mod=mod) res = ex.evaluate( check_func_name(*[rand_from_type(x.checked_type) for x in params])) assert res.data.asnumpy() < 0.001
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(gradient(f)) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) assert alpha_equal(g, expected)
def test_ref(): d = relay.Var("d") r = relay.Var("r") x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) assert alpha_equal(dcpe(square), Function([d], d * d))
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_tuple_get_item(): tt = relay.TupleType([e.float32, e.float32]) t = relay.Var('t', tt) a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = run_opt_pass(g, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = run_opt_pass(orig, transform.DeadCodeElimination()) assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected)
def test_head_cons(): mod = Module() p = Prelude(mod) hd = p.hd t = TypeVar("t") x = Var("x", t) body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert alpha_equal(res, Function([x], x, t, [t]))
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(g, expected)
def test_empty_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) # TODO(mbs): Revisit once DCE eliminates dead writes. g = dcpe(f, grad=True, ignore_impurity=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) assert tvm.ir.structural_equal(g, expected)
def test_double(): mod = Module() x = var('x', shape=()) double = GlobalVar('double') mod[double] = Function([x], x + x) x = var('x', shape=()) cfunc = compile(Function([x], double(double(x))), mod) a = tvm.nd.array(np.array(1.5, dtype='float32')) output = cfunc(a) np.testing.assert_allclose(output.asnumpy(), np.array(6.0, dtype='float32'))
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) assert tvm.ir.structural_equal(dcpe(square), expected)
def test_compose(): mod = Module() p = Prelude(mod) x = relay.Var('x') inc = GlobalVar('inc') mod[inc] = Function([x], p.s(x)) x = relay.Var('x') func = GlobalVar('func') f = Function([x], relay.Call(p.compose(inc, p.double), [x])) mod[func] = f cfunc = compile(func, mod) assert nat_to_int(cfunc(p.s(p.s(p.z())))) == 5
def test_recur_sum_global(): mod = Module() x = var('x', dtype='int32', shape=()) sum = GlobalVar('sum') c = relay.const(0) mod[sum] = Function([x], relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), relay.TensorType(dtype='int32', shape=())) cfunc = compile(Function([], sum(relay.const(10))), mod) output = cfunc() np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32'))
def test_recur_sum_local(): mod = Module() x = var('x', dtype='int32', shape=()) t = relay.TensorType(dtype='int32', shape=()) sum = relay.Var('sum', type_annotation=relay.FuncType([t], t)) c = relay.const(0) func = Function([x], relay.If(op.less(x, c), c, x + sum(x - relay.const(1))), t) body = relay.Let(sum, func, sum(relay.const(10))) cfunc = compile(Function([], body), mod) output = cfunc() np.testing.assert_allclose(output.asnumpy(), np.array(55, dtype='int32'))
def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod[mod.entry_func] = Function([], expected) expected = mod[mod.entry_func].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert alpha_equal(res.body, expected)
def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) t = TypeVar("t") x = Var("x", t) rlist, cons, nil = p.mod.get_type("List") hd = p.mod.get_global_var("hd") body = hd(cons(x, nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) expected_mod = tvm.IRModule.from_expr(Function([x], x, t, [t])) assert tvm.ir.structural_equal(res, expected_mod["main"])
def test_loop(): mod = tvm.IRModule() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod["main"] = Function([], expected) expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert tvm.ir.structural_equal(res.body, expected)
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = transform.OptimizeOnExpr(Function([d], d * d), transform.InferType()) assert alpha_equal(dcpe(square), expected)
def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, _, _ = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, _, _ = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat()) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) assert tvm.ir.structural_equal(prog, res.body)
def test_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_swap_loop(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) assert tvm.ir.structural_equal(prog, res.body)
def test_ref(): t = relay.TensorType([], "float32") d = relay.Var("d", t) r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) # TODO(mbs): Revisit once DCE eliminates dead writes. actual = dcpe(square, ignore_impurity=True) assert tvm.ir.structural_equal(actual, expected)
def test_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(z, []), z()) s_case = Clause(PatternConstructor(s, [PatternVar(y)]), s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_match_nat_id(): mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) nat_id = GlobalVar("nat_id") z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_map(): mod = Module() p = Prelude(mod) f = GlobalVar("f") t = TypeVar("t") a = Var("a", t) mod[f] = Function([a], a, t, [t]) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) expected = Function([], expected) mod["main"] = expected expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, expected.body)
def test_triangle_number(): t = relay.TensorType([], "int32") x = Var("x", t) f_var = Var("f") f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - const(1)))) orig = run_infer_type(Let(f_var, f, f_var(const(10)))) assert_alpha_equal(dcpe(orig), const(55))
def test_42(): mod = Module() func = Function([], relay.const(42)) cfunc = compile(func, mod) output = cfunc() np.testing.assert_allclose(output.asnumpy(), np.array(42.0, dtype='float32'))
def test_nat_3(): mod = Module() p = Prelude(mod) add_nat_definitions(p) cfunc = compile(Function([], p.s(p.s(p.s(p.z())))), mod) output = cfunc() assert nat_to_int(output) == 3