Exemple #1
0
def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
    """Partition the graph greedily offloading supported
    operators on Cortex-M using CMSIS-NN

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    mod_name: str, optional
        The module name

    Returns
    -------
    ret : Module
        annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("cmsis-nn"),
            transform.PartitionGraph(mod_name=mod_name),
            GenerateCMSISNNConstants(),
            ScalarToTensorConstants(),
            ExtractConstantsFromPartitionedFunction(),
            transform.InferType(),
        ]
    )
    return seq(mod)
Exemple #2
0
def partition_for_arm_compute_lib(mod, params=None):
    """Partition the graph greedily offloading supported
    operators to Arm Compute Library.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential([
        transform.MergeComposite(arm_compute_lib_pattern_table()),
        transform.AnnotateTarget("arm_compute_lib"),
        transform.PartitionGraph(),
    ])

    return seq(mod)
Exemple #3
0
def partition_for_dnnl(mod, params=None):
    """Partition the graph greedily offloading supported operators to DNNL.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    Returns
    -------
    mod : Module
        Annotated and partitioned module.
    """

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)
    seq = tvm.transform.Sequential([
        transform.CanonicalizeOps(),
        transform.InferType(),
        transform.SimplifyInference(),
        transform.FoldConstant(),
        transform.FoldScaleAxis(),
        # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
        transform.SimplifyExpr(),
        transform.FoldConstant(),
        transform.MergeComposite(pattern_table()),
        transform.AnnotateTarget("dnnl"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    return mod
Exemple #4
0
def partition_for_ethosn78(mod, params=None, **opts):
    """Partition the graph greedily offloading supported
    operators to Arm Ethos-N NPU.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    if not opts or opts.get("variant", "").lower() != "ethos-n78":
        raise ValueError("When targeting Ethos(TM)-N78, -variant=Ethos-N78 should be set.")

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("ethos-n"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ]
    )

    return seq(mod)
Exemple #5
0
def partition_for_cublas(
    mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
) -> tvm.IRModule:
    """Partition the graph to offload for cuBLAS.

    Parameters
    ----------
    mod : tvm.IRModule
        The module to partition.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        Constant input parameters.

    Returns
    -------
    tvm.IRModule
        The partitioned module.
    """

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("cublas"),
            transform.PartitionGraph(),
            transform.InferType(),
        ]
    )
    return seq(mod)
Exemple #6
0
def partition_for_arm_compute_lib(mod,
                                  params=None,
                                  disabled_ops=["concatenate"],
                                  **opts):
    """Partition the graph greedily offloading supported
    operators to Arm Compute Library.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    disabled_ops : Optional[list]
        Ops do not want to offload to ACL.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential([
        transform.InferType(),
        transform.MergeComposite(arm_compute_lib_pattern_table(disabled_ops)),
        transform.AnnotateTarget("arm_compute_lib", False),
        transform.PartitionGraph(),
    ])

    return seq(mod)
def partition_for_cmsisnn(mod, params=None, **opts):
    """Partition the graph greedily offloading supported
    operators on Cortex-M using CMSIS-NN

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : Module
        annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("cmsisnn"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ]
    )

    return seq(mod)
Exemple #8
0
def partition_for_clml(mod, params=None):
    """Partition the graph greedily offloading supported
    operators to CLML Library.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.FoldConstant(),
            transform.MergeComposite(clml_pattern_table()),
            transform.AnnotateTarget("clml", False),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ]
    )

    result_mod = seq(mod)
    return result_mod
Exemple #9
0
def partition_for_cutlass(mod):
    """Partition the input module into CUTLASS-supported subgraphs."""
    dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
    dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None))
    dense_bias_relu_pat = ("cutlass.dense_bias_relu",
                           make_gemm_pattern(True, "relu"))
    dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16",
                                make_gemm_pattern(True, "gelu"))
    dense_bias_gelu_fp32_pat = (
        "cutlass.dense_bias_gelu_fp32",
        make_gemm_pattern(True, "gelu", out_dtype="float32"),
    )
    cutlass_patterns = [
        dense_bias_gelu_fp16_pat,
        dense_bias_gelu_fp32_pat,
        dense_bias_relu_pat,
        dense_bias_pat,
        dense_pat,
        ("cutlass.batch_matmul", make_batch_matmul_pattern()),
        # TODO(masahi): Add more conv2d patterns
        ("cutlass.conv2d", make_conv2d_pattern()),
    ]
    mod = transform.MergeComposite(cutlass_patterns)(mod)
    mod = transform.AnnotateTarget(["cutlass"])(mod)
    mod = transform.PartitionGraph()(mod)
    return mod
