def test_recursive(): mod = tvm.IRModule() x = relay.var("x", shape=(2, )) i = relay.var("i", shape=(), dtype="int32") s = relay.var("s", shape=(2, )) cond = i < relay.const(10, dtype="int32") loop = relay.var("while_loop") sb = relay.scope_builder.ScopeBuilder() with sb.if_scope(cond): ii = i + relay.const(1, dtype="int32") ss = s + x sb.ret(loop(ii, ss)) with sb.else_scope(): sb.ret(s) func = relay.Function([i, s], sb.get()) ret = relay.Let( loop, func, loop(relay.const(0, dtype="int32"), relay.zeros(shape=(2, ), dtype="float32"))) mod["main"] = relay.Function([x], ret) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2
def test_closure(): mod = tvm.IRModule() x = relay.var('x', shape=(2,)) y = relay.var('y', shape=(2,)) inner_func = relay.Function([x], x + y) outer_func = relay.Function([y], inner_func) clo = outer_func(relay.ones(shape=(2,), dtype="float32")) mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), dtype="float32")])) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 3
def test_basic(): mod = relay.Module() x2 = relay.var('x2', shape=(10, 5)) y2 = relay.var('y2', shape=(1, 5)) level2_func = relay.Function([x2, y2], relay.op.add(x2, y2)) x1 = relay.var('x1', shape=(10, 5)) y1 = relay.var('y1', shape=(1, 5)) level1_func = relay.Function([x1, y1], level2_func(x1, y1)) mod["main"] = level1_func new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2
def test_basic(): mod = tvm.IRModule() x2 = relay.var("x2", shape=(10, 5)) y2 = relay.var("y2", shape=(1, 5)) level2_func = relay.Function([x2, y2], relay.op.add(x2, y2)) x1 = relay.var("x1", shape=(10, 5)) y1 = relay.var("y1", shape=(1, 5)) level1_func = relay.Function([x1, y1], level2_func(x1, y1)) mod["main"] = level1_func mod = relay.transform.InferType()(mod) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2