def test_recursion():
    """
    Program:
       let f(n: i32, data: f32) -> f32 = {
          if (n == 0) {
              return data;
          } else {
              return f(n - 1, log(data));
          }
       }
       f(2, 10000);
    """
    f = relay.Var("f")
    f1 = relay.Var("f1")
    n = relay.Var("n", e.int32)
    data = relay.Var("data", e.float32)
    funcbody = relay.If(
        equal(n, relay.const(0)), data,
        relay.Call(f1, [subtract(n, relay.const(1)),
                        log(data)]))
    value = relay.Function([n, data], funcbody, e.float32, [])
    orig = relay.Let(f, value,
                     relay.Call(
                         f,
                         [relay.const(2), relay.const(10000.0)]))
    dced = run_opt_pass(orig, transform.DeadCodeElimination())
    orig = run_opt_pass(orig, transform.InferType())
    assert graph_equal(dced, orig)
    dced = run_opt_pass(relay.Let(f, value, e.three),
                        transform.DeadCodeElimination())
    assert alpha_equal(dced, e.three)
Esempio n. 2
0
def parses_as(code, expr):
    # type: (str, relay.Expr) -> bool
    parsed = parse_text(code)
    result = graph_equal(parsed, expr)
    return result
Esempio n. 3
0
def test_op_assoc():
    assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"),
                       parse_text("(((1 * 1) + 1) < 1) == 1"))
    assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"),
                       parse_text("1 == (1 < (1 + (1 * 1)))"))
Esempio n. 4
0
def test_parens():
    assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
    assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
Esempio n. 5
0
def check_json_roundtrip(node):
    json_str = tvm.ir.save_json(node)
    back = tvm.ir.load_json(json_str)
    assert graph_equal(back, node)
Esempio n. 6
0
def check_func(func, ref_func):
    func = run_infer_type(func)
    ref_func = run_infer_type(ref_func)
    assert analysis.graph_equal(func, ref_func)
Esempio n. 7
0
def check_json_roundtrip(node):
    from tvm.relay.analysis import graph_equal
    json_str = tvm.ir.save_json(node)
    back = tvm.ir.load_json(json_str)
    assert graph_equal(back, node)