def annotate(func, compiler): """ An annotator for Core ML. """ # Bind free variables to the constant values. bind_dict = {} for arg in func.params: name = arg.name_hint if name in params: bind_dict[arg] = relay.const(params[name]) func = relay.bind(func, bind_dict) # Annotate the entire graph for Core ML mod = tvm.IRModule() mod["main"] = func seq = tvm.transform.Sequential([ transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), transform.AnnotateTarget(compiler), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with relay.build_config(opt_level=3): mod = seq(mod) return mod
def partition(): data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32")) bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32")) bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32")) bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32")) bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple()) mod = tvm.IRModule() mod["main"] = func op_list = ["nn.batch_norm", "nn.conv2d"] mod = WhiteListAnnotator(op_list, "test_compiler")(mod) opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.PartitionGraph(), transform.SimplifyInference(), transform.FoldConstant(), transform.AlterOpLayout(), transform.Inline(), ]) with relay.build_config(opt_level=3): mod = opt_pass(mod) return mod
def partition_for_dnnl(mod, params=None): """Partition the graph greedily offloading supported operators to DNNL. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- mod : Module Annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential([ transform.CanonicalizeOps(), transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` transform.SimplifyExpr(), transform.FoldConstant(), transform.MergeComposite(pattern_table()), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) return mod
def test_partial_constant(): """Test the subgraph with (const, var, const, var) arguments.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): print("skip because DNNL codegen is not available") return dtype = "float32" ishape = (10, 10) in_1 = relay.var("in_1", shape=ishape, dtype=dtype) in_2 = relay.var("in_2", shape=ishape, dtype=dtype) in_3 = relay.var("in_3", shape=ishape, dtype=dtype) in_4 = relay.var("in_4", shape=ishape, dtype=dtype) add1 = relay.add(in_1, in_2) add2 = relay.add(add1, in_3) add3 = relay.add(add2, in_3) add4 = relay.add(add3, in_3) func = relay.Function([in_1, in_2, in_3, in_4], add4) ref_mod = tvm.IRModule.from_expr(func) ref_mod = relay.transform.InferType()(ref_mod) data1 = np.random.uniform(0, 1, ishape).astype(dtype) data3 = np.random.uniform(0, 1, ishape).astype(dtype) params = { "in_1": tvm.nd.array(data1, device=tvm.cpu(0)), "in_3": tvm.nd.array(data3, device=tvm.cpu(0)), } ref_mod["main"] = bind_params_by_name(ref_mod["main"], params) opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): mod = opt_pass(ref_mod) data2 = np.random.uniform(0, 1, ishape).astype(dtype) data4 = np.random.uniform(0, 1, ishape).astype(dtype) check_result(mod, ref_mod, { "in_2": data2, "in_4": data4 }, (10, 10), tol=1e-5)
def test_constant(): """Test the subgraph with (var, const, ...) arguments.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): print("skip because DNNL codegen is not available") return dtype = "float32" ishape = (1, 32, 14, 14) wshape = (32, 32, 3, 3) data = relay.var("data", shape=ishape, dtype=dtype) weight = relay.var("weight", shape=wshape, dtype=dtype) bn_gamma = relay.var("bn_gamma") bn_beta = relay.var("bn_beta") bn_mmean = relay.var("bn_mean") bn_mvar = relay.var("bn_var") layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), padding=(1, 1)) bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar) out = bn_output[0] out = relay.nn.relu(out) func = relay.Function(relay.analysis.free_vars(out), out) ref_mod, params = tvm.relay.testing.create_workload(func) ref_mod["main"] = bind_params_by_name(ref_mod["main"], params) remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) dnnl_patterns = get_pattern_table("dnnl") composite_partition = tvm.transform.Sequential([ transform.MergeComposite(dnnl_patterns), transform.AnnotateTarget("dnnl"), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): ref_mod = remove_bn_pass(ref_mod) mod = composite_partition(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5)
def quantize_model(model, params, input_dtype, input_shape, qeval='power2'): skip_conv_layers = [0] with relay.quantize.qconfig(store_lowbit_output=False, skip_conv_layers=skip_conv_layers): from tvm.relay.quantize.quantize import _bind_params graph = _bind_params(model['main'], params) mod = relay.Module.from_expr(graph) optimize = _transform.Sequential([ _transform.SimplifyInference(), _transform.FoldConstant(), _transform.FoldScaleAxis(), _transform.CanonicalizeOps(), _transform.FoldConstant() ]) with relay.build_config(opt_level=4): mod = optimize(mod) mod = relay.quantize.annotate()(mod) # find scale cache_file = '%s_%s_scales.pkl' % (VIDEO_FILE, MODEL_NAME) if os.path.exists(cache_file): print("Using cached layer statistics...") with open(cache_file, 'rb') as f: scales = pickle.load(f) else: print("Compute layer statistics...") scales = calibrate_on_dataset(mod['main'], params, input_dtype, input_shape) with open(cache_file, 'wb') as f: pickle.dump(scales, f) if qeval == 'power2': scales = list( map( lambda scale: 2**np.math.ceil(np.math.log(scale, 2)) if scale > 0 else 1.0, scales)) weight_scales = 'power2' elif qeval == 'max': weight_scales = 'max' else: raise ValueError("Invalid quantiziation eval: " + qeval) mod['main'] = relay.quantize.calibrate(mod['main'], weight_scales=weight_scales, scales=scales) mod = relay.quantize.realize()(mod) mod = relay.transform.FoldConstant()(mod) return mod
def simplify_model(mod): """ Simplify execution graph At least merge BatchNorm into convolution. For this purpose decompose BN primitive into simple operation which can be calculated as const expr and after that merged into nearest conv/dense primitive. """ seq = tvm.transform.Sequential([ transform.InferType(), transform.FoldConstant(), transform.SimplifyInference(), transform.FoldScaleAxis(), ]) return seq(mod)
def get_partitoned_mod(mod): remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) byoc_pass = tvm.transform.Sequential([ remove_bn_pass, transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): return byoc_pass(mod)
def get_partitoned_mod(mod, params, pattern_table): # This is required for constant folding mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) composite_partition = tvm.transform.Sequential([ remove_bn_pass, transform.MergeComposite(pattern_table), transform.AnnotateTarget("dnnl"), transform.PartitionGraph() ]) with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): return composite_partition(mod)
def partition(): data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32")) bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32")) bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32")) bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32")) bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32")) conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)) bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean, bn_mvar) func = relay.Function( [data, weight, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn_output.astuple()) mod = tvm.IRModule() mod["main"] = func op_list = ["nn.batch_norm", "nn.conv2d"] mod = WhiteListAnnotator(op_list, "test_compiler")(mod) opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.PartitionGraph(), transform.SimplifyInference(), transform.FoldConstant(), transform.AlterOpLayout(), ]) with tvm.transform.PassContext(opt_level=3): mod = opt_pass(mod) return mod
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass") seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)
def test_alter_layout_scalar_regression(): """regression test where scalar fails""" def before(): x = relay.var("x", shape=(1, 56, 56, 64)) weight = relay.var("weight", shape=(3, 3, 64, 16)) bias = relay.var("bias", shape=(1, 1, 1, 16)) y = relay.nn.conv2d( x, weight, channels=16, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC", kernel_layout="HWIO", ) y = relay.add(y, bias) mean = relay.mean(y, axis=3, exclude=True) var = relay.variance(y, axis=3, exclude=True) gamma = relay.var("gamma") beta = relay.var("beta") y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3) y = y[0] y = relay.Function(analysis.free_vars(y), y) return y def alter_conv2d(attrs, inputs, tinfos, out_type): data, weight = inputs new_attrs = dict(attrs) new_attrs["data_layout"] = "NCHW16c" return relay.nn.conv2d(data, weight, **new_attrs) def expected(): x = relay.var("x", shape=(1, 56, 56, 64)) weight = relay.var("weight", shape=(3, 3, 64, 16)) bias = relay.var("bias", shape=(1, 1, 1, 16)) x = relay.layout_transform(x, src_layout="NHWC", dst_layout="NCHW") x = relay.layout_transform(x, src_layout="NCHW", dst_layout="NCHW16c") weight = relay.layout_transform(weight, src_layout="HWIO", dst_layout="OIHW") y = relay.nn.conv2d( x, weight, channels=16, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c" ) bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW") bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC") mean = relay.mean(y, axis=3, exclude=True) var = relay.variance(y, axis=3, exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) denom = denom * gamma denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2) denom_expand2 = relay.expand_dims(denom_expand1, axis=0) denom_nchwc16 = relay.layout_transform( denom_expand2, src_layout="NCHW", dst_layout="NCHW16c" ) out = add * denom_nchwc16 beta = relay.var("beta", shape=(16,)) numerator = (-mean) * denom + beta numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2) numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0) numerator_nchwc16 = relay.layout_transform( numerator_expand2, src_layout="NCHW", dst_layout="NCHW16c" ) out = out + numerator_nchwc16 out = relay.layout_transform(out, src_layout="NCHW16c", dst_layout="NCHW") y = relay.layout_transform(out, src_layout="NCHW", dst_layout="NHWC") y = relay.Function(analysis.free_vars(y), y) return y with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = before() desired_layouts = {"nn.conv2d": ["NCHW", "default"], "nn.batch_norm": ["NHWC", "default"]} a = run_opt_pass( a, [ transform.InferType(), relay.transform.ConvertLayout(desired_layouts), transform.SimplifyInference(), transform.CanonicalizeOps(), transform.AlterOpLayout(), ], ) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def add_functions(mod, funcs): """Workaround for type checker and mutually recursive functions.""" for gv in funcs: func = funcs[gv] body = _placeholder_body(func.ret_type) mod[gv] = relay.Function(func.params, body, func.ret_type) for gv in funcs: mod[gv] = funcs[gv] pass_set = transform.Sequential( passes=[ transform.SimplifyInference(), transform.CanonicalizeOps(), transform.CanonicalizeCast(), transform.FuseOps(3), # transform.CombineParallelConv2d(), transform.AlterOpLayout(), # transform.RewriteAnnotatedOps(???), ], opt_level=0) def optimize(mod): """Optimize all the functions in a module. Modules are the only mutable piece of Relay. We write an optimization pass over the module which destructively updates each
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) dense_bias_gelu_fp16_pat = ( "cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"), check_gemm, ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), check_gemm, ) cutlass_patterns = [ dense_bias_gelu_fp16_pat, dense_bias_gelu_fp32_pat, dense_bias_relu_pat, dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"), check_conv2d, ), ( "cutlass.conv2d_bias_silu", make_conv2d_pattern(with_bias=True, with_act="silu"), check_conv2d, ), ( "cutlass.conv2d_bias_relu", make_conv2d_pattern(with_bias=True, with_act="relu"), check_conv2d, ), ( "cutlass.conv2d_bias_sigmoid", make_conv2d_pattern(with_bias=True, with_act="sigmoid"), check_conv2d, ), ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)
def partition_for_dnnl(mod, params=None, alter_layout=True): """Partition the graph greedily offloading supported operators to DNNL. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- mod : Module Annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): seq = tvm.transform.Sequential([ transform.CanonicalizeOps(), transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` transform.SimplifyExpr(), transform.FoldConstant(), # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` transform.Legalize(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) if alter_layout: with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): with TempOpAttr("nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): alter_layout_seq = tvm.transform.Sequential([ transform.AlterOpLayout(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = alter_layout_seq(mod) byoc_seq = tvm.transform.Sequential([ transform.MergeComposite(dnnl.pattern_table()), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) return mod
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) dense_bias_gelu_fp16_pat = ( "cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"), check_gemm, ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), check_gemm, ) dense_patterns = [ dense_bias_gelu_fp16_pat, dense_bias_gelu_fp32_pat, dense_bias_relu_pat, dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), ] conv2d_patterns = [ ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"), check_conv2d, ), ( "cutlass.conv2d_bias_silu", make_conv2d_pattern(with_bias=True, with_act="silu"), check_conv2d, ), ( "cutlass.conv2d_bias_relu", make_conv2d_pattern(with_bias=True, with_act="relu"), check_conv2d, ), ( "cutlass.conv2d_bias_sigmoid", make_conv2d_pattern(with_bias=True, with_act="sigmoid"), check_conv2d, ), ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] residual_block_patterns = [] for with_act, postfix in [("relu", "_relu"), (None, "")]: for name, pat, _ in conv2d_patterns[:-1]: for bin_op in ["add", "multiply"]: residual_block_patterns.append(( name + "_residual_" + bin_op + postfix, make_residual_block_pattern(pat, bin_op, with_act=with_act), partial(check_conv2d_residual, binary_op=bin_op), )) cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)