def get_binds(args, compact=False, binds=None):
    """Internal function to get binds and arg_list given arguments.

    Parameters
    ----------
    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    compact : bool
        If the statement has already bound to a compact buffer.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    Returns
    -------
    binds: dict
        The bind specification

    arg_list: list
        The list of symbolic buffers of arguments.
    """
    binds = {} if binds is None else binds.copy()
    cfg = BuildConfig.current()
    arg_list = []
    for x in args:
        if isinstance(x, tensor.Tensor):
            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
            buffer_type = "auto_broadcast" if any_dim and not compact else ""
            if x not in binds:
                buf = tvm.tir.decl_buffer(
                    x.shape,
                    dtype=x.dtype,
                    name=x.name,
                    data_alignment=cfg.data_alignment,
                    offset_factor=cfg.offset_factor,
                    buffer_type=buffer_type)
                binds[x] = buf
                arg_list.append(buf)
            else:
                arg_list.append(binds[x])
        elif isinstance(x, schedule.Buffer):
            arg_list.append(x)
        elif isinstance(x, tvm.tir.Var):
            arg_list.append(x)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    return binds, arg_list
def form_irmodule(sch, args, name, binds):
    """According to the given schedule, form a function.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The given scheduler to form the raw body

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        The binds information

    Returns
    -------
    The body formed according to the given schedule
    """
    # normalize schedule first
    cfg = BuildConfig.current()
    sch = sch.normalize()
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds)

    compact = schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
    func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr("global_symbol", name)
    if cfg.restricted_func:
        func = func.with_attr("tir.noalias", True)
    return tvm.IRModule({name: func})
