def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta): """We assume stop_name only appears once for simplicity. This constraint will be lifted in the future. bitpack_start and bitpack_end are both inclusive. """ bitpack_start = op.op.get("annotation.bitpack_start") bitpack_end = op.op.get("annotation.bitpack_end") anf = run_opt_pass(expr, transform.ToANormalForm()) operator_current_idx = 0 def _recursion(anf, start_found, stop_found, operator_current_idx): """Helper to obtain the subgraph.""" if isinstance(anf, relay.Function): return relay.Function( anf.params, _recursion(anf.body, start_found, stop_found, operator_current_idx), anf.ret_type, anf.type_params, anf.attrs, ) if isinstance(anf, relay.expr.Let): value = anf.value if isinstance(value, relay.expr.Call): if isinstance(value.op, tvm.ir.Op): if value.op.name == start_name and not start_found: if operator_current_idx == start_name_idx or start_name_idx is None: value = relay.expr.Call(bitpack_start, [value]) start_found = True elif value.op.name == stop_name: if operator_current_idx == stop_name_idx or stop_name_idx is None: raise BT() operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx) try: return relay.expr.Let( anf.var, value, _recursion(anf.body, start_found, stop_found, operator_current_idx), ) except BT: assert start_found assert not stop_found stop_found = True value = relay.expr.Call(bitpack_end, [value]) # todo: check anf.body has no more stop_name beside that one return relay.expr.Let(anf.var, value, anf.body) else: assert start_found assert stop_found return anf annotated = _recursion(anf, False, False, operator_current_idx) return run_opt_pass(annotated, transform.ToGraphNormalForm())
def test_implicit_share(): x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) g = run_opt_pass(f, transform.ToGraphNormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0)
def test_round_trip(): x = relay.Var("x") y = relay.Var("y") z = relay.Var("z") body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) g = run_opt_pass(f, transform.ToGraphNormalForm()) h = run_opt_pass(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) check_eval(h, [], 8.0)
def test_round_trip(): x = relay.Var('x') y = relay.Var('y') z = relay.Var('z') body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) h = transform.OptimizeOnExpr(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) check_eval(h, [], 8.0)
def merge_transform_to_mxnet_model(mod): """ Add Image Transform Logic Into Model """ svalue = np.array([123., 117., 104.]) sub_data = relay.Constant(tvm.nd.array(svalue)).astype("float32") dvalue = np.array([58.395, 57.12, 57.37]) divide_data = relay.Constant(tvm.nd.array(dvalue)).astype("float32") data_shape = (224, 224, 3) data = relay.var("data", relay.TensorType(data_shape, "float32")) simple_net = relay.expand_dims(data, axis=0, num_newaxis=1) # To do, relay not support dynamic shape now, future need to add resize logic # simple_net = relay.image.resize(simple_net, (224, 224), "NHWC", "bilinear", "align_corners") simple_net = relay.subtract(simple_net, sub_data) simple_net = relay.divide(simple_net, divide_data) simple_net = relay.transpose(simple_net, ((0, 3, 1, 2))) #merge tranform into pretrained model network entry = mod["main"] anf = run_opt_pass(entry.body, transform.ToANormalForm()) call = anf.value data, weights = call.args first_op = op.nn.conv2d(simple_net, weights, strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, groups=call.attrs.groups, channels=call.attrs.channels, kernel_size=call.attrs.kernel_size, out_dtype=call.attrs.out_dtype) net = relay.expr.Let(anf.var, first_op, anf.body) net = run_opt_pass(net, transform.ToGraphNormalForm()) mod['main'] = net return mod
def _recursion(anf, pipeline_mods, split_conf, constant_expr): # Enumurate all operators of compute graph, then split the compute graph into a group of # subgraph. nonlocal operator_index_map nonlocal new_input_idx nonlocal snode_dep cur_node_dep = snode_dep[len(snode_dep) - 1] if isinstance(anf, tvm.relay.Function): return tvm.relay.Function( anf.params, _recursion(anf.body, pipeline_mods, split_conf, constant_expr), anf.ret_type, anf.type_params, anf.attrs, ) if isinstance(anf, tvm.relay.expr.Let): value = anf.value # record the constant expr to make sure all sugraphs can find correct constant. if isinstance(value, tvm.relay.expr.Constant): if not constant_expr: constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var) else: constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr) if isinstance(value, tvm.relay.expr.Call): new_args = [] # build current var list cur_node_dep["nodes"][anf.var] = 0 # Get the dependency information of the nodes. value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx) if isinstance(value.op, tvm.ir.Op): if value.op.name in operator_index_map: operator_index_map[value.op.name] += 1 else: operator_index_map[value.op.name] = 0 split_operator_name = split_conf[0]["op_name"] if split_conf else "" split_operator_index = split_conf[0]["op_index"] if split_conf else "" # if a operator name and repeating count in the network match with the values # of the 'split configuration', then this place is where we should do the # graph splitting. if ( split_conf and split_operator_name in operator_index_map and operator_index_map[split_operator_name] >= split_operator_index ): # Do graph splitting. split_conf.pop(0) snode_dep.append({"nodes": {}, "ref_nodes": {}}) ann = _recursion( anf.body, pipeline_mods, split_conf, constant_expr, ) snode_dep.pop() dep_vars = get_dep_var(snode_dep) # When the nodes of the current subgraph are the depedency node of another # subgraph, we need to set them as the output of current subgraph. body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var # when the operator of current subgraph uses previous subgraph constant # as the argument of a "relay.expr.call", such constant may become a free # varaible if the constant does not exist in the current subgraph. # merge the previous constant with current subgraph to avoid such issue. if constant_expr: ann = merge_constant_expr(constant_expr, ann) ann = run_opt_pass(ann, transform.ToGraphNormalForm()) mod = tvm.IRModule.from_expr(ann) pipeline_mods.insert(0, mod) # Return the last node of the current subgraph. return tvm.relay.expr.Let(anf.var, value, body) return tvm.relay.expr.Let( anf.var, value, _recursion(anf.body, pipeline_mods, split_conf, constant_expr), ) else: return anf
def graph_split(expr, split_conf, params=None): """Splitting the graph into a list of subgraphs""" def get_dep_var(sub_var_dep): return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]] def parse_dependency(value, snode_dep, new_input_idx): new_args = [] need_update = False for var in value.args: is_free_var = False for dep in snode_dep[:-1]: if var in dep["nodes"]: # Mark the previous subgraph node as a dependency. dep["nodes"][var] += 1 dep["ref_nodes"][var] = dep["nodes"][var] # The var of this call is a free_var is_free_var = True # if the var of this call is a free_var, recreate it and give it a fixed input name. if is_free_var: need_update = True new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type)) new_input_idx += 1 else: new_args.append(var) # if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'. if need_update: value = tvm.relay.expr.Call( value.op, new_args, value.attrs, value.type_args, value.span ) return value, snode_dep, new_input_idx def merge_constant_expr(constant_expr, expr): # merge constant express with a express if not isinstance(constant_expr.body, tvm.relay.expr.Let): return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr) return tvm.relay.expr.Let( constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr) ) def _recursion(anf, pipeline_mods, split_conf, constant_expr): # Enumurate all operators of compute graph, then split the compute graph into a group of # subgraph. nonlocal operator_index_map nonlocal new_input_idx nonlocal snode_dep cur_node_dep = snode_dep[len(snode_dep) - 1] if isinstance(anf, tvm.relay.Function): return tvm.relay.Function( anf.params, _recursion(anf.body, pipeline_mods, split_conf, constant_expr), anf.ret_type, anf.type_params, anf.attrs, ) if isinstance(anf, tvm.relay.expr.Let): value = anf.value # record the constant expr to make sure all sugraphs can find correct constant. if isinstance(value, tvm.relay.expr.Constant): if not constant_expr: constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var) else: constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr) if isinstance(value, tvm.relay.expr.Call): new_args = [] # build current var list cur_node_dep["nodes"][anf.var] = 0 # Get the dependency information of the nodes. value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx) if isinstance(value.op, tvm.ir.Op): if value.op.name in operator_index_map: operator_index_map[value.op.name] += 1 else: operator_index_map[value.op.name] = 0 split_operator_name = split_conf[0]["op_name"] if split_conf else "" split_operator_index = split_conf[0]["op_index"] if split_conf else "" # if a operator name and repeating count in the network match with the values # of the 'split configuration', then this place is where we should do the # graph splitting. if ( split_conf and split_operator_name in operator_index_map and operator_index_map[split_operator_name] >= split_operator_index ): # Do graph splitting. split_conf.pop(0) snode_dep.append({"nodes": {}, "ref_nodes": {}}) ann = _recursion( anf.body, pipeline_mods, split_conf, constant_expr, ) snode_dep.pop() dep_vars = get_dep_var(snode_dep) # When the nodes of the current subgraph are the depedency node of another # subgraph, we need to set them as the output of current subgraph. body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var # when the operator of current subgraph uses previous subgraph constant # as the argument of a "relay.expr.call", such constant may become a free # varaible if the constant does not exist in the current subgraph. # merge the previous constant with current subgraph to avoid such issue. if constant_expr: ann = merge_constant_expr(constant_expr, ann) ann = run_opt_pass(ann, transform.ToGraphNormalForm()) mod = tvm.IRModule.from_expr(ann) pipeline_mods.insert(0, mod) # Return the last node of the current subgraph. return tvm.relay.expr.Let(anf.var, value, body) return tvm.relay.expr.Let( anf.var, value, _recursion(anf.body, pipeline_mods, split_conf, constant_expr), ) else: return anf snode_dep = [{"nodes": {}, "ref_nodes": {}}] pipeline_mods = [] operator_index_map = {} # Used to tracking new input which caused by graph splitting. new_input_idx = 0 constant_expr = None subgraph_split_conf = split_conf.copy() # Binding the parameters. if params: expr = build_module.bind_params_by_name(expr, params) anf = run_opt_pass(expr, transform.ToANormalForm()) anf = run_opt_pass(anf, transform.InferType()) ann = _recursion( anf, pipeline_mods, subgraph_split_conf, constant_expr, ) ann = run_opt_pass(ann.body, transform.ToGraphNormalForm()) mod = tvm.IRModule.from_expr(ann) pipeline_mods.insert(0, mod) return pipeline_mods