def test_top_level_nested_if(): x = relay.var('x', shape=(), dtype='bool') y = relay.var('y', shape=(), dtype='float32') z = relay.var('z', shape=(), dtype='float32') cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype='float32') three = relay.const(3, dtype='float32') y2 = relay.add(y, y) z2 = relay.add(z, z) true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2)) false_branch = relay.If(cond_f, z2, one) body = relay.If(x, true_branch, false_branch) '\n free_var %x: bool\n if (%x) {\n if (True) {\n free_var %z: float32\n %0 = add(%z, %z);\n free_var %y: float32\n %1 = add(%y, %y);\n add(%0, %1)\n } else {\n add(3f, %1)\n }\n } else {\n if (False) {\n %0\n } else {\n 1f\n }\n }\n ' def expected(): x = relay.var('x', shape=(), dtype='bool') y = relay.var('y', shape=(), dtype='float32') z = relay.var('z', shape=(), dtype='float32') cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype='float32') three = relay.const(3, dtype='float32') y2 = relay.var('y2') z2 = relay.var('z2') true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2)) true_branch = relay.Let(y2, relay.add(y, y), true_branch) false_branch = relay.If(cond_f, z2, one) body = relay.If(x, true_branch, false_branch) body = relay.Let(z2, relay.add(z, z), body) return body bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) '\n free_var %z: float32\n let %x: float32 = add(%z, %z) /* ty=float32 */;\n free_var %x1: bool\n if (%x1) {\n free_var %y: float32\n let %x2: float32 = add(%y, %y) /* ty=float32 */;\n if (True /* ty=bool */) {\n add(%x, %x2) /* ty=float32 */\n } else {\n add(3f /* ty=float32 */, %x2) /* ty=float32 */\n }\n } else {\n if (False /* ty=bool */) {\n %x\n } else {\n 1f /* ty=float32 */\n }\n }\n ' expected_output = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
def test_nested_if(): x = relay.var('x', shape=(), dtype='bool') y = relay.var('y', shape=(), dtype='float32') cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype='float32') two = relay.const(2, dtype='float32') three = relay.const(3, dtype='float32') y2 = relay.add(y, y) true_branch = relay.If(cond_t, y2, relay.add(three, y2)) false_branch = relay.If(cond_f, two, one) body = relay.If(x, true_branch, false_branch) '\n free_var %x: bool\n if (%x) {\n if (True) {\n free_var %y: float32\n %0 = add(%y, %y);\n %0\n } else {\n add(3f, %0)\n }\n } else {\n if (False) {\n 2f\n } else {\n 1f\n }\n }\n ' def expected(): x = relay.var('x', shape=(), dtype='bool') y = relay.var('y', shape=(), dtype='float32') cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype='float32') two = relay.const(2, dtype='float32') three = relay.const(3, dtype='float32') y2 = relay.var('y2') true_branch = relay.If(cond_t, y2, relay.add(three, y2)) true_branch = relay.Let(y2, relay.add(y, y), true_branch) false_branch = relay.If(cond_f, two, one) body = relay.If(x, true_branch, false_branch) return body bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()]) '\n free_var %x: bool\n if (%x) {\n free_var %y: float32\n let %x1: float32 = add(%y, %y) /* ty=float32 */;\n if (True /* ty=bool */) {\n %x1\n } else {\n add(3f /* ty=float32 */, %x1) /* ty=float32 */\n }\n } else {\n if (False /* ty=bool */) {\n 2f /* ty=float32 */\n } else {\n 1f /* ty=float32 */\n }\n }\n ' expected_output = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True) check_basic_block_normal_form(bblock)
def test_cps_pe(): def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] ), ) assert Feature.fRefCreate not in detect_feature(x) unit = relay.Function([], relay.const(0.0, dtype="float32")) f_ref = relay.Var("f_ref") one = relay.const(1.0, dtype="float32") two = relay.const(2.0, dtype="float32") cond = relay.var(shape=(), dtype="uint1", name_hint="cond") true_branch = relay.RefWrite(f_ref, relay.Function([], one)) false_branch = relay.RefWrite(f_ref, relay.Function([], two)) if_expr = relay.If(cond, true_branch, false_branch) stmt = relay.Let( f_ref, relay.RefCreate(unit), relay.Let(relay.Var("x"), if_expr, relay.Call(relay.RefRead(f_ref), [])), ) F = relay.Function([cond], stmt) destroy_ref(F) G = relay.Function([cond], relay.If(cond, one, two)) G = run_infer_type(G) G = relay.transform.gradient(G) destroy_ref(G) x = relay.var("x", shape=(1, 16)) y = relay.var("y", shape=(1, 16)) z = relay.var("z", shape=(1, 16)) cond = relay.var("cond", shape=(), dtype="uint1") H = relay.If(cond, x, y) H = relay.add(H, z) H = relay.Function([cond, x, y, z], H) H = run_infer_type(H) H = relay.transform.gradient(H) destroy_ref(H)
def compute(self, input_size, hidden_size, output_size): self.category_var = category = relay.var('category', shape=(1, data.N_CATEGORIES)) self.inp_topi_var = inp_topi = relay.var('input', shape=(), dtype='int32') self.hidden_var = hidden = relay.var('hidden', shape=(1, hidden_size)) self.hidden = initialize(self.hidden_var) n_letter = relay.const(data.N_LETTERS) one_diag = relay.const(np.diag(np.ones(58)).astype('float32')) boxed_one = relay.const(np.array([1]).astype('int32')) inp = op.take(one_diag, op.multiply(boxed_one, inp_topi), axis=0) combined = op.concatenate( [op.concatenate([category, inp], axis=1), hidden], axis=1) hidden = self.linear(data.N_CATEGORIES + input_size + hidden_size, hidden_size, combined, name='i2h') output = self.linear(data.N_CATEGORIES + input_size + hidden_size, output_size, combined, name='i2o') output_combined = op.concatenate([hidden, output], axis=1) output = self.linear(hidden_size + output_size, output_size, output_combined, name='o2o') # output = op.nn.dropout(output, 0.1) #attributes has not been registered output = op.nn.log_softmax(output, axis=1) topi = op.argmax(output) body = relay.Tuple([ hidden, topi, op.equal(topi, op.subtract(n_letter, relay.const(1))) ]) fwd_para = [self.category_var, self.inp_topi_var, self.hidden_var] fwd_func = relay.Function(fwd_para, body) self.fwd = relay.Var('fwd') max = relay.var('max', shape=(), dtype='int32') inp_para = [max] + [copy_var(v) for v in fwd_para] fwd_res = self.fwd(*inp_para[1:]) fwd_res_0 = relay.TupleGetItem(fwd_res, 0) fwd_res_1 = relay.TupleGetItem(fwd_res, 1) fwd_res_2 = relay.TupleGetItem(fwd_res, 2) else_else_branch = self.prelude.cons( fwd_res_1, self.recurse(op.subtract(max, relay.const(1)), inp_para[1], fwd_res_1, fwd_res_0)) else_branch = relay.If(fwd_res_2, self.prelude.nil(), else_else_branch) body = relay.If(op.equal(max, relay.const(0)), self.prelude.nil(), else_branch) return inp_para, relay.Let(self.fwd, fwd_func, body), None
def expected(): x = relay.var('x', shape=(), dtype='bool') y = relay.var('y', shape=(), dtype='float32') cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype='float32') two = relay.const(2, dtype='float32') three = relay.const(3, dtype='float32') y2 = relay.var('y2') true_branch = relay.If(cond_t, y2, relay.add(three, y2)) true_branch = relay.Let(y2, relay.add(y, y), true_branch) false_branch = relay.If(cond_f, two, one) body = relay.If(x, true_branch, false_branch) return body
def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) anf = run_opt_pass(x, [transform.ToANormalForm(), transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) d = relay.Var('d', relay.IncompleteType()) true_branch = relay.Let(a, relay.const(2), a) false_branch = relay.Let(b, relay.const(3), b) expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) expected_output = run_opt_pass(expected_output, transform.InferType()) assert tvm.ir.structural_equal(anf, expected_output)
def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) anf = infer_type(to_a_normal_form(x)) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) d = relay.Var('d', relay.IncompleteType()) true_branch = relay.Let(a, relay.const(2), a) false_branch = relay.Let(b, relay.const(3), b) expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) expected_output = infer_type(expected_output) assert alpha_equal(anf, expected_output)
def test_cps_pe(): def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass( x, transform.Sequential([ transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True) ])) assert Feature.fRefCreate not in detect_feature(x) unit = relay.Function([], relay.const(0., dtype='float32')) f_ref = relay.Var("f_ref") one = relay.const(1., dtype='float32') two = relay.const(2., dtype='float32') cond = relay.var(shape=(), dtype='uint1', name_hint='cond') true_branch = relay.RefWrite(f_ref, relay.Function([], one)) false_branch = relay.RefWrite(f_ref, relay.Function([], two)) if_expr = relay.If(cond, true_branch, false_branch) stmt = relay.Let( f_ref, relay.RefCreate(unit), relay.Let(relay.Var("x"), if_expr, relay.Call(relay.RefRead(f_ref), []))) F = relay.Function([cond], stmt) destroy_ref(F) G = relay.Function([cond], relay.If(cond, one, two)) G = run_infer_type(G) G = relay.transform.gradient(G) destroy_ref(G) x = relay.var("x", shape=(1, 16)) y = relay.var("y", shape=(1, 16)) z = relay.var("z", shape=(1, 16)) cond = relay.var("cond", shape=(), dtype='uint1') H = relay.If(cond, x, y) H = relay.add(H, z) H = relay.Function([cond, x, y, z], H) H = run_infer_type(H) H = relay.transform.gradient(H) destroy_ref(H)
def test_recursion(): """ Program: let f(n: i32) -> i32 = { m = (n * 2) if (n == 0) { return m; } else { return m + f(n - 1); } } f(5); """ mod = tvm.IRModule() i64 = relay.TensorType((), "int64") f = relay.GlobalVar("f") n = relay.Var("n", i64) m = n * relay.const(2, "int64") cond = relay.equal(n, relay.const(0, "int64")) false_branch = m + f(n - relay.const(1, "int64")) funcbody = relay.If(cond, m, false_branch) value = relay.Function([n], funcbody, i64, []) mod[f] = value check_eval(f(relay.const(5, "int64")), 30.0, mod=mod) old_f = mod[f] mod = transform.ToBasicBlockNormalForm()(mod) f = mod[f] check_eval(f(relay.const(5, "int64")), 30.0, mod=mod) check_basic_block_normal_form(f)
def _make_env_find(self, m, rval_t): ctr = m['ctr'] gv = relay.GlobalVar(f"$_env_find<{ctr.name_hint}>") env = relay.Var("env", env_type(env_val())) key = relay.Var("key", relay.ty.scalar_type('int64')) dft = relay.Var("dft", rval_t) k = relay.Var("k") v = relay.Var("v") r = relay.Var("r") x = relay.Var("x") extract_clause = adt.Clause( adt.PatternConstructor(ctr, [adt.PatternVar(x)]), x) empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []), dft) cons_clause = adt.Clause( adt.PatternConstructor( cons_env, [adt.PatternVar(k), adt.PatternVar(v), adt.PatternVar(r)]), relay.If(relay.equal(key, k), adt.Match(v, [extract_clause], complete=False), relay.Call(gv, [r, key, dft]))) body = adt.Match(env, [empty_clause, cons_clause]) fn = relay.Function([env, key, dft], body, rval_t) m['env_find'] = (gv, fn) return gv, fn
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call( f, [relay.const(2), relay.const(10000.0)])) dced = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.InferType()) assert graph_equal(dced, orig) dced = run_opt_pass(relay.Let(f, value, e.three), transform.DeadCodeElimination()) assert alpha_equal(dced, e.three)
def test_recursion(): """ Program: let f(n: i32) -> i32 = { m = (n * 2) if (n == 0) { return m; } else { return m + f(n - 1); } } f(5); """ mod = tvm.IRModule() i64 = relay.TensorType((), 'int64') f = relay.GlobalVar("f") n = relay.Var("n", i64) m = n * relay.const(2, 'int64') funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), m, m + f(n - relay.const(1, 'int64'))) value = relay.Function([n], funcbody, i64, []) mod[f] = value check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) old_f = mod[f] mod = transform.ToANormalForm()(mod) f = mod[f] check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let( f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def _make_env_update(self, m, rval_t): ctr = m['ctr'] gv = relay.GlobalVar(f"$_env_update<{ctr.name_hint}>") env = relay.Var("env", env_type(env_val())) key = relay.Var("key", relay.ty.scalar_type('int64')) val = relay.Var("val", rval_t) k = relay.Var("k") v = relay.Var("v") r = relay.Var("r") empty_clause = adt.Clause(adt.PatternConstructor(empty_env, []), cons_env(key, ctr(val), env)) cons_clause = adt.Clause( adt.PatternConstructor( cons_env, [adt.PatternVar(k), adt.PatternVar(v), adt.PatternVar(r)]), relay.If(relay.equal(key, k), cons_env(key, ctr(val), env), cons_env(k, v, relay.Call(gv, [r, key, val])))) body = adt.Match(env, [empty_clause, cons_clause]) fn = relay.Function([env, key, val], body, env_type(env_val())) m['env_update'] = (gv, fn) return gv, fn
def expected(): x = relay.var("x", shape=(), dtype="bool") y = relay.var("y", shape=(), dtype="float32") z = relay.var("z", shape=(), dtype="float32") cond_t = relay.const(True) cond_f = relay.const(False) one = relay.const(1, dtype="float32") three = relay.const(3, dtype="float32") y2 = relay.var("y2") z2 = relay.var("z2") true_branch = relay.If(cond_t, relay.add(z2, y2), relay.add(three, y2)) true_branch = relay.Let(y2, relay.add(y, y), true_branch) false_branch = relay.If(cond_f, z2, one) body = relay.If(x, true_branch, false_branch) body = relay.Let(z2, relay.add(z, z), body) return body
def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n") np = relay.Param(n, e.int32) data = relay.Var("data") datap = relay.Param(data, e.float32) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) value = relay.Function([np, datap], e.float32, funcbody, []) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal( dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
def test_recursion(): """ Program: let sum_twice(n: i32) -> i32 = { m = (n * 2) if (n == 0) { return m; } else { return m + sum(n - 1); } } sum_twice(5); """ return # cannot be run as fuse_ops need to recursively visit mod = relay.Module() i64 = relay.TensorType((), 'int64') f = relay.GlobalVar("f") n = relay.Var("n", i64) m = n * relay.const(2, 'int64') funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), m, m + f(n - relay.const(1, 'int64'))) value = relay.Function([n], funcbody, i64, []) mod[f] = value check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) old_f = mod[f] mod = transform.ToANormalForm()(mod) f = mod[f] check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
def after(): data = relay.var("data", shape=(1, 32)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") cb_1 = relay.annotation.compiler_begin(eq1, target) cb_2 = relay.annotation.compiler_begin(eq2, target) equality_condition = relay.equal(cb_1, cb_2) ce_1 = relay.annotation.compiler_end(equality_condition, target) # if condition cb_3 = relay.annotation.compiler_begin(data, target) true_branch = relay.tanh(cb_3) ce_2 = relay.annotation.compiler_end(true_branch, target) # else condition cb_4 = relay.annotation.compiler_begin(data, target) false_branch = relay.sigmoid(cb_4) ce_3 = relay.annotation.compiler_end(false_branch, target) if_condition = relay.If(ce_1, ce_2, ce_3) cb_5 = relay.annotation.compiler_begin(if_condition, target) erf_out = relay.erf(cb_5) ce_4 = relay.annotation.compiler_end(erf_out, target) func = relay.Function([data, eq1, eq2], ce_4) mod = tvm.IRModule.from_expr(func) return mod
def test_higher_order_nested(): x = relay.var("x", dtype="float32", shape=(1, )) s = relay.var("s", dtype="float32", shape=(1, )) shared = relay.add(s, s) func_true = relay.Function([x], relay.add(x, shared)) choice_t = relay.FuncType([], relay.scalar_type("bool")) f = relay.Var("f", choice_t) z = relay.Var("z") body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared))) top = relay.Function([f, s], body) """ fn (%f: fn () -> bool, %s: Tensor[(1), float32]) { %0 = %f(); if (%0) { fn (%x: Tensor[(1), float32]) { %1 = add(%s, %s); add(%x, %1) } } else { fn (%z) { add(%z, %1) } } } """ check_basic_block_normal_form(top)
def test_if_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert alpha_equal(if_sample, same) different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_cond) different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_true) different_false = relay.If(v1, relay.const(1), relay.Tuple([])) assert not alpha_equal(if_sample, different_false)
def test_if_ref(): shape = () dtype = "bool" t = relay.TensorType(shape, dtype) d = relay.Var("d", t) r = relay.Var("r") update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r))) u = relay.Var("u") body = relay.If(d, u(), u()) eff = relay.Var("eff") body = relay.Let(eff, body, relay.RefRead(r)) f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body))) f = infer_type(f) pe_f = infer_type(partial_evaluate(f)) ex = create_executor() f_res = ex.evaluate(f)(relay.const(True)) pe_f_res = ex.evaluate(pe_f)(relay.const(True)) np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy())) np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
def test_if(): choice_t = relay.FuncType([], relay.scalar_type("bool")) f = relay.Var("f", choice_t) true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32")) false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32")) top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) ft = infer_expr(top) tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
def test_multiple_ifs(target, dev): mod = tvm.IRModule({}) b = relay.var("b") v0 = relay.var("v0") v1 = relay.var("v1") v2 = relay.var("v2") v3 = relay.var("v3") out = relay.Tuple([v2, v3]) out = relay.Let(v3, relay.If(b, v1, v0), out) out = relay.Let(v2, relay.If(b, v0, v1), out) out = relay.Let(v1, relay.Tuple([relay.const(1)]), out) out = relay.Let(v0, relay.Tuple([relay.const(0)]), out) fn = relay.Function([b], out) mod["main"] = fn func = relay.create_executor(device=dev, mod=mod, kind="vm").evaluate() res = vmobj_to_list(func(False)) assert res == [1, 0]
def after(annotate_non_call_ops): var1 = relay.var("var1", shape=(2,)) var2 = relay.var("var2", shape=(), dtype="int32") var3 = relay.var("var3", shape=(2,)) var4 = relay.const(10, dtype="int32") cb_1 = relay.annotation.compiler_begin(var2, target) cb_2 = relay.annotation.compiler_begin(var4, target) less_condition = relay.less(cb_1, cb_2) ce_1 = relay.annotation.compiler_end(less_condition, target) loop = relay.var("while_loop") # if condition cb_3 = relay.annotation.compiler_begin(var2, target) cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target) add_op_1 = relay.add(cb_3, cb_4) ce_2 = relay.annotation.compiler_end(add_op_1, target) cb_5 = relay.annotation.compiler_begin(ce_2, "default") if annotate_non_call_ops else ce_2 cb_6 = relay.annotation.compiler_begin(var3, target) cb_7 = relay.annotation.compiler_begin(var1, target) add_op_2 = relay.add(cb_6, cb_7) ce_3 = relay.annotation.compiler_end(add_op_2, target) cb_8 = relay.annotation.compiler_begin(ce_3, "default") if annotate_non_call_ops else ce_3 true_branch = loop(cb_5, cb_8) # while loop ce_4 = ( relay.annotation.compiler_end(true_branch, "default") if annotate_non_call_ops else true_branch ) if_condition = relay.If(ce_1, ce_4, var3) const_1 = relay.const(0, dtype="int32") cb_9 = ( relay.annotation.compiler_begin(const_1, "default") if annotate_non_call_ops else const_1 ) cb_10 = relay.annotation.compiler_begin(var1, target) zeros_like = relay.zeros_like(cb_10) ce_5 = relay.annotation.compiler_end(zeros_like, target) cb_11 = relay.annotation.compiler_begin(ce_5, "default") if annotate_non_call_ops else ce_5 while_condition = loop(cb_9, cb_11) ce_6 = ( relay.annotation.compiler_end(while_condition, "default") if annotate_non_call_ops else while_condition ) func_1 = relay.Function([var2, var3], if_condition) ret = relay.Let(loop, func_1, ce_6) func_2 = relay.Function([var1], ret) mod = tvm.IRModule.from_expr(func_2) return mod
def test_multiple_ifs(): mod = tvm.IRModule({}) b = relay.var('b') v0 = relay.var('v0') v1 = relay.var('v1') v2 = relay.var('v2') v3 = relay.var('v3') out = relay.Tuple([v2, v3]) out = relay.Let(v3, relay.If(b, v1, v0), out) out = relay.Let(v2, relay.If(b, v0, v1), out) out = relay.Let(v1, relay.Tuple([relay.const(1)]), out) out = relay.Let(v0, relay.Tuple([relay.const(0)]), out) fn = relay.Function([b], out) mod['main'] = fn ctx = tvm.runtime.ndarray.context('llvm', 0) vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm') res = vmobj_to_list(vm.evaluate()(False)) assert (res == [1, 0])
def test_ifelse(): assert parses_as( """ if (True) { 0 } else { 1 } """, relay.If(relay.const(True), relay.const(0), relay.const(1)))
def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) anf = transform.OptimizeOnExpr( x, [transform.ToANormalForm(), transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) d = relay.Var('d', relay.IncompleteType()) true_branch = relay.Let(a, relay.const(2), a) false_branch = relay.Let(b, relay.const(3), b) expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) expected_output = transform.OptimizeOnExpr(expected_output, transform.InferType()) assert alpha_equal(anf, expected_output)
def test_multiple_ifs(): mod = tvm.IRModule({}) b = relay.var("b") v0 = relay.var("v0") v1 = relay.var("v1") v2 = relay.var("v2") v3 = relay.var("v3") out = relay.Tuple([v2, v3]) out = relay.Let(v3, relay.If(b, v1, v0), out) out = relay.Let(v2, relay.If(b, v0, v1), out) out = relay.Let(v1, relay.Tuple([relay.const(1)]), out) out = relay.Let(v0, relay.Tuple([relay.const(0)]), out) fn = relay.Function([b], out) mod["main"] = fn ctx = tvm.runtime.ndarray.context("llvm", 0) vm = relay.create_executor(ctx=ctx, mod=mod, kind="vm") res = vmobj_to_list(vm.evaluate()(False)) assert res == [1, 0]
def expected(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) fn2 = relay.Function([], relay.const(2)) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) p = relay.var("p", "bool") mod["main"] = relay.Function([p], relay.Call(relay.If(p, fn1, fn2), [])) return mod
def test_ifelse(): assert alpha_equal( relay.fromtext(""" if (True) { 0 } else { 1 } """), relay.If(relay.const(True), relay.const(0), relay.const(1)))