Exemple #10
0
def partition_for_tensorrt(
    mod: tvm.IRModule,
    params: Optional[Dict[str, tvm.nd.NDArray]] = None,
    # CAUTION: Can't use default Target("tensorrt") here since the target kind is only available
    #          if is_tensorrt_compiler_enabled() == True.
    target: Optional[tvm.target.Target] = None,
) -> tvm.IRModule:
    """Partition all functions in mod to greedily offload supported operators to TensorRT.

    Parameters
    ----------
    mod : tvm.IRModule
        The module to partition.
    target : tvm.target.Target
        A target of kind "tensorrt" describing additional partitioning and compilation options.
    params : Optional[Dict[str, tvm.nd.NDArray]]
        Constant input parameters.

    Returns
    -------
    partitioned_mod : tvm.IRModule
        The partitioned module.

    """
    assert is_tensorrt_compiler_enabled(
    ), "Can only partition for TensorRT if it is enabled"
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)
    if target is None:
        # Use a default target. The get_tensorrt_target() function will similarly create an
        # equivalent default target when compilation continues after partitioning.
        target = tvm.target.Target("tensorrt")

    seq = tvm.transform.Sequential([
        transform.InferType(),
        RemoveDropoutPass(),
        transform.RemoveUnusedFunctions(),
        transform.ConvertLayout({
            "nn.conv1d": ["NCW", "default"],
            "nn.conv2d": ["NCHW", "default"],
            "nn.conv3d": ["NCDHW", "default"],
            "nn.conv2d_transpose": ["NCHW", "default"],
        }),
        transform.FoldConstant(),
        transform.MergeComposite(pattern_table()),
        transform.AnnotateTarget("tensorrt"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
        transform.InferType(),
    ])
    with target:
        mod = seq(mod)
        mod = prune_tensorrt_subgraphs(mod)
    return mod
Exemple #11
0
def test_constant():
    """Test the subgraph with (var, const, ...) arguments."""
    if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True):
        print("skip because DNNL codegen is not available")
        return

    dtype = "float32"
    ishape = (1, 32, 14, 14)
    wshape = (32, 32, 3, 3)

    data = relay.var("data", shape=ishape, dtype=dtype)
    weight = relay.var("weight", shape=wshape, dtype=dtype)
    bn_gamma = relay.var("bn_gamma")
    bn_beta = relay.var("bn_beta")
    bn_mmean = relay.var("bn_mean")
    bn_mvar = relay.var("bn_var")

    layer = relay.nn.conv2d(data=data,
                            weight=weight,
                            kernel_size=(3, 3),
                            padding=(1, 1))
    bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean,
                                    bn_mvar)
    out = bn_output[0]
    out = relay.nn.relu(out)

    func = relay.Function(relay.analysis.free_vars(out), out)
    ref_mod, params = tvm.relay.testing.create_workload(func)
    ref_mod["main"] = bind_params_by_name(ref_mod["main"], params)

    remove_bn_pass = tvm.transform.Sequential([
        transform.InferType(),
        transform.SimplifyInference(),
        transform.FoldConstant(),
        transform.FoldScaleAxis(),
    ])

    dnnl_patterns = get_pattern_table("dnnl")
    composite_partition = tvm.transform.Sequential([
        transform.MergeComposite(dnnl_patterns),
        transform.AnnotateTarget("dnnl"),
        transform.PartitionGraph(),
    ])

    with tvm.transform.PassContext(opt_level=3,
                                   disabled_pass=["AlterOpLayout"]):
        ref_mod = remove_bn_pass(ref_mod)
        mod = composite_partition(ref_mod)

    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
    check_result(mod, ref_mod, {"data": i_data}, (1, 32, 14, 14), tol=1e-5)
