def annotate(func, compiler):
        """
        An annotator for Core ML.
        """
        # Bind free variables to the constant values.
        bind_dict = {}
        for arg in func.params:
            name = arg.name_hint
            if name in params:
                bind_dict[arg] = relay.const(params[name])

        func = relay.bind(func, bind_dict)

        # Annotate the entire graph for Core ML
        mod = tvm.IRModule()
        mod["main"] = func

        seq = tvm.transform.Sequential([
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
            transform.AnnotateTarget(compiler),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ])

        with relay.build_config(opt_level=3):
            mod = seq(mod)

        return mod
    def partition():
        data = relay.var("data", relay.TensorType((1, 16, 224, 224),
                                                  "float32"))
        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))

        bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
                                        bn_mvar)

        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar],
                              bn_output.astuple())
        mod = tvm.IRModule()
        mod["main"] = func
        op_list = ["nn.batch_norm", "nn.conv2d"]
        mod = WhiteListAnnotator(op_list, "test_compiler")(mod)

        opt_pass = tvm.transform.Sequential([
            transform.InferType(),
            transform.PartitionGraph(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.AlterOpLayout(),
            transform.Inline(),
        ])

        with relay.build_config(opt_level=3):
            mod = opt_pass(mod)

        return mod
示例#3
0
def partition_for_dnnl(mod, params=None):
    """Partition the graph greedily offloading supported operators to DNNL.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    Returns
    -------
    mod : Module
        Annotated and partitioned module.
    """

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)
    seq = tvm.transform.Sequential([
        transform.CanonicalizeOps(),
        transform.InferType(),
        transform.SimplifyInference(),
        transform.FoldConstant(),
        transform.FoldScaleAxis(),
        # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
        transform.SimplifyExpr(),
        transform.FoldConstant(),
        transform.MergeComposite(pattern_table()),
        transform.AnnotateTarget("dnnl"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    return mod
示例#4
0
def test_partial_constant():
    """Test the subgraph with (const, var, const, var) arguments."""
    if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
        print("skip because DNNL codegen is not available")
        return

    dtype = "float32"
    ishape = (10, 10)

    in_1 = relay.var("in_1", shape=ishape, dtype=dtype)
    in_2 = relay.var("in_2", shape=ishape, dtype=dtype)
    in_3 = relay.var("in_3", shape=ishape, dtype=dtype)
    in_4 = relay.var("in_4", shape=ishape, dtype=dtype)

    add1 = relay.add(in_1, in_2)
    add2 = relay.add(add1, in_3)
    add3 = relay.add(add2, in_3)
    add4 = relay.add(add3, in_3)

    func = relay.Function([in_1, in_2, in_3, in_4], add4)
    ref_mod = tvm.IRModule.from_expr(func)
    ref_mod = relay.transform.InferType()(ref_mod)

    data1 = np.random.uniform(0, 1, ishape).astype(dtype)
    data3 = np.random.uniform(0, 1, ishape).astype(dtype)

    params = {
        "in_1": tvm.nd.array(data1, device=tvm.cpu(0)),
        "in_3": tvm.nd.array(data3, device=tvm.cpu(0)),
    }
    ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)

    opt_pass = tvm.transform.Sequential([
        transform.InferType(),
        transform.SimplifyInference(),
        transform.FoldConstant(),
        transform.FoldScaleAxis(),
        transform.AnnotateTarget("dnnl"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])

    with tvm.transform.PassContext(opt_level=3,
                                   disabled_pass=["AlterOpLayout"]):
        mod = opt_pass(ref_mod)

    data2 = np.random.uniform(0, 1, ishape).astype(dtype)
    data4 = np.random.uniform(0, 1, ishape).astype(dtype)
    check_result(mod,
                 ref_mod, {
                     "in_2": data2,
                     "in_4": data4
                 }, (10, 10),
                 tol=1e-5)
示例#5
0
def test_constant():
    """Test the subgraph with (var, const, ...) arguments."""
    if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
        print("skip because DNNL codegen is not available")
        return

    dtype = "float32"
    ishape = (1, 32, 14, 14)
    wshape = (32, 32, 3, 3)

    data = relay.var("data", shape=ishape, dtype=dtype)
    weight = relay.var("weight", shape=wshape, dtype=dtype)
    bn_gamma = relay.var("bn_gamma")
    bn_beta = relay.var("bn_beta")
    bn_mmean = relay.var("bn_mean")
    bn_mvar = relay.var("bn_var")

    layer = relay.nn.conv2d(data=data,
                            weight=weight,
                            kernel_size=(3, 3),
                            padding=(1, 1))
    bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean,
                                    bn_mvar)
    out = bn_output[0]
    out = relay.nn.relu(out)

    func = relay.Function(relay.analysis.free_vars(out), out)
    ref_mod, params = tvm.relay.testing.create_workload(func)
    ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)

    remove_bn_pass = tvm.transform.Sequential([
        transform.InferType(),
        transform.SimplifyInference(),
        transform.FoldConstant(),
        transform.FoldScaleAxis(),
    ])

    dnnl_patterns = get_pattern_table("dnnl")
    composite_partition = tvm.transform.Sequential([
        transform.MergeComposite(dnnl_patterns),
        transform.AnnotateTarget("dnnl"),
        transform.PartitionGraph(),
    ])

    with tvm.transform.PassContext(opt_level=3,
                                   disabled_pass=["AlterOpLayout"]):
        ref_mod = remove_bn_pass(ref_mod)
        mod = composite_partition(ref_mod)

    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
    check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5)