def _build_for_device(input_mod, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    input_mod : IRModule
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : IRModule
        The host IRModule.

    mdev : tvm.module
        A module that contains device code.
    """
    target = _target.create(target)
    target_host = _target.create(target_host)
    device_type = ndarray.context(target.target_name, 0).device_type

    mod_mixed = input_mod
    mod_mixed = tvm.tir.transform.Apply(
        lambda f: f.with_attr("target", target))(mod_mixed)

    opt_mixed = [tvm.tir.transform.VerifyMemory()]
    if len(mod_mixed.functions) == 1:
        opt_mixed += [
            tvm.tir.transform.Apply(
                lambda f: f.with_attr("tir.is_entry_func", True))
        ]
    if BuildConfig.current().detect_global_barrier:
        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
    opt_mixed += [
        tvm.tir.transform.ThreadSync("shared"),
        tvm.tir.transform.ThreadSync("warp"),
        tvm.tir.transform.InferFragment(),
        tvm.tir.transform.LowerThreadAllreduce(),
        tvm.tir.transform.MakePackedAPI(),
        tvm.tir.transform.SplitHostDevice()
    ]
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

    # device optimizations
    opt_device = tvm.transform.Sequential([
        tvm.tir.transform.Filter(
            lambda f: "calling_conv" in f.attrs and f.attrs[
                "calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
        tvm.tir.transform.LowerWarpMemory(),
        tvm.tir.transform.Simplify(),
        tvm.tir.transform.LowerDeviceStorageAccessInfo(),
        tvm.tir.transform.LowerIntrin()
    ])
    mod_dev = opt_device(mod_mixed)

    # host optimizations
    opt_host = tvm.transform.Sequential([
        tvm.tir.transform.Filter(
            lambda f: "calling_conv" not in f.attrs or f.attrs[
                "calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
        tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
        tvm.tir.transform.LowerTVMBuiltin(),
        tvm.tir.transform.LowerDeviceStorageAccessInfo(),
        tvm.tir.transform.LowerIntrin(),
        tvm.tir.transform.CombineContextCall()
    ])
    mod_host = opt_host(mod_mixed)

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert len(mod_dev.functions) == 0
    if "gpu" in target.keys and len(mod_dev.functions) == 0:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do "
            "bind?" % target)

    rt_mod_dev = codegen.build_module(
        mod_dev, target) if len(mod_dev.functions) != 0 else None
    return mod_host, rt_mod_dev
def lower(sch, args, name="main", binds=None, simple_mode=False):
    """Lowering step before build into target.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The schedule to be built

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str, optional
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.

    Returns
    -------
    m : IRModule or Stmt
       The result IRModule, if simple_mode=False
       Then the Stmt before make api is returned.
    """
    cfg = BuildConfig.current()
    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

    # Phase 0
    if isinstance(sch, schedule.Schedule):
        mod = form_irmodule(sch, args, name, binds)
    else:
        mod = sch

    pass_list = lower_phase0
    # Phase 1
    pass_list += [
        tvm.tir.transform.InjectPrefetch(),
        tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
        tvm.tir.transform.NarrowDataType(32),
        tvm.tir.transform.Simplify(),
    ]
    pass_list += lower_phase1

    # Phase 2
    if not simple_mode:
        pass_list += [
            (tvm.tir.transform.LoopPartition(cfg.partition_const_loop))
        ]

    pass_list += [
        tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
        tvm.tir.transform.InjectVirtualThread(),
        tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
        tvm.tir.transform.StorageRewrite(),
        tvm.tir.transform.UnrollLoop(cfg.auto_unroll_max_step,
                                     cfg.auto_unroll_max_depth,
                                     cfg.auto_unroll_max_extent,
                                     cfg.unroll_explicit),
    ]
    pass_list += lower_phase2

    # Phase 3
    pass_list += [
        tvm.tir.transform.Simplify(),
        tvm.tir.transform.RemoveNoOp(),
    ]

    if not cfg.disable_select_rewriting:
        pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
    pass_list += lower_phase3

    # Instrument BoundCheckers
    if cfg.instrument_bound_checkers:
        pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]

    optimize = tvm.transform.Sequential(pass_list)
    mod = optimize(mod)
    return mod
Exemple #5
0
def lower(sch, args, name="default_function", binds=None, simple_mode=False):
    """Lowering step before build into target.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The schedule to be built

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str, optional
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.

    Returns
    -------
    m : IRModule or Stmt
       The result IRModule, if simple_mode=False
       Then the Stmt before make api is returned.
    """
    cfg = BuildConfig.current()
    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
    if cfg.dump_pass_ir:
        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(
            add_lower_pass)
    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

    # Phase 0
    if isinstance(sch, schedule.Schedule):
        stmt = form_body(sch)

    for f in lower_phase0:
        stmt = f(stmt)

    compact = ir_pass.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    # Phase 1
    stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
    stmt = ir_pass.StorageFlatten(stmt, binds, 64,
                                  cfg.instrument_bound_checkers)
    stmt = ir_pass.NarrowDataType(stmt, 32)
    stmt = ir_pass.CanonicalSimplify(stmt)
    for f in lower_phase1:
        stmt = f(stmt)

    # Phase 2
    if not simple_mode:
        stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
    if cfg.disable_vectorize:
        stmt = ir_pass.SkipVectorize(stmt)
    else:
        stmt = ir_pass.VectorizeLoop(stmt)
    stmt = ir_pass.InjectVirtualThread(stmt)
    stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
    stmt = ir_pass.StorageRewrite(stmt)
    stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step,
                              cfg.auto_unroll_max_depth,
                              cfg.auto_unroll_max_extent, cfg.unroll_explicit)
    for f in lower_phase2:
        stmt = f(stmt)

    # Phase 3
    stmt = ir_pass.Simplify(stmt)
    stmt = ir_pass.RemoveNoOp(stmt)
    if not cfg.disable_select_rewriting:
        stmt = ir_pass.RewriteUnsafeSelect(stmt)
    for f in lower_phase3:
        stmt = f(stmt)
    # Instrument BoundCheckers
    if cfg.instrument_bound_checkers:
        stmt = ir_pass.InstrumentBoundCheckers(stmt)
    if simple_mode:
        return stmt

    f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol",
                                                   tvm.runtime.String(name))
    if cfg.restricted_func:
        f = f.with_attr("tir.noalias", True)
    mod = tvm.IRModule({name: f})
    return tvm.tir.transform.MakePackedAPI()(mod)
def _build_for_device(flist, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    flist : list of LoweredFunc
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : list of LoweredFunc
        A list of lowered functions for the host.

    mdev : tvm.module
        A module that contains device code.
    """
    target = _target.create(target)
    target_host = _target.create(target_host)
    device_type = ndarray.context(target.target_name, 0).device_type

    for func in flist:
        if not ir_pass.VerifyMemory(func, device_type):
            raise ValueError(
                "Direct host side access to device memory is detected in %s. "
                "Did you forget to bind?" % func.name)

    mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
    opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
    if BuildConfig.current().detect_global_barrier:
        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
    opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
                  tvm.tir.transform.ThreadSync("warp"),
                  tvm.tir.transform.InferFragment(),
                  tvm.tir.transform.LowerThreadAllreduce(),
                  tvm.tir.transform.BindDeviceType(),
                  tvm.tir.transform.SplitHostDevice()]
    mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)


    # device optimizations
    opt_device = tvm.ir.transform.Sequential(
        [tvm.tir.transform.Filter(
            lambda f: "calling_conv" in f.attrs and
            f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
         tvm.tir.transform.LowerWarpMemory(),
         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
         tvm.tir.transform.LowerIntrin()])
    mod_dev = opt_device(mod_mixed)

    # host optimizations
    opt_host = tvm.ir.transform.Sequential(
        [tvm.tir.transform.Filter(
            lambda f: "calling_conv" not in f.attrs or
            f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
         tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
         tvm.tir.transform.LowerTVMBuiltin(),
         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
         tvm.tir.transform.LowerIntrin(),
         tvm.tir.transform.CombineContextCall()])
    mod_host = opt_host(mod_mixed)

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert len(mod_dev.functions) == 0
    if "gpu" in target.keys and len(mod_dev.functions) == 0:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do "
            "bind?" % target)

    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
    return mod_host, rt_mod_dev
Exemple #7
0
def decl_tensor_intrin(op,
                       fcompute,
                       name="tensor_intrin",
                       binds=None, scalar_params=None):
    """Declare a tensor intrinsic function.

    Parameters
    ----------
    op: Operation
        The symbolic description of the intrinsic operation

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`, or tuple of three stmts)
             - If a single stmt is returned, it represents the body
             - If tuple of three stmts are returned they corresponds to body,
               reduce_init, reduce_update

    name: str, optional
        The name of the intrinsic.

    binds: dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    scalar_params: a list of variables used by op, whose values will be passed
                   as scalar_inputs when the tensor intrinsic is called.

    Returns
    -------
    intrin: TensorIntrin
        A TensorIntrin that can be used in tensorize schedule.
    """
    if not isinstance(op, _tensor.Operation):
        raise TypeError("expect Operation")
    inputs = op.input_tensors
    binds = binds if binds else {}
    tensors = list(inputs)
    for i in range(op.num_outputs):
        tensors.append(op.output(i))

    binds_list = []
    for t in inputs:
        if not isinstance(t.op, PlaceholderOp):
            raise ValueError("Do not yet support composition op")

    cfg = BuildConfig.current()
    for t in tensors:
        buf = (binds[t] if t in binds else
               tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
                                   data_alignment=cfg.data_alignment,
                                   offset_factor=cfg.offset_factor))
        binds_list.append(buf)

    if scalar_params:
        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
    else:
        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
        scalar_params = []
    if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)):
        body = [body]
    body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body]
    if len(body) < 3:
        body += [None] * (3 - len(body))
    return _ffi_api.TensorIntrin(
        name, op, inputs, binds_list, scalar_params, *body)