Exemple #12
0
def get_pass_order(use_patterns):
    """
    Get the pass ordering based on using predicates or patterns.

    Parameters
    ----------
    use_patterns: Bool
        True if pass needs to work with op patterns
    Returns
    ----------
    ret : Sequential
        Pass object
    """
    return (tvm.transform.Sequential([
        transform.InferType(),
        RemoveDropoutPass(),
        transform.RemoveUnusedFunctions(),
        transform.ConvertLayout({
            "nn.conv1d": ["NCW", "default"],
            "nn.conv2d": ["NCHW", "default"],
            "nn.conv3d": ["NCDHW", "default"],
            "nn.conv2d_transpose": ["NCHW", "default"],
        }),
        transform.FoldConstant(),
        transform.MergeComposite(pattern_table()),
        transform.AnnotateTarget("tensorrt"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
        transform.InlineComposites("tensorrt"),
        transform.InferType(),
    ]) if use_patterns else tvm.transform.Sequential([
        transform.InferType(),
        RemoveDropoutPass(),
        transform.RemoveUnusedFunctions(),
        transform.ConvertLayout({
            "nn.conv1d": ["NCW", "default"],
            "nn.conv2d": ["NCHW", "default"],
            "nn.conv3d": ["NCDHW", "default"],
            "nn.conv2d_transpose": ["NCHW", "default"],
        }),
        transform.FoldConstant(),
        transform.AnnotateTarget("tensorrt"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
        transform.InferType(),
    ]))
Exemple #13
0
def partition_for_ethosn(mod, params=None, **opts):
    """Partition the graph greedily offloading supported
    operators to Arm Ethos-N NPU.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    opts = opts or {}
    if "variant" not in opts:
        raise ValueError(
            "Please specify a variant in the target string, e.g. -variant=n78."
        )

    # -variant=ethos-n78 deprecated in favour of -variant=n78
    if opts["variant"].lower() == "ethos-n78":
        warnings.warn(
            "Please use '-variant=n78' instead of the deprecated "
            "'-variant=ethos-n78', which will be removed in TVM v0.9.",
            DeprecationWarning,
        )
    elif opts["variant"] != "n78":
        raise ValueError(
            "When targeting Ethos(TM)-N78, -variant=n78 should be set.")

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential([
        transform.InferType(),
        transform.MergeComposite(pattern_table()),
        transform.AnnotateTarget("ethos-n"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])

    return seq(mod)
    def get_partitoned_mod(mod, params, pattern_table):
        # This is required for constant folding
        mod["main"] = bind_params_by_name(mod["main"], params)

        remove_bn_pass = tvm.transform.Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        composite_partition = tvm.transform.Sequential([
            remove_bn_pass,
            transform.MergeComposite(pattern_table),
            transform.AnnotateTarget("dnnl"),
            transform.PartitionGraph()
        ])

        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
            return composite_partition(mod)
Exemple #15
0
def partition_for_bnns(mod, params=None):
    """Partition the graph greedily offloading supported
    operators to BNNS.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
            transform.DynamicToStatic(),
            transform.AlterOpLayout(),
            # TODO(apeskov): WA. AlterOpLayout call lead to constants shape transformation
            #   Some expand_dims op may appears after constants. It breaks BNNS fusing.
            #   So we have to call FoldConstant right before bnns composite passes.
            transform.FoldConstant(),
            transform.MergeComposite(get_pattern_table("bnns")),
            transform.AnnotateTarget("bnns"),
            #   If you no need in per layer performance statistic you can
            #   uncomment next line
            # transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ]
    )

    return seq(mod)
Exemple #16
0
def partition_for_ethosn77(mod, params=None, **opts):
    """Partition the graph greedily offloading supported
    operators to Arm Ethos-N NPU.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.

    Returns
    -------
    ret : annotated and partitioned module.
    """
    if opts:
        tops = opts.get("tops", None)
        ple_ratio = opts.get("ple_ratio", None)
        sram_size = opts.get("sram_size", None)
        if tops or ple_ratio or sram_size:
            raise ValueError(
                "Setting tops, ple_ratio or sram_size has no effect when targeting Ethos(TM)-N77"
            )

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("ethos-n"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
        ]
    )

    return seq(mod)
Exemple #17
0
def partition_for_cudnn(mod: tvm.IRModule) -> tvm.IRModule:
    """Partition the graph to offload for cuDNN.

    Parameters
    ----------
    mod : tvm.IRModule
        The module to partition.

    Returns
    -------
    tvm.IRModule
        The partitioned module.
    """

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("cudnn"),
            transform.PartitionGraph(),
            transform.InferType(),
        ]
    )
    return seq(mod)
