def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x)
def test_recursion(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) mod[mod.entry_func] = func mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod) mod[mod.entry_func] = un_cps(mod[mod.entry_func]) ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) forward = ex.evaluate(mod.entry_func)(i_nd) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) # TODO(mbs): Revisit once DCE can eliminate dead writes. x = run_opt_pass( x, tvm.transform.Sequential( [ transform.PartialEvaluate(), transform.InferType(), transform.DeadCodeElimination(inline_once=True, ignore_impurity=True), ] ), ) assert Feature.fRefCreate not in detect_feature(x)
def test_recursion(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat_iterate = p.mod.get_global_var("nat_iterate") shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], nat_iterate(double, make_nat_expr(p, 3))(i)) mod["main"] = func mod = relay.transform.InferType()(mod) mod["main"] = to_cps(mod["main"], mod=mod) mod = relay.transform.InferType()(mod) mod["main"] = un_cps(mod["main"]) i_nd = rand(dtype, *shape) forward = create_executor(mod=mod).evaluate()(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy())