Exemple #1
0
def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx,
                 count_meta):
    """We assume stop_name only appears once for simplicity.
    This constraint will be lifted in the future.
    bitpack_start and bitpack_end are both inclusive.
    """
    bitpack_start = op.op.get("annotation.bitpack_start")
    bitpack_end = op.op.get("annotation.bitpack_end")
    anf = run_opt_pass(expr, transform.ToANormalForm())
    operator_current_idx = 0

    def _recursion(anf, start_found, stop_found, operator_current_idx):
        """Helper to obtain the subgraph."""
        if isinstance(anf, relay.Function):
            return relay.Function(
                anf.params,
                _recursion(anf.body, start_found, stop_found,
                           operator_current_idx),
                anf.ret_type,
                anf.type_params,
                anf.attrs,
            )
        if isinstance(anf, relay.expr.Let):
            value = anf.value
            if isinstance(value, relay.expr.Call):
                if isinstance(value.op, tvm.ir.Op):
                    if value.op.name == start_name and not start_found:
                        if operator_current_idx == start_name_idx or start_name_idx is None:
                            value = relay.expr.Call(bitpack_start, [value])
                            start_found = True
                    elif value.op.name == stop_name:
                        if operator_current_idx == stop_name_idx or stop_name_idx is None:
                            raise BT()

            operator_current_idx = _operator_idx_inc(value, count_meta,
                                                     operator_current_idx)

            try:
                return relay.expr.Let(
                    anf.var,
                    value,
                    _recursion(anf.body, start_found, stop_found,
                               operator_current_idx),
                )
            except BT:
                assert start_found
                assert not stop_found
                stop_found = True
                value = relay.expr.Call(bitpack_end, [value])
                # todo: check anf.body has no more stop_name beside that one
                return relay.expr.Let(anf.var, value, anf.body)
        else:
            assert start_found
            assert stop_found
            return anf

    annotated = _recursion(anf, False, False, operator_current_idx)
    return run_opt_pass(annotated, transform.ToGraphNormalForm())
def test_implicit_share():
    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = run_opt_pass(f, transform.ToGraphNormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
Exemple #3
0
def test_round_trip():
    x = relay.Var("x")
    y = relay.Var("y")
    z = relay.Var("z")
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = run_opt_pass(f, transform.ToGraphNormalForm())
    h = run_opt_pass(g, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
    check_eval(h, [], 8.0)
Exemple #4
0
def test_round_trip():
    x = relay.Var('x')
    y = relay.Var('y')
    z = relay.Var('z')
    body = relay.Let(z, op.add(y, y), op.add(z, z))
    body = relay.Let(y, op.add(x, x), body)
    f = relay.Function([], relay.Let(x, relay.const(1), body))
    g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm())
    h = transform.OptimizeOnExpr(g, transform.ToANormalForm())
    assert Feature.fLet in detect_feature(f)
    assert not Feature.fLet in detect_feature(g)
    check_eval(f, [], 8.0)
    check_eval(g, [], 8.0)
    check_eval(h, [], 8.0)
Exemple #5
0
def merge_transform_to_mxnet_model(mod):
    """ Add Image Transform Logic Into Model """
    svalue = np.array([123., 117., 104.])
    sub_data = relay.Constant(tvm.nd.array(svalue)).astype("float32")
    dvalue = np.array([58.395, 57.12, 57.37])
    divide_data = relay.Constant(tvm.nd.array(dvalue)).astype("float32")

    data_shape = (224, 224, 3)
    data = relay.var("data", relay.TensorType(data_shape, "float32"))

    simple_net = relay.expand_dims(data, axis=0, num_newaxis=1)
    # To do, relay not support dynamic shape now, future need to add resize logic
    # simple_net = relay.image.resize(simple_net, (224, 224), "NHWC", "bilinear", "align_corners")
    simple_net = relay.subtract(simple_net, sub_data)
    simple_net = relay.divide(simple_net, divide_data)
    simple_net = relay.transpose(simple_net, ((0, 3, 1, 2)))

    #merge tranform into pretrained model network
    entry = mod["main"]
    anf = run_opt_pass(entry.body, transform.ToANormalForm())
    call = anf.value
    data, weights = call.args
    first_op = op.nn.conv2d(simple_net,
                            weights,
                            strides=call.attrs.strides,
                            padding=call.attrs.padding,
                            dilation=call.attrs.dilation,
                            groups=call.attrs.groups,
                            channels=call.attrs.channels,
                            kernel_size=call.attrs.kernel_size,
                            out_dtype=call.attrs.out_dtype)
    net = relay.expr.Let(anf.var, first_op, anf.body)
    net = run_opt_pass(net, transform.ToGraphNormalForm())

    mod['main'] = net
    return mod
Exemple #6
0
 def _recursion(anf, pipeline_mods, split_conf, constant_expr):
     # Enumurate all operators of compute graph, then split the compute graph into a group of
     # subgraph.
     nonlocal operator_index_map
     nonlocal new_input_idx
     nonlocal snode_dep
     cur_node_dep = snode_dep[len(snode_dep) - 1]
     if isinstance(anf, tvm.relay.Function):
         return tvm.relay.Function(
             anf.params,
             _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
             anf.ret_type,
             anf.type_params,
             anf.attrs,
         )
     if isinstance(anf, tvm.relay.expr.Let):
         value = anf.value
         # record the constant expr to make sure all sugraphs can find correct constant.
         if isinstance(value, tvm.relay.expr.Constant):
             if not constant_expr:
                 constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
             else:
                 constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr)
         if isinstance(value, tvm.relay.expr.Call):
             new_args = []
             # build current var list
             cur_node_dep["nodes"][anf.var] = 0
             # Get the dependency information of the nodes.
             value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx)
             if isinstance(value.op, tvm.ir.Op):
                 if value.op.name in operator_index_map:
                     operator_index_map[value.op.name] += 1
                 else:
                     operator_index_map[value.op.name] = 0
                 split_operator_name = split_conf[0]["op_name"] if split_conf else ""
                 split_operator_index = split_conf[0]["op_index"] if split_conf else ""
                 # if a operator name and repeating count in the network match with the values
                 # of the 'split configuration', then this place is where we should do the
                 # graph splitting.
                 if (
                     split_conf
                     and split_operator_name in operator_index_map
                     and operator_index_map[split_operator_name] >= split_operator_index
                 ):
                     # Do graph splitting.
                     split_conf.pop(0)
                     snode_dep.append({"nodes": {}, "ref_nodes": {}})
                     ann = _recursion(
                         anf.body,
                         pipeline_mods,
                         split_conf,
                         constant_expr,
                     )
                     snode_dep.pop()
                     dep_vars = get_dep_var(snode_dep)
                     # When the nodes of the current subgraph are the depedency node of another
                     # subgraph, we need to set them as the output of current subgraph.
                     body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var
                     # when the operator of current subgraph uses previous subgraph constant
                     # as the argument of a "relay.expr.call", such constant may become a free
                     # varaible if the constant does not exist in the current subgraph.
                     # merge the previous constant with current subgraph to avoid such issue.
                     if constant_expr:
                         ann = merge_constant_expr(constant_expr, ann)
                     ann = run_opt_pass(ann, transform.ToGraphNormalForm())
                     mod = tvm.IRModule.from_expr(ann)
                     pipeline_mods.insert(0, mod)
                     # Return the last node of the current subgraph.
                     return tvm.relay.expr.Let(anf.var, value, body)
         return tvm.relay.expr.Let(
             anf.var,
             value,
             _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
         )
     else:
         return anf
