Exemple #1
0
def test_higher_order_nested():
    x = relay.var("x", dtype="float32", shape=(1, ))
    s = relay.var("s", dtype="float32", shape=(1, ))
    shared = relay.add(s, s)
    func_true = relay.Function([x], relay.add(x, shared))
    choice_t = relay.FuncType([], relay.scalar_type("bool"))
    f = relay.Var("f", choice_t)
    z = relay.Var("z")
    body = relay.If(f(), func_true, relay.Function([z], relay.add(z, shared)))
    top = relay.Function([f, s], body)
    """
    fn (%f: fn () -> bool, %s: Tensor[(1), float32]) {
      %0 = %f();
      if (%0) {
        fn (%x: Tensor[(1), float32]) {
          %1 = add(%s, %s);
          add(%x, %1)
        }
      } else {
        fn (%z) {
          add(%z, %1)
        }
      }
    }
    """
    check_basic_block_normal_form(top)
def test_recursion():
    """
    Program:
       let f(n: i32) -> i32 = {
          m = (n * 2)
          if (n == 0) {
              return m;
          } else {
              return m + f(n - 1);
          }
       }
       f(5);
    """
    mod = tvm.IRModule()
    i64 = relay.TensorType((), "int64")
    f = relay.GlobalVar("f")
    n = relay.Var("n", i64)
    m = n * relay.const(2, "int64")
    cond = relay.equal(n, relay.const(0, "int64"))
    false_branch = m + f(n - relay.const(1, "int64"))
    funcbody = relay.If(cond, m, false_branch)
    value = relay.Function([n], funcbody, i64, [])
    mod[f] = value
    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
    old_f = mod[f]
    mod = transform.ToBasicBlockNormalForm()(mod)
    f = mod[f]
    check_eval(f(relay.const(5, "int64")), 30.0, mod=mod)
    check_basic_block_normal_form(f)
