def annotate(func, compiler): """ An annotator for Core ML. """ # Bind free variables to the constant values. bind_dict = {} for arg in func.params: name = arg.name_hint if name in params: bind_dict[arg] = relay.const(params[name]) func = relay.bind(func, bind_dict) # Annotate the entire graph for Core ML mod = tvm.IRModule() mod["main"] = func seq = tvm.transform.Sequential([ transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), transform.AnnotateTarget(compiler), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with relay.build_config(opt_level=3): mod = seq(mod) return mod
def partition_for_dnnl(mod, params=None): """Partition the graph greedily offloading supported operators to DNNL. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- mod : Module Annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential([ transform.CanonicalizeOps(), transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` transform.SimplifyExpr(), transform.FoldConstant(), transform.MergeComposite(pattern_table()), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) return mod
def test_partial_constant(): """Test the subgraph with (const, var, const, var) arguments.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): print("skip because DNNL codegen is not available") return dtype = "float32" ishape = (10, 10) in_1 = relay.var("in_1", shape=ishape, dtype=dtype) in_2 = relay.var("in_2", shape=ishape, dtype=dtype) in_3 = relay.var("in_3", shape=ishape, dtype=dtype) in_4 = relay.var("in_4", shape=ishape, dtype=dtype) add1 = relay.add(in_1, in_2) add2 = relay.add(add1, in_3) add3 = relay.add(add2, in_3) add4 = relay.add(add3, in_3) func = relay.Function([in_1, in_2, in_3, in_4], add4) ref_mod = tvm.IRModule.from_expr(func) ref_mod = relay.transform.InferType()(ref_mod) data1 = np.random.uniform(0, 1, ishape).astype(dtype) data3 = np.random.uniform(0, 1, ishape).astype(dtype) params = { "in_1": tvm.nd.array(data1, device=tvm.cpu(0)), "in_3": tvm.nd.array(data3, device=tvm.cpu(0)), } ref_mod["main"] = bind_params_by_name(ref_mod["main"], params) opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): mod = opt_pass(ref_mod) data2 = np.random.uniform(0, 1, ishape).astype(dtype) data4 = np.random.uniform(0, 1, ishape).astype(dtype) check_result(mod, ref_mod, { "in_2": data2, "in_4": data4 }, (10, 10), tol=1e-5)
def test_constant(): """Test the subgraph with (var, const, ...) arguments.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): print("skip because DNNL codegen is not available") return dtype = "float32" ishape = (1, 32, 14, 14) wshape = (32, 32, 3, 3) data = relay.var("data", shape=ishape, dtype=dtype) weight = relay.var("weight", shape=wshape, dtype=dtype) bn_gamma = relay.var("bn_gamma") bn_beta = relay.var("bn_beta") bn_mmean = relay.var("bn_mean") bn_mvar = relay.var("bn_var") layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), padding=(1, 1)) bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar) out = bn_output[0] out = relay.nn.relu(out) func = relay.Function(relay.analysis.free_vars(out), out) ref_mod, params = tvm.relay.testing.create_workload(func) ref_mod["main"] = bind_params_by_name(ref_mod["main"], params) remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) dnnl_patterns = get_pattern_table("dnnl") composite_partition = tvm.transform.Sequential([ transform.MergeComposite(dnnl_patterns), transform.AnnotateTarget("dnnl"), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): ref_mod = remove_bn_pass(ref_mod) mod = composite_partition(ref_mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5)
def quantize_model(model, params, input_dtype, input_shape, qeval='power2'): skip_conv_layers = [0] with relay.quantize.qconfig(store_lowbit_output=False, skip_conv_layers=skip_conv_layers): from tvm.relay.quantize.quantize import _bind_params graph = _bind_params(model['main'], params) mod = relay.Module.from_expr(graph) optimize = _transform.Sequential([ _transform.SimplifyInference(), _transform.FoldConstant(), _transform.FoldScaleAxis(), _transform.CanonicalizeOps(), _transform.FoldConstant() ]) with relay.build_config(opt_level=4): mod = optimize(mod) mod = relay.quantize.annotate()(mod) # find scale cache_file = '%s_%s_scales.pkl' % (VIDEO_FILE, MODEL_NAME) if os.path.exists(cache_file): print("Using cached layer statistics...") with open(cache_file, 'rb') as f: scales = pickle.load(f) else: print("Compute layer statistics...") scales = calibrate_on_dataset(mod['main'], params, input_dtype, input_shape) with open(cache_file, 'wb') as f: pickle.dump(scales, f) if qeval == 'power2': scales = list( map( lambda scale: 2**np.math.ceil(np.math.log(scale, 2)) if scale > 0 else 1.0, scales)) weight_scales = 'power2' elif qeval == 'max': weight_scales = 'max' else: raise ValueError("Invalid quantiziation eval: " + qeval) mod['main'] = relay.quantize.calibrate(mod['main'], weight_scales=weight_scales, scales=scales) mod = relay.quantize.realize()(mod) mod = relay.transform.FoldConstant()(mod) return mod
def simplify_model(mod): """ Simplify execution graph At least merge BatchNorm into convolution. For this purpose decompose BN primitive into simple operation which can be calculated as const expr and after that merged into nearest conv/dense primitive. """ seq = tvm.transform.Sequential([ transform.InferType(), transform.FoldConstant(), transform.SimplifyInference(), transform.FoldScaleAxis(), ]) return seq(mod)
def get_partitoned_mod(mod): remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) byoc_pass = tvm.transform.Sequential([ remove_bn_pass, transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): return byoc_pass(mod)
def get_partitoned_mod(mod, params, pattern_table): # This is required for constant folding mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) composite_partition = tvm.transform.Sequential([ remove_bn_pass, transform.MergeComposite(pattern_table), transform.AnnotateTarget("dnnl"), transform.PartitionGraph() ]) with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): return composite_partition(mod)
def partition_for_bnns(mod, params=None): """Partition the graph greedily offloading supported operators to BNNS. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- ret : annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential( [ transform.InferType(), transform.FoldConstant(), transform.FoldScaleAxis(), transform.DynamicToStatic(), transform.AlterOpLayout(), # TODO(apeskov): WA. AlterOpLayout call lead to constants shape transformation # Some expand_dims op may appears after constants. It breaks BNNS fusing. # So we have to call FoldConstant right before bnns composite passes. transform.FoldConstant(), transform.MergeComposite(get_pattern_table("bnns")), transform.AnnotateTarget("bnns"), # If you no need in per layer performance statistic you can # uncomment next line # transform.MergeCompilerRegions(), transform.PartitionGraph(), ] ) return seq(mod)
def match_pass_name(name): if name == 'FoldScaleAxis': return transform.FoldScaleAxis() if name == 'BackwardFoldScaleAxis': return transform.BackwardFoldScaleAxis() if name == 'ForwardFoldScaleAxis': return transform.ForwardFoldScaleAxis() if name == 'FuseOps': return transform.FuseOps(3) if name == 'FoldConstant': return transform.FoldConstant() if name == 'CombineParallelConv2d': return transform.CombineParallelConv2D() if name == 'AlterOpLayout': return transform.AlterOpLayout() if name == 'EliminateCommonSubexpr': return transform.EliminateCommonSubexpr() if name == 'PartialEvaluate': return transform.PartialEvaluate() if name == 'CanonicalizeCast': return transform.CanonicalizeCast() if name == 'CanonicalizeOps': return transform.CanonicalizeOps() raise Exception('Name {} does not match any pass'.format(name))
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass") seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) dense_bias_gelu_fp16_pat = ( "cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"), check_gemm, ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), check_gemm, ) cutlass_patterns = [ dense_bias_gelu_fp16_pat, dense_bias_gelu_fp32_pat, dense_bias_relu_pat, dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"), check_conv2d, ), ( "cutlass.conv2d_bias_silu", make_conv2d_pattern(with_bias=True, with_act="silu"), check_conv2d, ), ( "cutlass.conv2d_bias_relu", make_conv2d_pattern(with_bias=True, with_act="relu"), check_conv2d, ), ( "cutlass.conv2d_bias_sigmoid", make_conv2d_pattern(with_bias=True, with_act="sigmoid"), check_conv2d, ), ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)
def partition_for_dnnl(mod, params=None, alter_layout=True): """Partition the graph greedily offloading supported operators to DNNL. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- mod : Module Annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): seq = tvm.transform.Sequential([ transform.CanonicalizeOps(), transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` transform.SimplifyExpr(), transform.FoldConstant(), # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` transform.Legalize(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) if alter_layout: with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): with TempOpAttr("nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): alter_layout_seq = tvm.transform.Sequential([ transform.AlterOpLayout(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = alter_layout_seq(mod) byoc_seq = tvm.transform.Sequential([ transform.MergeComposite(dnnl.pattern_table()), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) return mod
def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) dense_bias_gelu_fp16_pat = ( "cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"), check_gemm, ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), check_gemm, ) dense_patterns = [ dense_bias_gelu_fp16_pat, dense_bias_gelu_fp32_pat, dense_bias_relu_pat, dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), ] conv2d_patterns = [ ( "cutlass.conv2d_bias_hardswish", make_conv2d_pattern(with_bias=True, with_act="hardswish"), check_conv2d, ), ( "cutlass.conv2d_bias_silu", make_conv2d_pattern(with_bias=True, with_act="silu"), check_conv2d, ), ( "cutlass.conv2d_bias_relu", make_conv2d_pattern(with_bias=True, with_act="relu"), check_conv2d, ), ( "cutlass.conv2d_bias_sigmoid", make_conv2d_pattern(with_bias=True, with_act="sigmoid"), check_conv2d, ), ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] residual_block_patterns = [] for with_act, postfix in [("relu", "_relu"), (None, "")]: for name, pat, _ in conv2d_patterns[:-1]: for bin_op in ["add", "multiply"]: residual_block_patterns.append(( name + "_residual_" + bin_op + postfix, make_residual_block_pattern(pat, bin_op, with_act=with_act), partial(check_conv2d_residual, binary_op=bin_op), )) cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) remove_bn_pass = Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) with PassContext(opt_level=3): mod = remove_bn_pass(mod) seq = Sequential([ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"], include_non_call_ops=False), transform.PartitionGraph(bind_constants=False), ]) return seq(mod)