Exemple #18
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass")

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)
Exemple #19
0
def partition_for_tensorrt(
    mod: tvm.IRModule,
    params: Optional[Dict[str, tvm.nd.NDArray]] = None,
    version: Optional[Tuple[int, int, int]] = None,
    use_implicit_batch: bool = True,
    remove_no_mac_subgraphs: bool = False,
    max_workspace_size: int = 1 << 30,
    use_fp16: bool = False,
    use_uint8: bool = False,
) -> Tuple[tvm.IRModule, Dict[str, Any]]:
    """Partition the graph greedily offloading supported operators to TensorRT.

    Parameters
    ----------
    mod : tvm.IRModule
        The module to run passes on.
    params : Optional[Dict[str, tvm.nd.NDArray]]
        Constant input parameters.
    version : Optional[Tuple[int, int, int]]
        TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
        USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead.
    use_implicit_batch : bool
        Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
        mode which will widen supported operators to include those which modify the batch dimension,
        but may reduce performance for some models.
    remove_no_mac_subgraphs : bool
        Removes subgraphs which have been partitioned for TensorRT if they do not have any
        multiply-accumulate operations. The removed subgraphs will go through TVM's standard
        compilation instead. Can improve performance.
    max_workspace_size : int
        How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
        See TensorRT documentation for more info.
    use_fp16: bool
        Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
        if FP16 inputs tensors and weights are used.
        Note that TensorRT will still choose a higher-precision kernel if it results in overall
        lower runtime, or if no low-precision implementation exists.
    use_uint8: bool
        Allows, TRT to automatically convert FP32 inputs to UINT8.

    Returns
    -------
    mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
        A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options"
        configuration which should be given to PassContext when building.

    """
    config: Dict[str, Any] = {
        "use_implicit_batch": use_implicit_batch,
        "max_workspace_size": max_workspace_size,
        "remove_no_mac_subgraphs": remove_no_mac_subgraphs,
        "use_fp16": use_fp16,
        "use_uint8": use_uint8,
    }
    if version:
        assert isinstance(version, tuple) and len(version) == 3
        config["tensorrt_version"] = version
    else:
        linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
        if not linked_version:
            logger.warning(
                "TVM was not built against TensorRT and no version was provided to "
                "partition_for_tensorrt. Defaulting to 6.0.1"
            )
            linked_version = (6, 0, 1)
        config["tensorrt_version"] = linked_version

    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            RemoveDropoutPass(),
            transform.RemoveUnusedFunctions(),
            transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
            transform.FoldConstant(),
            transform.MergeComposite(pattern_table()),
            transform.AnnotateTarget("tensorrt"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
            transform.InferType(),
        ]
    )
    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
        mod = seq(mod)
        # TODO(mbs): Revisit
        # mod = prune_tensorrt_subgraphs(mod)
    return mod, config
Exemple #20
0
def partition_for_dnnl(mod, params=None, alter_layout=True):
    """Partition the graph greedily offloading supported operators to DNNL.

    Parameters
    ----------
    mod : Module
        The module to run passes on.
    params : Optional[Dict[str, NDArray]]
        Constant input parameters.
    Returns
    -------
    mod : Module
        Annotated and partitioned module.
    """
    if params:
        mod["main"] = bind_params_by_name(mod["main"], params)

    with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv):
        with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize",
                        dnnl.legalize_group_conv):
            seq = tvm.transform.Sequential([
                transform.CanonicalizeOps(),
                transform.InferType(),
                transform.SimplifyInference(),
                transform.FoldConstant(),
                transform.FoldScaleAxis(),
                # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
                transform.SimplifyExpr(),
                transform.FoldConstant(),
                # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW`
                transform.Legalize(),
                transform.FoldConstant(),
            ])
            with tvm.transform.PassContext(opt_level=3):
                mod = seq(mod)
    if alter_layout:
        with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv):
            with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv):
                with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout",
                                dnnl.alter_conv):
                    with TempOpAttr("nn.conv2d_transpose", "FTVMAlterOpLayout",
                                    dnnl.alter_conv_transpose):
                        with TempOpAttr("nn.conv3d_transpose",
                                        "FTVMAlterOpLayout",
                                        dnnl.alter_conv_transpose):
                            alter_layout_seq = tvm.transform.Sequential([
                                transform.AlterOpLayout(),
                                transform.FoldConstant(),
                            ])
                            with tvm.transform.PassContext(opt_level=3):
                                mod = alter_layout_seq(mod)

    byoc_seq = tvm.transform.Sequential([
        transform.MergeComposite(dnnl.pattern_table()),
        transform.AnnotateTarget("dnnl"),
        transform.MergeCompilerRegions(),
        transform.PartitionGraph(),
    ])
    with tvm.transform.PassContext(opt_level=3):
        mod = byoc_seq(mod)
    return mod
