示例#1
0
def test_recursion():
    """
    Program:
       let f(n: i32, data: f32) -> f32 = {
          if (n == 0) {
              return data;
          } else {
              return f(n - 1, log(data));
          }
       }
       f(2, 10000);
    """
    f = relay.Var("f")
    f1 = relay.Var("f1")
    n = relay.Var("n", e.int32)
    data = relay.Var("data", e.float32)
    funcbody = relay.If(equal(n, relay.const(0)),
                        data,
                        relay.Call(f1, [subtract(n, relay.const(1)),
                                       log(data)]))
    value = relay.Function([n, data], funcbody, e.float32, [])
    orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)]))
    dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination())
    orig = transform.OptimizeOnExpr(orig, transform.InferType())
    assert graph_equal(dced, orig)
    dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three),
                                    transform.DeadCodeElimination())
    assert alpha_equal(dced, e.three)
示例#2
0
def check_func(func, ref_func):
    func = ir_pass.infer_type(func)
    ref_func = ir_pass.infer_type(ref_func)
    assert ir_pass.graph_equal(func, ref_func)
示例#3
0
def check_json_roundtrip(node):
    json_str = tvm.save_json(node)
    back = tvm.load_json(json_str)
    assert graph_equal(back, node)
示例#4
0
def check_func(func, ref_func):
    func = ir_pass.infer_type(func)
    ref_func = ir_pass.infer_type(ref_func)
    assert ir_pass.graph_equal(func, ref_func)
示例#5
0
def check_json_roundtrip(node):
    json_str = tvm.save_json(node)
    back = tvm.load_json(json_str)
    assert graph_equal(back, node)