def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If( equal(n, relay.const(0)), data, relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call( f, [relay.const(2), relay.const(10000.0)])) dced = run_opt_pass(orig, transform.DeadCodeElimination()) orig = run_opt_pass(orig, transform.InferType()) assert graph_equal(dced, orig) dced = run_opt_pass(relay.Let(f, value, e.three), transform.DeadCodeElimination()) assert alpha_equal(dced, e.three)
def parses_as(code, expr): # type: (str, relay.Expr) -> bool parsed = parse_text(code) result = graph_equal(parsed, expr) return result
def test_op_assoc(): assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
def test_parens(): assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) back = tvm.ir.load_json(json_str) assert graph_equal(back, node)
def check_func(func, ref_func): func = run_infer_type(func) ref_func = run_infer_type(ref_func) assert analysis.graph_equal(func, ref_func)
def check_json_roundtrip(node): from tvm.relay.analysis import graph_equal json_str = tvm.ir.save_json(node) back = tvm.ir.load_json(json_str) assert graph_equal(back, node)