示例#6
0
def quantize_model(model, params, input_dtype, input_shape, qeval='power2'):

    skip_conv_layers = [0]
    with relay.quantize.qconfig(store_lowbit_output=False,
                                skip_conv_layers=skip_conv_layers):
        from tvm.relay.quantize.quantize import _bind_params
        graph = _bind_params(model['main'], params)
        mod = relay.Module.from_expr(graph)
        optimize = _transform.Sequential([
            _transform.SimplifyInference(),
            _transform.FoldConstant(),
            _transform.FoldScaleAxis(),
            _transform.CanonicalizeOps(),
            _transform.FoldConstant()
        ])

        with relay.build_config(opt_level=4):
            mod = optimize(mod)
            mod = relay.quantize.annotate()(mod)

            # find scale
            cache_file = '%s_%s_scales.pkl' % (VIDEO_FILE, MODEL_NAME)
            if os.path.exists(cache_file):
                print("Using cached layer statistics...")
                with open(cache_file, 'rb') as f:
                    scales = pickle.load(f)
            else:
                print("Compute layer statistics...")
                scales = calibrate_on_dataset(mod['main'], params, input_dtype,
                                              input_shape)
                with open(cache_file, 'wb') as f:
                    pickle.dump(scales, f)

            if qeval == 'power2':
                scales = list(
                    map(
                        lambda scale: 2**np.math.ceil(np.math.log(scale, 2))
                        if scale > 0 else 1.0, scales))
                weight_scales = 'power2'
            elif qeval == 'max':
                weight_scales = 'max'
            else:
                raise ValueError("Invalid quantiziation eval: " + qeval)

            mod['main'] = relay.quantize.calibrate(mod['main'],
                                                   weight_scales=weight_scales,
                                                   scales=scales)
            mod = relay.quantize.realize()(mod)
            mod = relay.transform.FoldConstant()(mod)

    return mod
示例#7
0
def simplify_model(mod):
    """
    Simplify execution graph

    At least merge BatchNorm into convolution. For this purpose decompose BN primitive
    into simple operation which can be calculated as const expr and after that merged
    into nearest conv/dense primitive.
    """
    seq = tvm.transform.Sequential([
        transform.InferType(),
        transform.FoldConstant(),
        transform.SimplifyInference(),
        transform.FoldScaleAxis(),
    ])
    return seq(mod)
示例#8
0
    def get_partitoned_mod(mod):
        remove_bn_pass = tvm.transform.Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        byoc_pass = tvm.transform.Sequential([
            remove_bn_pass,
            transform.AnnotateTarget("dnnl"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ])

        with tvm.transform.PassContext(opt_level=3,
                                       disabled_pass=["AlterOpLayout"]):
            return byoc_pass(mod)
    def get_partitoned_mod(mod, params, pattern_table):
        # This is required for constant folding
        mod["main"] = bind_params_by_name(mod["main"], params)

        remove_bn_pass = tvm.transform.Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        composite_partition = tvm.transform.Sequential([
            remove_bn_pass,
            transform.MergeComposite(pattern_table),
            transform.AnnotateTarget("dnnl"),
            transform.PartitionGraph()
        ])

        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
            return composite_partition(mod)
示例#10
0
    def partition():
        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
        weight = relay.var("weight", relay.TensorType((16, 3, 3, 3),
                                                      "float32"))
        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))

        conv = relay.nn.conv2d(data=data,
                               weight=weight,
                               kernel_size=(3, 3),
                               channels=16,
                               padding=(1, 1))
        bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
                                        bn_mvar)

        func = relay.Function(
            [data, weight, bn_gamma, bn_beta, bn_mmean, bn_mvar],
            bn_output.astuple())
        mod = tvm.IRModule()
        mod["main"] = func
        op_list = ["nn.batch_norm", "nn.conv2d"]
        mod = WhiteListAnnotator(op_list, "test_compiler")(mod)

        opt_pass = tvm.transform.Sequential([
            transform.InferType(),
            transform.PartitionGraph(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.AlterOpLayout(),
        ])

        with tvm.transform.PassContext(opt_level=3):
            mod = opt_pass(mod)

        return mod
示例#11
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass")

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)
def test_alter_layout_scalar_regression():
    """regression test where scalar fails"""

    def before():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var("weight", shape=(3, 3, 64, 16))
        bias = relay.var("bias", shape=(1, 1, 1, 16))
        y = relay.nn.conv2d(
            x,
            weight,
            channels=16,
            kernel_size=(3, 3),
            padding=(1, 1),
            data_layout="NHWC",
            kernel_layout="HWIO",
        )
        y = relay.add(y, bias)
        mean = relay.mean(y, axis=3, exclude=True)
        var = relay.variance(y, axis=3, exclude=True)
        gamma = relay.var("gamma")
        beta = relay.var("beta")
        y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3)
        y = y[0]
        y = relay.Function(analysis.free_vars(y), y)
        return y

    def alter_conv2d(attrs, inputs, tinfos, out_type):
        data, weight = inputs
        new_attrs = dict(attrs)
        new_attrs["data_layout"] = "NCHW16c"
        return relay.nn.conv2d(data, weight, **new_attrs)

    def expected():
        x = relay.var("x", shape=(1, 56, 56, 64))
        weight = relay.var("weight", shape=(3, 3, 64, 16))
        bias = relay.var("bias", shape=(1, 1, 1, 16))
        x = relay.layout_transform(x, src_layout="NHWC", dst_layout="NCHW")
        x = relay.layout_transform(x, src_layout="NCHW", dst_layout="NCHW16c")
        weight = relay.layout_transform(weight, src_layout="HWIO", dst_layout="OIHW")
        y = relay.nn.conv2d(
            x, weight, channels=16, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
        )
        bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW")
        bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
        add = relay.add(y, bias)
        y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
        y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC")
        mean = relay.mean(y, axis=3, exclude=True)
        var = relay.variance(y, axis=3, exclude=True)
        denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
        gamma = relay.var("gamma", shape=(16,))
        denom = denom * gamma
        denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2)
        denom_expand2 = relay.expand_dims(denom_expand1, axis=0)
        denom_nchwc16 = relay.layout_transform(
            denom_expand2, src_layout="NCHW", dst_layout="NCHW16c"
        )
        out = add * denom_nchwc16
        beta = relay.var("beta", shape=(16,))
        numerator = (-mean) * denom + beta
        numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2)
        numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0)
        numerator_nchwc16 = relay.layout_transform(
            numerator_expand2, src_layout="NCHW", dst_layout="NCHW16c"
        )
        out = out + numerator_nchwc16
        out = relay.layout_transform(out, src_layout="NCHW16c", dst_layout="NCHW")
        y = relay.layout_transform(out, src_layout="NCHW", dst_layout="NHWC")
        y = relay.Function(analysis.free_vars(y), y)
        return y

    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
        a = before()
        desired_layouts = {"nn.conv2d": ["NCHW", "default"], "nn.batch_norm": ["NHWC", "default"]}
        a = run_opt_pass(
            a,
            [
                transform.InferType(),
                relay.transform.ConvertLayout(desired_layouts),
                transform.SimplifyInference(),
                transform.CanonicalizeOps(),
                transform.AlterOpLayout(),
            ],
        )
        b = run_opt_pass(expected(), transform.InferType())

    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
示例#13
0

def add_functions(mod, funcs):
    """Workaround for type checker and mutually recursive functions."""
    for gv in funcs:
        func = funcs[gv]
        body = _placeholder_body(func.ret_type)
        mod[gv] = relay.Function(func.params, body, func.ret_type)

    for gv in funcs:
        mod[gv] = funcs[gv]


pass_set = transform.Sequential(
    passes=[
        transform.SimplifyInference(),
        transform.CanonicalizeOps(),
        transform.CanonicalizeCast(),
        transform.FuseOps(3),
        # transform.CombineParallelConv2d(),
        transform.AlterOpLayout(),
        # transform.RewriteAnnotatedOps(???),
    ],
    opt_level=0)