Exemple #7
0
def graph_split(expr, split_conf, params=None):
    """Splitting the graph into a list of subgraphs"""

    def get_dep_var(sub_var_dep):
        return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]]

    def parse_dependency(value, snode_dep, new_input_idx):
        new_args = []
        need_update = False
        for var in value.args:
            is_free_var = False
            for dep in snode_dep[:-1]:
                if var in dep["nodes"]:
                    # Mark the previous subgraph node as a dependency.
                    dep["nodes"][var] += 1
                    dep["ref_nodes"][var] = dep["nodes"][var]
                    # The var of this call is a free_var
                    is_free_var = True
            # if the var of this call is a free_var, recreate it and give it a fixed input name.
            if is_free_var:
                need_update = True
                new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type))
                new_input_idx += 1
            else:
                new_args.append(var)
        # if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'.
        if need_update:
            value = tvm.relay.expr.Call(
                value.op, new_args, value.attrs, value.type_args, value.span
            )
        return value, snode_dep, new_input_idx

    def merge_constant_expr(constant_expr, expr):
        # merge constant express with a express
        if not isinstance(constant_expr.body, tvm.relay.expr.Let):
            return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr)

        return tvm.relay.expr.Let(
            constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr)
        )

    def _recursion(anf, pipeline_mods, split_conf, constant_expr):
        # Enumurate all operators of compute graph, then split the compute graph into a group of
        # subgraph.
        nonlocal operator_index_map
        nonlocal new_input_idx
        nonlocal snode_dep
        cur_node_dep = snode_dep[len(snode_dep) - 1]
        if isinstance(anf, tvm.relay.Function):
            return tvm.relay.Function(
                anf.params,
                _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
                anf.ret_type,
                anf.type_params,
                anf.attrs,
            )
        if isinstance(anf, tvm.relay.expr.Let):
            value = anf.value
            # record the constant expr to make sure all sugraphs can find correct constant.
            if isinstance(value, tvm.relay.expr.Constant):
                if not constant_expr:
                    constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
                else:
                    constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr)
            if isinstance(value, tvm.relay.expr.Call):
                new_args = []
                # build current var list
                cur_node_dep["nodes"][anf.var] = 0
                # Get the dependency information of the nodes.
                value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx)
                if isinstance(value.op, tvm.ir.Op):
                    if value.op.name in operator_index_map:
                        operator_index_map[value.op.name] += 1
                    else:
                        operator_index_map[value.op.name] = 0
                    split_operator_name = split_conf[0]["op_name"] if split_conf else ""
                    split_operator_index = split_conf[0]["op_index"] if split_conf else ""
                    # if a operator name and repeating count in the network match with the values
                    # of the 'split configuration', then this place is where we should do the
                    # graph splitting.
                    if (
                        split_conf
                        and split_operator_name in operator_index_map
                        and operator_index_map[split_operator_name] >= split_operator_index
                    ):
                        # Do graph splitting.
                        split_conf.pop(0)
                        snode_dep.append({"nodes": {}, "ref_nodes": {}})
                        ann = _recursion(
                            anf.body,
                            pipeline_mods,
                            split_conf,
                            constant_expr,
                        )
                        snode_dep.pop()
                        dep_vars = get_dep_var(snode_dep)
                        # When the nodes of the current subgraph are the depedency node of another
                        # subgraph, we need to set them as the output of current subgraph.
                        body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var
                        # when the operator of current subgraph uses previous subgraph constant
                        # as the argument of a "relay.expr.call", such constant may become a free
                        # varaible if the constant does not exist in the current subgraph.
                        # merge the previous constant with current subgraph to avoid such issue.
                        if constant_expr:
                            ann = merge_constant_expr(constant_expr, ann)
                        ann = run_opt_pass(ann, transform.ToGraphNormalForm())
                        mod = tvm.IRModule.from_expr(ann)
                        pipeline_mods.insert(0, mod)
                        # Return the last node of the current subgraph.
                        return tvm.relay.expr.Let(anf.var, value, body)
            return tvm.relay.expr.Let(
                anf.var,
                value,
                _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
            )
        else:
            return anf

    snode_dep = [{"nodes": {}, "ref_nodes": {}}]
    pipeline_mods = []
    operator_index_map = {}
    # Used to tracking new input which caused by graph splitting.
    new_input_idx = 0
    constant_expr = None
    subgraph_split_conf = split_conf.copy()
    # Binding the parameters.
    if params:
        expr = build_module.bind_params_by_name(expr, params)
    anf = run_opt_pass(expr, transform.ToANormalForm())
    anf = run_opt_pass(anf, transform.InferType())
    ann = _recursion(
        anf,
        pipeline_mods,
        subgraph_split_conf,
        constant_expr,
    )
    ann = run_opt_pass(ann.body, transform.ToGraphNormalForm())
    mod = tvm.IRModule.from_expr(ann)
    pipeline_mods.insert(0, mod)
    return pipeline_mods