Exemple #3
0
def test_nested_if():
    x = relay.var('x', shape=(), dtype='bool')
    y = relay.var('y', shape=(), dtype='float32')
    cond_t = relay.const(True)
    cond_f = relay.const(False)
    one = relay.const(1, dtype='float32')
    two = relay.const(2, dtype='float32')
    three = relay.const(3, dtype='float32')
    y2 = relay.add(y, y)
    true_branch = relay.If(cond_t, y2, relay.add(three, y2))
    false_branch = relay.If(cond_f, two, one)
    body = relay.If(x, true_branch, false_branch)
    '\n    free_var %x: bool\n    if (%x) {\n      if (True) {\n        free_var %y: float32\n        %0 = add(%y, %y);\n        %0\n      } else {\n        add(3f, %0)\n      }\n    } else {\n      if (False) {\n        2f\n      } else {\n        1f\n      }\n    }\n    '

    def expected():
        x = relay.var('x', shape=(), dtype='bool')
        y = relay.var('y', shape=(), dtype='float32')
        cond_t = relay.const(True)
        cond_f = relay.const(False)
        one = relay.const(1, dtype='float32')
        two = relay.const(2, dtype='float32')
        three = relay.const(3, dtype='float32')
        y2 = relay.var('y2')
        true_branch = relay.If(cond_t, y2, relay.add(three, y2))
        true_branch = relay.Let(y2, relay.add(y, y), true_branch)
        false_branch = relay.If(cond_f, two, one)
        body = relay.If(x, true_branch, false_branch)
        return body
    bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm()])
    '\n    free_var %x: bool\n    if (%x) {\n      free_var %y: float32\n      let %x1: float32 = add(%y, %y) /* ty=float32 */;\n      if (True /* ty=bool */) {\n        %x1\n      } else {\n        add(3f /* ty=float32 */, %x1) /* ty=float32 */\n      }\n    } else {\n      if (False /* ty=bool */) {\n        2f /* ty=float32 */\n      } else {\n        1f /* ty=float32 */\n      }\n    }\n    '
    expected_output = run_opt_pass(expected(), transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
    check_basic_block_normal_form(bblock)
def test_if():
    def if_expr(x):
        """
        free_var %x: float32
        %0 = equal(%x, 2f);
        if (%0) {
          %1 = add(%x, 1f);
          multiply(%1, 2f)
        } else {
          multiply(%1, 1f)
        }
        """
        one = relay.const(1, dtype="float32")
        two = relay.const(2, dtype="float32")
        v1 = relay.add(x, one)
        v2 = relay.equal(x, two)
        true_branch = relay.multiply(v1, two)
        false_branch = relay.multiply(v1, one)
        body = relay.If(v2, true_branch, false_branch)
        return body

    def expected_if_expr(x):
        """
        free_var %x: float32
        let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */;
        %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */;
        if (%0) {
          multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */
        } else {
          multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */
        }
        """
        one = relay.const(1, dtype="float32")
        two = relay.const(2, dtype="float32")
        v1 = relay.var("v1")
        v2 = relay.equal(x, two)
        true_branch = relay.multiply(v1, two)
        false_branch = relay.multiply(v1, one)
        body = relay.If(v2, true_branch, false_branch)
        body = relay.Let(v1, relay.add(x, one), body)
        return body

    x = relay.var("x", shape=(), dtype="float32")
    body = if_expr(x)
    expected_body = expected_if_expr(x)
    bblock = run_opt_pass(
        body, [transform.ToBasicBlockNormalForm(),
               transform.InferType()])
    expected_bblock = run_opt_pass(expected_body, transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True)
    check_basic_block_normal_form(bblock)

    func = relay.Function([x], body)
    expected_func = relay.Function([x], expected_body)
    bblock = run_opt_pass(
        func, [transform.ToBasicBlockNormalForm(),
               transform.InferType()])
    expected_bblock = run_opt_pass(expected_func, transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_bblock)
    check_basic_block_normal_form(bblock)
Exemple #5
0
 def test_let1_1():
     x = relay.Var('y')
     d = relay.const(4.0, 'float32')
     body = relay.Let(x, d, relay.add(x, x))
     body = run_opt_pass(body, transform.InferType())
     opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
     assert tvm.ir.structural_equal(body, opt_body)
     check_basic_block_normal_form(opt_body)
Exemple #6
0
 def test_let1():
     x = relay.Var('x')
     c = relay.const(4.0, 'float32')
     body = relay.Let(x, c, x)
     body = run_opt_pass(body, transform.InferType())
     '\n        let %x: float32 = 4f /* ty=float32 */;\n        %x\n        '
     opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
     assert tvm.ir.structural_equal(body, opt_body)
     check_basic_block_normal_form(opt_body)
Exemple #7
0
def test_function():
    t = relay.TensorType((), 'float32')
    x = relay.Var('x', t)
    f = relay.Function([x], (x + x))
    d = relay.const(4.0, 'float32')
    bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
    assert isinstance(bblock, relay.Function)
    check_eval(f(d), 8)
    check_eval(bblock(d), 8)
    check_basic_block_normal_form(bblock)
Exemple #8
0
def test_no_explicit_bind():
    x = relay.const(1)
    y = op.add(x, x)
    z = op.add(y, y)
    f = relay.Function([], op.add(z, z))
    '\n    fn () {\n      %0 = add(1, 1);\n      %1 = add(%0, %0);\n      add(%1, %1)\n    }\n    '
    assert (not (Feature.fLet in detect_feature(f)))
    bblock = run_opt_pass(f, transform.ToBasicBlockNormalForm())
    assert (Feature.fLet not in detect_feature(bblock))
    check_eval(f(), 8.0)
    check_eval(bblock(), 8.0)
    check_basic_block_normal_form(bblock)
Exemple #9
0
 def test_let3():
     x = relay.Var('x')
     y = relay.Var('y')
     z = relay.Var('z')
     c = relay.const(3.0, 'float32')
     d = relay.const(4.0, 'float32')
     body = relay.Let(z, (x + y), (x + z))
     body = relay.Let(x, d, body)
     body = relay.Let(y, c, body)
     body = run_opt_pass(body, transform.InferType())
     opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
     assert tvm.ir.structural_equal(body, opt_body)
     check_basic_block_normal_form(opt_body)
 def test_let3():
     x = relay.Var("x")
     y = relay.Var("y")
     z = relay.Var("z")
     c = relay.const(3.0, "float32")
     d = relay.const(4.0, "float32")
     body = relay.Let(z, x + y, x + z)
     body = relay.Let(x, d, body)
     body = relay.Let(y, c, body)
     body = run_opt_pass(body, transform.InferType())
     opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
     assert tvm.ir.structural_equal(body, opt_body)
     check_basic_block_normal_form(opt_body)
Exemple #11
0
def test_ref():
    i = relay.Var('i')
    iv = relay.Var('iv')
    u = relay.Var('u')
    uv = relay.Var('uv')
    body = relay.add(iv, uv)
    body = relay.Let(uv, relay.RefRead(i), body)
    body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
    body = relay.Let(iv, relay.RefRead(i), body)
    body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
    check_eval(body, 3)
    opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
    check_eval(opt_body, 3)
    check_basic_block_normal_form(opt_body)
Exemple #12
0
def test_gradient_if():
    x = relay.var('a', shape=(1, 16))
    y = relay.var('y', shape=(1, 16))
    cond = relay.var('cond', shape=(), dtype='uint1')
    net = relay.If(cond, x, x)
    net = relay.add(x, net)
    net = relay.Function([cond, x, y], net)
    mod = tvm.IRModule.from_expr(net)
    mod = relay.transform.ToBasicBlockNormalForm()(mod)
    net_grad = relay.transform.gradient(mod['main'], mode='higher_order')
    mod['main'] = net_grad
    mod_grad = relay.transform.ToBasicBlockNormalForm()(mod)
    check_basic_block_normal_form(mod_grad['main'])
    check_basic_block_normal_form(mod['main'])
Exemple #13
0
def test_gradient_if():
    x = relay.var("a", shape=(1, 16))
    y = relay.var("y", shape=(1, 16))
    cond = relay.var("cond", shape=(), dtype="uint1")
    net = relay.If(cond, x, x)
    net = relay.add(x, net)
    net = relay.Function([cond, x, y], net)
    mod = tvm.IRModule.from_expr(net)
    mod = relay.transform.ToBasicBlockNormalForm()(mod)
    net_grad = relay.transform.gradient(mod["main"], mode="higher_order")
    mod["main"] = net_grad
    mod_grad = relay.transform.ToBasicBlockNormalForm()(mod)
    check_basic_block_normal_form(mod_grad["main"])
    check_basic_block_normal_form(mod["main"])
Exemple #14
0
def test_recursion():
    '\n    Program:\n       let f(n: i32) -> i32 = {\n          m = (n * 2)\n          if (n == 0) {\n              return m;\n          } else {\n              return m + f(n - 1);\n          }\n       }\n       f(5);\n    '
    mod = tvm.IRModule()
    i64 = relay.TensorType((), 'int64')
    f = relay.GlobalVar('f')
    n = relay.Var('n', i64)
    m = (n * relay.const(2, 'int64'))
    cond = relay.equal(n, relay.const(0, 'int64'))
    false_branch = (m + f((n - relay.const(1, 'int64'))))
    funcbody = relay.If(cond, m, false_branch)
    value = relay.Function([n], funcbody, i64, [])
    mod[f] = value
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
    old_f = mod[f]
    mod = transform.ToBasicBlockNormalForm()(mod)
    f = mod[f]
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
    check_basic_block_normal_form(f)
Exemple #15
0
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context('llvm', 0)
    intrp = create_executor(mod=mod, ctx=ctx, target='llvm')
    assert (mod[add].checked_type == relay.FuncType([nat(), nat()], nat()))
    assert (count(p, intrp.evaluate(add(s(z()), s(z())))) == 2)
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar('f')
    mod[f] = relay.Function([], expr)
    mod = transform.ToBasicBlockNormalForm()(mod)
    opt_expr = mod['f']
    assert (count(p, intrp.evaluate(opt_expr.body)) == 2)
    assert (not (Feature.fLet in detect_feature(mod[add])))
    check_basic_block_normal_form(opt_expr)
Exemple #16
0
def test_invalid_if():
    cond = relay.var("cond", dtype="bool", shape=())
    shared = relay.var("shared")
    true_branch = shared
    false_branch = relay.add(shared, shared)
    body = relay.If(cond, true_branch, false_branch)
    """
    The program below violates basic block normal form, as the scope of %shared
    is ambiguous and should not be in that of true branch.

    free_var %cond: bool
    if (%cond) {
      free_var %shared
      %shared
    } else {
      add(%shared, %shared)
    }
    """
    check_basic_block_normal_form(body)
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    add = p.mod.get_global_var("nat_add")
    dev = tvm.device("llvm", 0)
    intrp = create_executor(mod=mod, device=dev, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.InferType()(mod)
    mod = transform.ToBasicBlockNormalForm()(mod)
    opt_expr = mod["f"]
    assert count(p, intrp.evaluate(opt_expr.body)) == 2
    assert not Feature.fLet in detect_feature(mod[add])
    check_basic_block_normal_form(opt_expr)
Exemple #18
0
    def test_let2():
        x = relay.Var('x')
        y = relay.Var('y')
        d = relay.const(4.0, 'float32')
        body = relay.Let(y, x, x)
        body = relay.Let(x, d, body)
        body = run_opt_pass(body, transform.InferType())
        check_eval(body, 4)

        def expected():
            x = relay.Var('x')
            y = relay.Var('y')
            d = relay.const(4.0, 'float32')
            body = relay.Let(y, x, y)
            body = relay.Let(x, d, body)
            return body
        opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm())
        expected_body = run_opt_pass(expected(), transform.InferType())
        assert tvm.ir.structural_equal(opt_body, expected_body)
        check_basic_block_normal_form(opt_body)
Exemple #19
0
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    # add_nat_definitions(p)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToBasicBlockNormalForm()(mod)
    opt_expr = mod["f"]
    assert count(p, intrp.evaluate(opt_expr.body)) == 2
    assert not Feature.fLet in detect_feature(mod[add])
    check_basic_block_normal_form(opt_expr)
Exemple #20
0
def test_higher_order_return():
    x = relay.var("x", shape=(1, ), dtype="float32")  # , a)
    y = relay.var("y", shape=(1, ), dtype="float32")  # , a)
    z = relay.var("z", shape=(1, ), dtype="float32")  # , a)
    x2 = relay.add(x, x)
    func_a = relay.Function([y], relay.add(x2, y))  # , a, [a])
    func_b = relay.Function([z], relay.add(x2, z))  # , a, [a])
    body = relay.Tuple([func_a, func_b])
    body = relay.Function([x], body)
    """
    fn (%x: Tensor[(1), float32]) {
      %1 = fn (%y: Tensor[(1), float32]) {
        %0 = add(%x, %x);
        add(%0, %y)
      };
      %2 = fn (%z: Tensor[(1), float32]) {
        add(%0, %z)
      };
      (%1, %2)
    }
    """
    check_basic_block_normal_form(body)
Exemple #21
0
def test_invalid_if2():
    """
    fn (%x: float32) {
      %0 = equal(%x, 2f);
      if (%0) {
        %1 = add(%x, 1f);
        multiply(%1, 2f)
      } else {
        multiply(%1, 1f)
      }
    }
    """
    x = relay.var("x", shape=(), dtype="float32")
    one = relay.const(1, dtype="float32")
    two = relay.const(2, dtype="float32")
    v1 = relay.add(x, one)
    v2 = relay.equal(x, two)
    true_branch = relay.multiply(v1, two)
    false_branch = relay.multiply(v1, one)
    body = relay.If(v2, true_branch, false_branch)
    func = relay.Function([x], body)
    check_basic_block_normal_form(func)
Exemple #22
0
def test_func():
    x = relay.var('x', shape=(1,), dtype='float32')#, a)
    y = relay.var('y', shape=(1,), dtype='float32')#, a)
    z = relay.var('z', shape=(1,), dtype='float32')#, a)
    x2 = relay.add(x, x)
    func_a = relay.Function([y], relay.add(x2, y)) #, a, [a])
    func_b = relay.Function([z], relay.add(x2, z)) #, a, [a])
    body = relay.Tuple([func_a, func_b])
    body = relay.Function([x], body)
    """
    fn (%x: Tensor[(1), float32]) {
      %1 = fn (%y: Tensor[(1), float32]) {
        %0 = add(%x, %x);
        add(%0, %y)
      };
      %2 = fn (%z: Tensor[(1), float32]) {
        add(%0, %z)
      };
      (%1, %2)
    }
    """
    check_basic_block_normal_form(body)
Exemple #23
0
def test_valid_if():
    cond = relay.var("cond", dtype="bool", shape=())
    shared = relay.var("shared")
    true_branch = shared
    false_branch = relay.add(shared, shared)
    body = relay.If(cond, true_branch, false_branch)
    shared_bound = relay.var("shared_bound", shape=(1, ), dtype="float32")
    body = relay.Let(shared, shared_bound, body)
    """
    The program below uses let binding to control the scope of %shared, which
    follows the basic block normal form.

    free_var %shared_bound: Tensor[(1), float32]
    let %shared = %shared_bound;
    free_var %cond: bool
    if (%cond) {
      %shared
    } else {
      add(%shared, %shared)
    }
    """
    check_basic_block_normal_form(body)
Exemple #24
0
def test_valid_if2():
    """
    fn (%x: float32) {
      let %v1 = add(%x, 1f);
      %0 = equal(%x, 2f);
      if (%0) {
        multiply(%v1, 2f)
      } else {
        multiply(%v1, 1f)
      }
    }
    """
    x = relay.var('x', shape=(), dtype='float32')
    one = relay.const(1, dtype='float32')
    two = relay.const(2, dtype='float32')
    v1 = relay.var('v1')
    v2 = relay.equal(x, two)
    true_branch = relay.multiply(v1, two)
    false_branch = relay.multiply(v1, one)
    body = relay.If(v2, true_branch, false_branch)
    body = relay.Let(v1, relay.add(x, one), body)
    func = relay.Function([x], body)
    check_basic_block_normal_form(func)
Exemple #25
0
def test_if():

    def if_expr(x):
        '\n        free_var %x: float32\n        %0 = equal(%x, 2f);\n        if (%0) {\n          %1 = add(%x, 1f);\n          multiply(%1, 2f)\n        } else {\n          multiply(%1, 1f)\n        }\n        '
        one = relay.const(1, dtype='float32')
        two = relay.const(2, dtype='float32')
        v1 = relay.add(x, one)
        v2 = relay.equal(x, two)
        true_branch = relay.multiply(v1, two)
        false_branch = relay.multiply(v1, one)
        body = relay.If(v2, true_branch, false_branch)
        return body

    def expected_if_expr(x):
        '\n        free_var %x: float32\n        let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */;\n        %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */;\n        if (%0) {\n          multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */\n        } else {\n          multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */\n        }\n        '
        one = relay.const(1, dtype='float32')
        two = relay.const(2, dtype='float32')
        v1 = relay.var('v1')
        v2 = relay.equal(x, two)
        true_branch = relay.multiply(v1, two)
        false_branch = relay.multiply(v1, one)
        body = relay.If(v2, true_branch, false_branch)
        body = relay.Let(v1, relay.add(x, one), body)
        return body
    x = relay.var('x', shape=(), dtype='float32')
    body = if_expr(x)
    expected_body = expected_if_expr(x)
    bblock = run_opt_pass(body, transform.ToBasicBlockNormalForm())
    expected_bblock = run_opt_pass(expected_body, transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True)
    check_basic_block_normal_form(bblock)
    func = relay.Function([x], body)
    expected_func = relay.Function([x], expected_body)
    bblock = run_opt_pass(func, transform.ToBasicBlockNormalForm())
    expected_bblock = run_opt_pass(expected_func, transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_bblock)
    check_basic_block_normal_form(bblock)
Exemple #26
0
def test_one_block():
    x = relay.var("x")
    y = relay.add(x, x)
    z = relay.add(x, y)
    check_basic_block_normal_form(z)
Exemple #27
0
def test_let():
    x = relay.var('x')
    y = relay.var('y')
    body = relay.Let(y, x, y)
    check_basic_block_normal_form(body)
def test_nested_if():
    x = relay.var("x", shape=(), dtype="bool")
    y = relay.var("y", shape=(), dtype="float32")
    cond_t = relay.const(True)
    cond_f = relay.const(False)
    one = relay.const(1, dtype="float32")
    two = relay.const(2, dtype="float32")
    three = relay.const(3, dtype="float32")
    y2 = relay.add(y, y)
    true_branch = relay.If(cond_t, y2, relay.add(three, y2))
    false_branch = relay.If(cond_f, two, one)
    body = relay.If(x, true_branch, false_branch)
    """
    free_var %x: bool
    if (%x) {
      if (True) {
        free_var %y: float32
        %0 = add(%y, %y);
        %0
      } else {
        add(3f, %0)
      }
    } else {
      if (False) {
        2f
      } else {
        1f
      }
    }
    """
    def expected():
        x = relay.var("x", shape=(), dtype="bool")
        y = relay.var("y", shape=(), dtype="float32")
        cond_t = relay.const(True)
        cond_f = relay.const(False)
        one = relay.const(1, dtype="float32")
        two = relay.const(2, dtype="float32")
        three = relay.const(3, dtype="float32")
        y2 = relay.var("y2")
        true_branch = relay.If(cond_t, y2, relay.add(three, y2))
        true_branch = relay.Let(y2, relay.add(y, y), true_branch)
        false_branch = relay.If(cond_f, two, one)
        body = relay.If(x, true_branch, false_branch)
        return body

    bblock = run_opt_pass(
        body, [transform.ToBasicBlockNormalForm(),
               transform.InferType()])
    """
    free_var %x: bool
    if (%x) {
      free_var %y: float32
      let %x1: float32 = add(%y, %y) /* ty=float32 */;
      if (True /* ty=bool */) {
        %x1
      } else {
        add(3f /* ty=float32 */, %x1) /* ty=float32 */
      }
    } else {
      if (False /* ty=bool */) {
        2f /* ty=float32 */
      } else {
        1f /* ty=float32 */
      }
    }
    """
    expected_output = run_opt_pass(expected(), transform.InferType())
    assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True)
    check_basic_block_normal_form(bblock)
