def test_explicit_bound():
    x = relay.const(1)
    y = op.add(x, x)
    z = op.add(y, y)
    f = relay.Function([], op.add(z, z))
    assert not Feature.fLet in detect_feature(f)
    anf = run_opt_pass(f, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(anf)
    check_eval(f(), 8.0)
    check_eval(anf(), 8.0)
Пример #2
0
def test_no_explicit_bind():
    x = relay.const(1)
    y = op.add(x, x)
    z = op.add(y, y)
    f = relay.Function([], op.add(z, z))
    '\n    fn () {\n      %0 = add(1, 1);\n      %1 = add(%0, %0);\n      add(%1, %1)\n    }\n    '
    assert (not (Feature.fLet in detect_feature(f)))
    bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
    assert (Feature.fLet not in detect_feature(bblock))
    check_eval(f(), 8.0)
    check_eval(bblock(), 8.0)
    check_basic_block_normal_form(bblock)
def test_implicit_share():
    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = run_opt_pass(f, transform.ToGraphNormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
Пример #4
0
def test_round_trip():
    x = relay.Var("x")
    y = relay.Var("y")
    z = relay.Var("z")
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = run_opt_pass(f, transform.ToGraphNormalForm())
    h = run_opt_pass(g, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
    check_eval(h, [], 8.0)
def test_ad():
    shape = (10, 10)
    dtype = "float32"
    t = relay.TensorType(shape, dtype)
    x = relay.var("x", t)
    func = relay.Function([x], x + x)
    func = run_infer_type(func)
    mod = tvm.IRModule.from_expr(gradient(func))
    mod = relay.transform.InferType()(mod)
    back_func = mod["main"]
    feats = detect_feature(back_func)
    assert feats == set(
        [
            Feature.fVar,
            Feature.fTuple,
            Feature.fTupleGetItem,
            Feature.fFunction,
            Feature.fOp,
            Feature.fCall,
            Feature.fLet,
            Feature.fRefCreate,
            Feature.fRefRead,
            Feature.fRefWrite,
        ]
    )
Пример #6
0
def test_prelude():
    p = Prelude()
    feats = detect_feature(p.mod)
    assert feats == set([
        Feature.fVar, Feature.fGlobalVar, Feature.fConstant, Feature.fTuple,
        Feature.fTupleGetItem, Feature.fFunction, Feature.fOp, Feature.fCall,
        Feature.fLet, Feature.fIf, Feature.fConstructor, Feature.fMatch
    ])
Пример #7
0
 def destroy_ref(x):
     x = run_infer_type(x)
     x = to_cps(x)
     x = run_infer_type(x)
     y = un_cps(x)
     y = run_infer_type(y)
     x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
     assert Feature.fRefCreate not in detect_feature(x)
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    add = p.mod.get_global_var("nat_add")
    dev = tvm.device("llvm", 0)
    intrp = create_executor(mod=mod, device=dev, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToANormalForm()(mod)
    expr = mod["f"]
    assert count(p, intrp.evaluate(expr.body)) == 2
    assert Feature.fLet in detect_feature(mod[add])
Пример #9
0
 def destroy_ref(x):
     x = run_infer_type(x)
     x = to_cps(x)
     x = run_infer_type(x)
     y = un_cps(x)
     y = run_infer_type(y)
     # TODO(mbs): Revisit once DCE can eliminate dead writes.
     x = run_opt_pass(
         x,
         tvm.transform.Sequential(
             [
                 transform.PartialEvaluate(),
                 transform.InferType(),
                 transform.DeadCodeElimination(inline_once=True, ignore_impurity=True),
             ]
         ),
     )
     assert Feature.fRefCreate not in detect_feature(x)
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToANormalForm()(mod)
    expr = mod["f"]
    assert count(p, intrp.evaluate(expr.body)) == 2
    assert Feature.fLet in detect_feature(mod[add])
Пример #11
0
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context('llvm', 0)
    intrp = create_executor(mod=mod, ctx=ctx, target='llvm')
    assert (mod[add].checked_type == relay.FuncType([nat(), nat()], nat()))
    assert (count(p, intrp.evaluate(add(s(z()), s(z())))) == 2)
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar('f')
    mod[f] = relay.Function([], expr)
    mod = transform.ToBasicBlockNormalForm()(mod)
    opt_expr = mod['f']
    assert (count(p, intrp.evaluate(opt_expr.body)) == 2)
    assert (not (Feature.fLet in detect_feature(mod[add])))
    check_basic_block_normal_form(opt_expr)
Пример #12
0
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    add = p.mod.get_global_var("nat_add")
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.InferType()(mod)
    mod = transform.ToBasicBlockNormalForm()(mod)
    opt_expr = mod["f"]
    assert count(p, intrp.evaluate(opt_expr.body)) == 2
    assert not Feature.fLet in detect_feature(mod[add])
    check_basic_block_normal_form(opt_expr)