def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)])) dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) orig = transform.OptimizeOnExpr(orig, transform.InferType()) assert graph_equal(dced, orig) dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three), transform.DeadCodeElimination()) assert alpha_equal(dced, e.three)
def check_func(func, ref_func): func = ir_pass.infer_type(func) ref_func = ir_pass.infer_type(ref_func) assert ir_pass.graph_equal(func, ref_func)
def check_json_roundtrip(node): json_str = tvm.save_json(node) back = tvm.load_json(json_str) assert graph_equal(back, node)