Exemple #21
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""
    dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
    dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None),
                      check_gemm)
    dense_bias_relu_pat = ("cutlass.dense_bias_relu",
                           make_gemm_pattern(True, "relu"), check_gemm)
    dense_bias_gelu_fp16_pat = (
        "cutlass.dense_bias_gelu_fp16",
        make_gemm_pattern(True, "gelu"),
        check_gemm,
    )
    dense_bias_gelu_fp32_pat = (
        "cutlass.dense_bias_gelu_fp32",
        make_gemm_pattern(True, "gelu", out_dtype="float32"),
        check_gemm,
    )

    dense_patterns = [
        dense_bias_gelu_fp16_pat,
        dense_bias_gelu_fp32_pat,
        dense_bias_relu_pat,
        dense_bias_pat,
        dense_pat,
        ("cutlass.batch_matmul", make_batch_matmul_pattern(),
         check_batch_matmul),
    ]

    conv2d_patterns = [
        (
            "cutlass.conv2d_bias_hardswish",
            make_conv2d_pattern(with_bias=True, with_act="hardswish"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_silu",
            make_conv2d_pattern(with_bias=True, with_act="silu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_relu",
            make_conv2d_pattern(with_bias=True, with_act="relu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_sigmoid",
            make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
            check_conv2d,
        ),
        ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True),
         check_conv2d),
        ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
    ]

    residual_block_patterns = []

    for with_act, postfix in [("relu", "_relu"), (None, "")]:
        for name, pat, _ in conv2d_patterns[:-1]:
            for bin_op in ["add", "multiply"]:
                residual_block_patterns.append((
                    name + "_residual_" + bin_op + postfix,
                    make_residual_block_pattern(pat, bin_op,
                                                with_act=with_act),
                    partial(check_conv2d_residual, binary_op=bin_op),
                ))

    cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)
Exemple #22
0
def partition_for_cutlass(mod, params=None):
    """Partition the input module into CUTLASS-supported subgraphs."""
    dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
    dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None),
                      check_gemm)
    dense_bias_relu_pat = ("cutlass.dense_bias_relu",
                           make_gemm_pattern(True, "relu"), check_gemm)
    dense_bias_gelu_fp16_pat = (
        "cutlass.dense_bias_gelu_fp16",
        make_gemm_pattern(True, "gelu"),
        check_gemm,
    )
    dense_bias_gelu_fp32_pat = (
        "cutlass.dense_bias_gelu_fp32",
        make_gemm_pattern(True, "gelu", out_dtype="float32"),
        check_gemm,
    )
    cutlass_patterns = [
        dense_bias_gelu_fp16_pat,
        dense_bias_gelu_fp32_pat,
        dense_bias_relu_pat,
        dense_bias_pat,
        dense_pat,
        ("cutlass.batch_matmul", make_batch_matmul_pattern(),
         check_batch_matmul),
        (
            "cutlass.conv2d_bias_hardswish",
            make_conv2d_pattern(with_bias=True, with_act="hardswish"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_silu",
            make_conv2d_pattern(with_bias=True, with_act="silu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_relu",
            make_conv2d_pattern(with_bias=True, with_act="relu"),
            check_conv2d,
        ),
        (
            "cutlass.conv2d_bias_sigmoid",
            make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
            check_conv2d,
        ),
        ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True),
         check_conv2d),
        ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
    ]

    if params is not None:
        mod["main"] = bind_params_by_name(mod["main"], params)
        remove_bn_pass = Sequential([
            transform.InferType(),
            transform.SimplifyInference(),
            transform.FoldConstant(),
            transform.FoldScaleAxis(),
        ])
        with PassContext(opt_level=3):
            mod = remove_bn_pass(mod)

    seq = Sequential([
        transform.InferType(),
        transform.MergeComposite(cutlass_patterns),
        transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
        transform.PartitionGraph(bind_constants=False),
    ])

    return seq(mod)