def test_recursion():
    """
    Program:
       let sum_twice(n: i32) -> i32 = {
          m = (n * 2)
          if (n == 0) {
              return m;
          } else {
              return m + sum(n - 1);
          }
       }
       sum_twice(5);
    """
    return  # cannot be run as fuse_ops need to recursively visit
    mod = relay.Module()
    i64 = relay.TensorType((), 'int64')
    f = relay.GlobalVar("f")
    n = relay.Var("n", i64)
    m = n * relay.const(2, 'int64')
    funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), m,
                        m + f(n - relay.const(1, 'int64')))
    value = relay.Function([n], funcbody, i64, [])
    mod[f] = value
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
    old_f = mod[f]
    mod = transform.ToANormalForm()(mod)
    f = mod[f]
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
def test_recursion():
    """
    Program:
       let f(n: i32) -> i32 = {
          m = (n * 2)
          if (n == 0) {
              return m;
          } else {
              return m + f(n - 1);
          }
       }
       f(5);
    """
    mod = tvm.IRModule()
    i64 = relay.TensorType((), 'int64')
    f = relay.GlobalVar("f")
    n = relay.Var("n", i64)
    m = n * relay.const(2, 'int64')
    funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), m,
                        m + f(n - relay.const(1, 'int64')))
    value = relay.Function([n], funcbody, i64, [])
    mod[f] = value
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
    old_f = mod[f]
    mod = transform.ToANormalForm()(mod)
    f = mod[f]
    check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
Example #3
0
def test_let_as_subexpr():
    def on_cpu(x):
        return relay.annotation.on_device(x, tvm.device("cpu"), constrain_result=True)

    x = relay.Var("x", relay.IncompleteType())
    c = relay.const(1)
    l = relay.Let(x, on_cpu(c + c), x)
    body = l * l

    anf = run_opt_pass(body, [transform.ToANormalForm(), transform.InferType()])

    v0 = relay.Var("v0", relay.IncompleteType())
    v1 = relay.Var("v1", relay.IncompleteType())
    v2 = relay.Var("v2", relay.IncompleteType())
    expected_output = relay.Let(
        v0,
        on_cpu(c),
        relay.Let(
            x,
            on_cpu(v0 + v0),
            relay.Let(v1, x, relay.Let(v2, v1 * v1, v2)),
        ),
    )
    expected_output = run_opt_pass(expected_output, transform.InferType())

    tvm.ir.assert_structural_equal(anf, expected_output)
Example #4
0
def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx,
                 count_meta):
    """We assume stop_name only appears once for simplicity.
    This constraint will be lifted in the future.
    bitpack_start and bitpack_end are both inclusive.
    """
    bitpack_start = op.op.get("annotation.bitpack_start")
    bitpack_end = op.op.get("annotation.bitpack_end")
    anf = run_opt_pass(expr, transform.ToANormalForm())
    operator_current_idx = 0

    def _recursion(anf, start_found, stop_found, operator_current_idx):
        """Helper to obtain the subgraph."""
        if isinstance(anf, relay.Function):
            return relay.Function(
                anf.params,
                _recursion(anf.body, start_found, stop_found,
                           operator_current_idx),
                anf.ret_type,
                anf.type_params,
                anf.attrs,
            )
        if isinstance(anf, relay.expr.Let):
            value = anf.value
            if isinstance(value, relay.expr.Call):
                if isinstance(value.op, tvm.ir.Op):
                    if value.op.name == start_name and not start_found:
                        if operator_current_idx == start_name_idx or start_name_idx is None:
                            value = relay.expr.Call(bitpack_start, [value])
                            start_found = True
                    elif value.op.name == stop_name:
                        if operator_current_idx == stop_name_idx or stop_name_idx is None:
                            raise BT()

            operator_current_idx = _operator_idx_inc(value, count_meta,
                                                     operator_current_idx)

            try:
                return relay.expr.Let(
                    anf.var,
                    value,
                    _recursion(anf.body, start_found, stop_found,
                               operator_current_idx),
                )
            except BT:
                assert start_found
                assert not stop_found
                stop_found = True
                value = relay.expr.Call(bitpack_end, [value])
                # todo: check anf.body has no more stop_name beside that one
                return relay.expr.Let(anf.var, value, anf.body)
        else:
            assert start_found
            assert stop_found
            return anf

    annotated = _recursion(anf, False, False, operator_current_idx)
    return run_opt_pass(annotated, transform.ToGraphNormalForm())