def optimize(mod):
    """Optimize all the functions in a module.

    Modules are the only mutable piece of Relay.  We write an
    optimization pass over the module which destructively updates each
示例#14
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""
    dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
    dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None),
                      check_gemm)
    dense_bias_relu_pat = ("cutlass.dense_bias_relu",
                           make_gemm_pattern(True, "relu"), check_gemm)
    dense_bias_gelu_fp16_pat = (
        "cutlass.dense_bias_gelu_fp16",
        make_gemm_pattern(True, "gelu"),
        check_gemm,
    )
    dense_bias_gelu_fp32_pat = (
        "cutlass.dense_bias_gelu_fp32",
        make_gemm_pattern(True, "gelu", out_dtype="float32"),
        check_gemm,
    )
    cutlass_patterns = [
        dense_bias_gelu_fp16_pat,
        dense_bias_gelu_fp32_pat,
        dense_bias_relu_pat,
        dense_bias_pat,
        dense_pat,
        ("cutlass.batch_matmul", make_batch_matmul_pattern(),
         check_batch_matmul),
        (
            "cutlass.conv2d_bias_hardswish",
            make_conv2d_pattern(with_bias=True, with_act="hardswish"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_silu",
            make_conv2d_pattern(with_bias=True, with_act="silu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_relu",
            make_conv2d_pattern(with_bias=True, with_act="relu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_sigmoid",
            make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
            check_conv2d,
        ),
        ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True),
         check_conv2d),
        ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
    ]

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)
示例#15
0
文件: test_dnnl.py 项目: wenxcs/tvm
def partition_for_dnnl(mod, params=None, alter_layout=True):
    """Partition the graph greedily offloading supported operators to DNNL.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    Returns
    -------
    mod : Module
        Annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv):
        with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize",
                        dnnl.legalize_group_conv):
            seq = tvm.transform.Sequential([
                transform.CanonicalizeOps(),
                transform.InferType(),
                transform.SimplifyInference(),
                transform.FoldConstant(),
                transform.FoldScaleAxis(),
                # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
                transform.SimplifyExpr(),
                transform.FoldConstant(),
                # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW`
                transform.Legalize(),
                transform.FoldConstant(),
            ])
            with tvm.transform.PassContext(opt_level=3):
                mod = seq(mod)
    if alter_layout:
        with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv):
            with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv):
                with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout",
                                dnnl.alter_conv):
                    with TempOpAttr("nn.conv2d_transpose", "FTVMAlterOpLayout",
                                    dnnl.alter_conv_transpose):
                        with TempOpAttr("nn.conv3d_transpose",
                                        "FTVMAlterOpLayout",
                                        dnnl.alter_conv_transpose):
                            alter_layout_seq = tvm.transform.Sequential([
                                transform.AlterOpLayout(),
                                transform.FoldConstant(),
                            ])
                            with tvm.transform.PassContext(opt_level=3):
                                mod = alter_layout_seq(mod)

    byoc_seq = tvm.transform.Sequential([
        transform.MergeComposite(dnnl.pattern_table()),
        transform.AnnotateTarget("dnnl"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])
    with tvm.transform.PassContext(opt_level=3):
        mod = byoc_seq(mod)
    return mod
示例#16
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""
    dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
    dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None),
                      check_gemm)
    dense_bias_relu_pat = ("cutlass.dense_bias_relu",
                           make_gemm_pattern(True, "relu"), check_gemm)
    dense_bias_gelu_fp16_pat = (
        "cutlass.dense_bias_gelu_fp16",
        make_gemm_pattern(True, "gelu"),
        check_gemm,
    )
    dense_bias_gelu_fp32_pat = (
        "cutlass.dense_bias_gelu_fp32",
        make_gemm_pattern(True, "gelu", out_dtype="float32"),
        check_gemm,
    )

    dense_patterns = [
        dense_bias_gelu_fp16_pat,
        dense_bias_gelu_fp32_pat,
        dense_bias_relu_pat,
        dense_bias_pat,
        dense_pat,
        ("cutlass.batch_matmul", make_batch_matmul_pattern(),
         check_batch_matmul),
    ]

    conv2d_patterns = [
        (
            "cutlass.conv2d_bias_hardswish",
            make_conv2d_pattern(with_bias=True, with_act="hardswish"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_silu",
            make_conv2d_pattern(with_bias=True, with_act="silu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_relu",
            make_conv2d_pattern(with_bias=True, with_act="relu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_sigmoid",
            make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
            check_conv2d,
        ),
        ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True),
         check_conv2d),
        ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
    ]

    residual_block_patterns = []

    for with_act, postfix in [("relu", "_relu"), (None, "")]:
        for name, pat, _ in conv2d_patterns[:-1]:
            for bin_op in ["add", "multiply"]:
                residual_block_patterns.append((
                    name + "_residual_" + bin_op + postfix,
                    make_residual_block_pattern(pat, bin_op,
                                                with_act=with_act),
                    partial(check_conv2d_residual, binary_op=bin_op),
                ))

    cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)