Exemple #8
0
def _build_for_device(flist, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    flist : list of LoweredFunc
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : list of LoweredFunc
        A list of lowered functions for the host.

    mdev : tvm.module
        A module that contains device code.
    """
    @tvm.tir.transform.prim_func_pass(opt_level=0)
    class BindTarget:
        def __init__(self, target):
            self.target = target

        # pylint: disable=unused-argument
        def transform_function(self, func, mod, ctx):
            return func.with_attr("target", self.target)

    target = _target.create(target)
    device_type = ndarray.context(target.target_name, 0).device_type
    fhost = []
    fdevice = []
    for func in flist:
        if not ir_pass.VerifyMemory(func, device_type):
            raise ValueError(
                "Direct host side access to device memory is detected in %s. "
                "Did you forget to bind?" % func.name)
        if func.func_type == LoweredFunc.MixedFunc:
            if BuildConfig.current().detect_global_barrier:
                func = ir_pass.ThreadSync(func, "global")
            func = ir_pass.ThreadSync(func, "shared")
            func = ir_pass.ThreadSync(func, "warp")
            func = ir_pass.InferFragment(func)
            warp_size = target.thread_warp_size
            func = ir_pass.LowerThreadAllreduce(func, warp_size)
            fsplits = list(ir_pass.SplitHostDevice(func))
            fhost.append(fsplits[0])
            for x in fsplits[1:]:
                fdevice.append(x)
        elif func.func_type == LoweredFunc.HostFunc:
            fhost.append(func)
        elif func.func_type == LoweredFunc.DeviceFunc:
            fdevice.append(func)
        else:
            raise ValueError("unknown function type %d" % func.func_type)

    if "gpu" in target.keys and not fdevice:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do "
            "bind?" % target)

    fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert not fdevice

    target_host = _target.create(target_host)

    # device optimizations
    mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
    opt_device = tvm.ir.transform.Sequential(
        [BindTarget(target),
         tvm.tir.transform.LowerWarpMemory(),
         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
         tvm.tir.transform.LowerIntrin()])
    mod_dev = opt_device(mod_dev)

    # host optimizations
    mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
    opt_host = tvm.ir.transform.Sequential(
        [BindTarget(target_host),
         tvm.tir.transform.LowerTVMBuiltin(),
         tvm.tir.transform.LowerDeviceStorageAccessInfo(),
         tvm.tir.transform.LowerIntrin(),
         tvm.tir.transform.CombineContextCall()])
    mod_host = opt_host(mod_host)

    rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
    return mod_host, rt_mod_dev
def _build_for_device(flist, target, target_host):
    """Build the lowered functions for a device with the given compilation
    target.

    Parameters
    ----------
    flist : list of LoweredFunc
        The schedule to be built.

    target : str or :any:`tvm.target.Target`
        The target and option of the compilation.

    target_host : str or :any:`tvm.target.Target`
        The host compilation target.

    Returns
    -------
    fhost : list of LoweredFunc
        A list of lowered functions for the host.

    mdev : tvm.module
        A module that contains device code.
    """
    target = _target.create(target)
    device_type = ndarray.context(target.target_name, 0).device_type
    fhost = []
    fdevice = []
    for func in flist:
        if not ir_pass.VerifyMemory(func, device_type):
            raise ValueError(
                "Direct host side access to device memory is detected in %s. "
                "Did you forget to bind?" % func.name)
        if func.func_type == LoweredFunc.MixedFunc:
            if BuildConfig.current().detect_global_barrier:
                func = ir_pass.ThreadSync(func, "global")
            func = ir_pass.ThreadSync(func, "shared")
            func = ir_pass.ThreadSync(func, "warp")
            func = ir_pass.InferFragment(func)
            warp_size = target.thread_warp_size
            func = ir_pass.LowerThreadAllreduce(func, warp_size)
            fsplits = list(ir_pass.SplitHostDevice(func))
            fhost.append(fsplits[0])
            for x in fsplits[1:]:
                fdevice.append(x)
        elif func.func_type == LoweredFunc.HostFunc:
            fhost.append(func)
        elif func.func_type == LoweredFunc.DeviceFunc:
            fdevice.append(func)
        else:
            raise ValueError("unknown function type %d" % func.func_type)

    for i, func in enumerate(fdevice):
        warp_size = target.thread_warp_size
        fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)

    if "gpu" in target.keys and not fdevice:
        warnings.warn(
            "Specified target %s, but cannot find device code, did you do "
            "bind?" % target)

    fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
    fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]

    if device_type == ndarray.cpu(0).device_type and target_host == target:
        assert not fdevice

    target_host = _target.create(target_host)
    fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
    fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
    fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
    fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
    fhost = [ir_pass.CombineContextCall(x) for x in fhost]
    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

    return fhost, mdev