def test_function():
    t = relay.TensorType((), 'float32')
    x = relay.Var("x", t)
    f = relay.Function([x], x + x)
    d = relay.const(4.0, 'float32')
    anf_f = run_opt_pass(f, transform.ToANormalForm())
    assert isinstance(anf_f, relay.Function)
    check_eval(f(d), 8)
    check_eval(anf_f(d), 8)
def test_let():
    x = relay.Var("x")
    y = relay.Var("y")
    d = relay.const(4.0, 'float32')
    body = relay.Let(y, x, x + y)
    body = relay.Let(x, d, body)
    check_eval(body, 8)
    opt_body = run_opt_pass(body, transform.ToANormalForm())
    check_eval(opt_body, 8)
def test_explicit_bound():
    x = relay.const(1)
    y = op.add(x, x)
    z = op.add(y, y)
    f = relay.Function([], op.add(z, z))
    assert not Feature.fLet in detect_feature(f)
    anf = run_opt_pass(f, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(anf)
    check_eval(f(), 8.0)
    check_eval(anf(), 8.0)
def test_ref():
    i = relay.Var('i')
    iv = relay.Var('iv')
    u = relay.Var('u')
    uv = relay.Var('uv')
    body = relay.add(iv, uv)
    body = relay.Let(uv, relay.RefRead(i), body)
    body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
    body = relay.Let(iv, relay.RefRead(i), body)
    body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
    check_eval(body, 3)
    opt_body = run_opt_pass(body, transform.ToANormalForm())
    check_eval(opt_body, 3)
def test_round_trip():
    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = run_opt_pass(f, transform.ToGraphNormalForm())
    h = run_opt_pass(g, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
    check_eval(h, [], 8.0)
def test_if():
    cond = relay.const(True)
    x = relay.If(cond, relay.const(2), relay.const(3))
    anf = run_opt_pass(x, [transform.ToANormalForm(), transform.InferType()])
    a = relay.Var('a', relay.IncompleteType())
    b = relay.Var('b', relay.IncompleteType())
    c = relay.Var('c', relay.IncompleteType())
    d = relay.Var('d', relay.IncompleteType())
    true_branch = relay.Let(a, relay.const(2), a)
    false_branch = relay.Let(b, relay.const(3), b)
    expected_output = relay.If(c, true_branch, false_branch)
    expected_output = relay.Let(d, expected_output, d)
    expected_output = relay.Let(c, cond, expected_output)
    expected_output = run_opt_pass(expected_output, transform.InferType())
    assert tvm.ir.structural_equal(anf, expected_output)
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    p.mod.import_from_std("nat.rly")
    nat, z, s = p.mod.get_type("nat")
    add = p.mod.get_global_var("nat_add")
    dev = tvm.device("llvm", 0)
    intrp = create_executor(mod=mod, device=dev, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToANormalForm()(mod)
    expr = mod["f"]
    assert count(p, intrp.evaluate(expr.body)) == 2
    assert Feature.fLet in detect_feature(mod[add])
Example #12
0
def test_if():
    cond = relay.const(True)
    x = relay.If(cond, relay.const(2), relay.const(3))
    anf = transform.OptimizeOnExpr(
        x, [transform.ToANormalForm(),
            transform.InferType()])
    a = relay.Var('a', relay.IncompleteType())
    b = relay.Var('b', relay.IncompleteType())
    c = relay.Var('c', relay.IncompleteType())
    d = relay.Var('d', relay.IncompleteType())
    true_branch = relay.Let(a, relay.const(2), a)
    false_branch = relay.Let(b, relay.const(3), b)
    expected_output = relay.If(c, true_branch, false_branch)
    expected_output = relay.Let(d, expected_output, d)
    expected_output = relay.Let(c, cond, expected_output)
    expected_output = transform.OptimizeOnExpr(expected_output,
                                               transform.InferType())
    assert alpha_equal(anf, expected_output)
def test_nat_add():
    mod = tvm.IRModule()
    p = Prelude(mod)
    add_nat_definitions(p)
    nat = p.nat
    add = p.add
    s = p.s
    z = p.z
    ctx = tvm.context("llvm", 0)
    intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
    assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
    assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
    expr = add(s(z()), s(z()))
    f = relay.GlobalVar("f")
    mod[f] = relay.Function([], expr)
    mod = transform.ToANormalForm()(mod)
    expr = mod["f"]
    assert count(p, intrp.evaluate(expr.body)) == 2
    assert Feature.fLet in detect_feature(mod[add])
def test_order():
    z = relay.const(3)
    y = relay.const(2)
    x = relay.const(1)
    val = x + y * z
    check_eval(val, 7.0)
    anf = run_opt_pass(val, [transform.ToANormalForm(), transform.InferType()])
    a = relay.Var('a', relay.IncompleteType())
    b = relay.Var('b', relay.IncompleteType())
    c = relay.Var('c', relay.IncompleteType())
    d = relay.Var('d', relay.IncompleteType())
    e = relay.Var('e', relay.IncompleteType())
    expected_output = e
    expected_output = relay.Let(e, a + d, expected_output)
    expected_output = relay.Let(d, b * c, expected_output)
    expected_output = relay.Let(c, z, expected_output)
    expected_output = relay.Let(b, y, expected_output)
    expected_output = relay.Let(a, x, expected_output)
    expected_output = run_opt_pass(expected_output, transform.InferType())
    assert tvm.ir.structural_equal(anf, expected_output)
Example #15
0
def test_order():
    z = relay.const(3)
    y = relay.const(2)
    x = relay.const(1)
    val = x + y * z
    check_eval(val, 7.0)
    anf = transform.OptimizeOnExpr(
        val, [transform.ToANormalForm(),
              transform.InferType()])
    a = relay.Var('a', relay.IncompleteType())
    b = relay.Var('b', relay.IncompleteType())
    c = relay.Var('c', relay.IncompleteType())
    d = relay.Var('d', relay.IncompleteType())
    e = relay.Var('e', relay.IncompleteType())
    expected_output = e
    expected_output = relay.Let(e, a + d, expected_output)
    expected_output = relay.Let(d, b * c, expected_output)
    expected_output = relay.Let(c, z, expected_output)
    expected_output = relay.Let(b, y, expected_output)
    expected_output = relay.Let(a, x, expected_output)
    expected_output = transform.OptimizeOnExpr(expected_output,
                                               transform.InferType())
    assert alpha_equal(anf, expected_output)
Example #16
0
def merge_transform_to_mxnet_model(mod):
    """ Add Image Transform Logic Into Model """
    svalue = np.array([123., 117., 104.])
    sub_data = relay.Constant(tvm.nd.array(svalue)).astype("float32")
    dvalue = np.array([58.395, 57.12, 57.37])
    divide_data = relay.Constant(tvm.nd.array(dvalue)).astype("float32")

    data_shape = (224, 224, 3)
    data = relay.var("data", relay.TensorType(data_shape, "float32"))

    simple_net = relay.expand_dims(data, axis=0, num_newaxis=1)
    # To do, relay not support dynamic shape now, future need to add resize logic
    # simple_net = relay.image.resize(simple_net, (224, 224), "NHWC", "bilinear", "align_corners")
    simple_net = relay.subtract(simple_net, sub_data)
    simple_net = relay.divide(simple_net, divide_data)
    simple_net = relay.transpose(simple_net, ((0, 3, 1, 2)))

    #merge tranform into pretrained model network
    entry = mod["main"]
    anf = run_opt_pass(entry.body, transform.ToANormalForm())
    call = anf.value
    data, weights = call.args
    first_op = op.nn.conv2d(simple_net,
                            weights,
                            strides=call.attrs.strides,
                            padding=call.attrs.padding,
                            dilation=call.attrs.dilation,
                            groups=call.attrs.groups,
                            channels=call.attrs.channels,
                            kernel_size=call.attrs.kernel_size,
                            out_dtype=call.attrs.out_dtype)
    net = relay.expr.Let(anf.var, first_op, anf.body)
    net = run_opt_pass(net, transform.ToGraphNormalForm())

    mod['main'] = net
    return mod
Example #17
0
def graph_split(expr, split_conf, params=None):
    """Splitting the graph into a list of subgraphs"""

    def get_dep_var(sub_var_dep):
        return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]]

    def parse_dependency(value, snode_dep, new_input_idx):
        new_args = []
        need_update = False
        for var in value.args:
            is_free_var = False
            for dep in snode_dep[:-1]:
                if var in dep["nodes"]:
                    # Mark the previous subgraph node as a dependency.
                    dep["nodes"][var] += 1
                    dep["ref_nodes"][var] = dep["nodes"][var]
                    # The var of this call is a free_var
                    is_free_var = True
            # if the var of this call is a free_var, recreate it and give it a fixed input name.
            if is_free_var:
                need_update = True
                new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type))
                new_input_idx += 1
            else:
                new_args.append(var)
        # if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'.
        if need_update:
            value = tvm.relay.expr.Call(
                value.op, new_args, value.attrs, value.type_args, value.span
            )
        return value, snode_dep, new_input_idx

    def merge_constant_expr(constant_expr, expr):
        # merge constant express with a express
        if not isinstance(constant_expr.body, tvm.relay.expr.Let):
            return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr)

        return tvm.relay.expr.Let(
            constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr)
        )

    def _recursion(anf, pipeline_mods, split_conf, constant_expr):
        # Enumurate all operators of compute graph, then split the compute graph into a group of
        # subgraph.
        nonlocal operator_index_map
        nonlocal new_input_idx
        nonlocal snode_dep
        cur_node_dep = snode_dep[len(snode_dep) - 1]
        if isinstance(anf, tvm.relay.Function):
            return tvm.relay.Function(
                anf.params,
                _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
                anf.ret_type,
                anf.type_params,
                anf.attrs,
            )
        if isinstance(anf, tvm.relay.expr.Let):
            value = anf.value
            # record the constant expr to make sure all sugraphs can find correct constant.
            if isinstance(value, tvm.relay.expr.Constant):
                if not constant_expr:
                    constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
                else:
                    constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr)
            if isinstance(value, tvm.relay.expr.Call):
                new_args = []
                # build current var list
                cur_node_dep["nodes"][anf.var] = 0
                # Get the dependency information of the nodes.
                value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx)
                if isinstance(value.op, tvm.ir.Op):
                    if value.op.name in operator_index_map:
                        operator_index_map[value.op.name] += 1
                    else:
                        operator_index_map[value.op.name] = 0
                    split_operator_name = split_conf[0]["op_name"] if split_conf else ""
                    split_operator_index = split_conf[0]["op_index"] if split_conf else ""
                    # if a operator name and repeating count in the network match with the values
                    # of the 'split configuration', then this place is where we should do the
                    # graph splitting.
                    if (
                        split_conf
                        and split_operator_name in operator_index_map
                        and operator_index_map[split_operator_name] >= split_operator_index
                    ):
                        # Do graph splitting.
                        split_conf.pop(0)
                        snode_dep.append({"nodes": {}, "ref_nodes": {}})
                        ann = _recursion(
                            anf.body,
                            pipeline_mods,
                            split_conf,
                            constant_expr,
                        )
                        snode_dep.pop()
                        dep_vars = get_dep_var(snode_dep)
                        # When the nodes of the current subgraph are the depedency node of another
                        # subgraph, we need to set them as the output of current subgraph.
                        body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var
                        # when the operator of current subgraph uses previous subgraph constant
                        # as the argument of a "relay.expr.call", such constant may become a free
                        # varaible if the constant does not exist in the current subgraph.
                        # merge the previous constant with current subgraph to avoid such issue.
                        if constant_expr:
                            ann = merge_constant_expr(constant_expr, ann)
                        ann = run_opt_pass(ann, transform.ToGraphNormalForm())
                        mod = tvm.IRModule.from_expr(ann)
                        pipeline_mods.insert(0, mod)
                        # Return the last node of the current subgraph.
                        return tvm.relay.expr.Let(anf.var, value, body)
            return tvm.relay.expr.Let(
                anf.var,
                value,
                _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
            )
        else:
            return anf

    snode_dep = [{"nodes": {}, "ref_nodes": {}}]
    pipeline_mods = []
    operator_index_map = {}
    # Used to tracking new input which caused by graph splitting.
    new_input_idx = 0
    constant_expr = None
    subgraph_split_conf = split_conf.copy()
    # Binding the parameters.
    if params:
        expr = build_module.bind_params_by_name(expr, params)
    anf = run_opt_pass(expr, transform.ToANormalForm())
    anf = run_opt_pass(anf, transform.InferType())
    ann = _recursion(
        anf,
        pipeline_mods,
        subgraph_split_conf,
        constant_expr,
    )
    ann = run_opt_pass(ann.body, transform.ToGraphNormalForm())
    mod = tvm.IRModule.from_expr(ann)
    pipeline_mods.insert(0, mod)
    return pipeline_mods
Example #18
0
def test_nat_update():
    m = tvm.IRModule()
    p = Prelude(m)
    p.mod.import_from_std("nat.rly")
    m = transform.ToANormalForm()(m)
    transform.PartialEvaluate()(m)
def test_nat_update():
    m = tvm.IRModule()
    p = Prelude(m)
    add_nat_definitions(p)
    m = transform.ToANormalForm()(m)
    transform.PartialEvaluate()(m)