Exemple #29
0
def test_let():
    x = relay.var("x")
    y = relay.var("y")
    body = relay.Let(y, x, y)
    check_basic_block_normal_form(body)
Exemple #30
0

def if_expr(x):
    '\n        free_var %x: float32\n        %0 = equal(%x, 2f);\n        if (%0) {\n          %1 = add(%x, 1f);\n          multiply(%1, 2f)\n        } else {\n          multiply(%1, 1f)\n        }\n        '
    one = relay.const(1, dtype='float32')
    two = relay.const(2, dtype='float32')
    v1 = relay.add(x, one)
    v2 = relay.equal(x, two)
    true_branch = relay.multiply(v1, two)
    false_branch = relay.multiply(v1, one)
    body = relay.If(v2, true_branch, false_branch)
    return body


def expected_if_expr(x):
    '\n        free_var %x: float32\n        let %v1: float32 = add(%x, 1f /* ty=float32 */) /* ty=float32 */;\n        %0 = equal(%x, 2f /* ty=float32 */) /* ty=bool */;\n        if (%0) {\n          multiply(%v1, 2f /* ty=float32 */) /* ty=float32 */\n        } else {\n          multiply(%v1, 1f /* ty=float32 */) /* ty=float32 */\n        }\n        '
    one = relay.const(1, dtype='float32')
    two = relay.const(2, dtype='float32')
    v1 = relay.var('v1')
    v2 = relay.equal(x, two)
    true_branch = relay.multiply(v1, two)
    false_branch = relay.multiply(v1, one)
    body = relay.If(v2, true_branch, false_branch)
    body = relay.Let(v1, relay.add(x, one), body)
    return body
ZmncS=relay.var('''y0''',shape=(1,4),dtype='''int16''')
SdJ2V=expected_if_expr(ZmncS)
P30df=transform.ToBasicBlockNormalForm()
WJw4J=run_opt_pass(SdJ2V,P30df)
check_basic_block_normal_form(WJw4J)