def get_binds(args, compact=False, binds=None): """Internal function to get binds and arg_list given arguments. Parameters ---------- args : list of Buffer or Tensor or Var The argument lists to the function. compact : bool If the statement has already bound to a compact buffer. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. Returns ------- binds: dict The bind specification arg_list: list The list of symbolic buffers of arguments. """ binds = {} if binds is None else binds.copy() cfg = BuildConfig.current() arg_list = [] for x in args: if isinstance(x, tensor.Tensor): any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) buffer_type = "auto_broadcast" if any_dim and not compact else "" if x not in binds: buf = tvm.tir.decl_buffer( x.shape, dtype=x.dtype, name=x.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, buffer_type=buffer_type) binds[x] = buf arg_list.append(buf) else: arg_list.append(binds[x]) elif isinstance(x, schedule.Buffer): arg_list.append(x) elif isinstance(x, tvm.tir.Var): arg_list.append(x) else: raise ValueError("args must be Tensor, Buffer or Var") return binds, arg_list
def form_irmodule(sch, args, name, binds): """According to the given schedule, form a function. Parameters ---------- sch : tvm.te.schedule.Schedule The given scheduler to form the raw body args : list of Buffer or Tensor or Var The argument lists to the function. name : str The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional The binds information Returns ------- The body formed according to the given schedule """ # normalize schedule first cfg = BuildConfig.current() sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) compact = schedule.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) func = func.with_attr("global_symbol", name) if cfg.restricted_func: func = func.with_attr("tir.noalias", True) return tvm.IRModule({name: func})
def _build_for_device(input_mod, target, target_host): """Build the lowered functions for a device with the given compilation target. Parameters ---------- input_mod : IRModule The schedule to be built. target : str or :any:`tvm.target.Target` The target and option of the compilation. target_host : str or :any:`tvm.target.Target` The host compilation target. Returns ------- fhost : IRModule The host IRModule. mdev : tvm.module A module that contains device code. """ target = _target.create(target) target_host = _target.create(target_host) device_type = ndarray.context(target.target_name, 0).device_type mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply( lambda f: f.with_attr("target", target))(mod_mixed) opt_mixed = [tvm.tir.transform.VerifyMemory()] if len(mod_mixed.functions) == 1: opt_mixed += [ tvm.tir.transform.Apply( lambda f: f.with_attr("tir.is_entry_func", True)) ] if BuildConfig.current().detect_global_barrier: opt_mixed += [tvm.tir.transform.ThreadSync("global")] opt_mixed += [ tvm.tir.transform.ThreadSync("shared"), tvm.tir.transform.ThreadSync("warp"), tvm.tir.transform.InferFragment(), tvm.tir.transform.LowerThreadAllreduce(), tvm.tir.transform.MakePackedAPI(), tvm.tir.transform.SplitHostDevice() ] mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) # device optimizations opt_device = tvm.transform.Sequential([ tvm.tir.transform.Filter( lambda f: "calling_conv" in f.attrs and f.attrs[ "calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), tvm.tir.transform.Simplify(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin() ]) mod_dev = opt_device(mod_mixed) # host optimizations opt_host = tvm.transform.Sequential([ tvm.tir.transform.Filter( lambda f: "calling_conv" not in f.attrs or f.attrs[ "calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin(), tvm.tir.transform.CombineContextCall() ]) mod_host = opt_host(mod_mixed) if device_type == ndarray.cpu(0).device_type and target_host == target: assert len(mod_dev.functions) == 0 if "gpu" in target.keys and len(mod_dev.functions) == 0: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target) rt_mod_dev = codegen.build_module( mod_dev, target) if len(mod_dev.functions) != 0 else None return mod_host, rt_mod_dev
def lower(sch, args, name="main", binds=None, simple_mode=False): """Lowering step before build into target. Parameters ---------- sch : tvm.te.schedule.Schedule The schedule to be built args : list of Buffer or Tensor or Var The argument lists to the function. name : str, optional The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. simple_mode : bool, optional Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- m : IRModule or Stmt The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 if isinstance(sch, schedule.Schedule): mod = form_irmodule(sch, args, name, binds) else: mod = sch pass_list = lower_phase0 # Phase 1 pass_list += [ tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), ] pass_list += lower_phase1 # Phase 2 if not simple_mode: pass_list += [ (tvm.tir.transform.LoopPartition(cfg.partition_const_loop)) ] pass_list += [ tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize), tvm.tir.transform.InjectVirtualThread(), tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop), tvm.tir.transform.StorageRewrite(), tvm.tir.transform.UnrollLoop(cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit), ] pass_list += lower_phase2 # Phase 3 pass_list += [ tvm.tir.transform.Simplify(), tvm.tir.transform.RemoveNoOp(), ] if not cfg.disable_select_rewriting: pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] pass_list += lower_phase3 # Instrument BoundCheckers if cfg.instrument_bound_checkers: pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] optimize = tvm.transform.Sequential(pass_list) mod = optimize(mod) return mod
def lower(sch, args, name="default_function", binds=None, simple_mode=False): """Lowering step before build into target. Parameters ---------- sch : tvm.te.schedule.Schedule The schedule to be built args : list of Buffer or Tensor or Var The argument lists to the function. name : str, optional The name of result function. binds : dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. simple_mode : bool, optional Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. Returns ------- m : IRModule or Stmt The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] if cfg.dump_pass_ir: add_lower_pass = BuildConfig._dump_ir.decorate_custompass( add_lower_pass) lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) for f in lower_phase0: stmt = f(stmt) compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) else: stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) if simple_mode: return stmt f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol", tvm.runtime.String(name)) if cfg.restricted_func: f = f.with_attr("tir.noalias", True) mod = tvm.IRModule({name: f}) return tvm.tir.transform.MakePackedAPI()(mod)
def _build_for_device(flist, target, target_host): """Build the lowered functions for a device with the given compilation target. Parameters ---------- flist : list of LoweredFunc The schedule to be built. target : str or :any:`tvm.target.Target` The target and option of the compilation. target_host : str or :any:`tvm.target.Target` The host compilation target. Returns ------- fhost : list of LoweredFunc A list of lowered functions for the host. mdev : tvm.module A module that contains device code. """ target = _target.create(target) target_host = _target.create(target_host) device_type = ndarray.context(target.target_name, 0).device_type for func in flist: if not ir_pass.VerifyMemory(func, device_type): raise ValueError( "Direct host side access to device memory is detected in %s. " "Did you forget to bind?" % func.name) mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist) opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))] if BuildConfig.current().detect_global_barrier: opt_mixed += [tvm.tir.transform.ThreadSync("global")] opt_mixed += [tvm.tir.transform.ThreadSync("shared"), tvm.tir.transform.ThreadSync("warp"), tvm.tir.transform.InferFragment(), tvm.tir.transform.LowerThreadAllreduce(), tvm.tir.transform.BindDeviceType(), tvm.tir.transform.SplitHostDevice()] mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed) # device optimizations opt_device = tvm.ir.transform.Sequential( [tvm.tir.transform.Filter( lambda f: "calling_conv" in f.attrs and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) mod_dev = opt_device(mod_mixed) # host optimizations opt_host = tvm.ir.transform.Sequential( [tvm.tir.transform.Filter( lambda f: "calling_conv" not in f.attrs or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin(), tvm.tir.transform.CombineContextCall()]) mod_host = opt_host(mod_mixed) if device_type == ndarray.cpu(0).device_type and target_host == target: assert len(mod_dev.functions) == 0 if "gpu" in target.keys and len(mod_dev.functions) == 0: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target) rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None return mod_host, rt_mod_dev
def decl_tensor_intrin(op, fcompute, name="tensor_intrin", binds=None, scalar_params=None): """Declare a tensor intrinsic function. Parameters ---------- op: Operation The symbolic description of the intrinsic operation fcompute: lambda function of inputs, outputs-> stmt Specifies the IR statement to do the computation. See the following note for function signature of fcompute .. note:: **Parameters** - **ins** (list of :any:`Buffer`) - Placeholder for each inputs - **outs** (list of :any:`Buffer`) - Placeholder for each outputs **Returns** - **stmt** (:any:`Stmt`, or tuple of three stmts) - If a single stmt is returned, it represents the body - If tuple of three stmts are returned they corresponds to body, reduce_init, reduce_update name: str, optional The name of the intrinsic. binds: dict of :any:`Tensor` to :any:`Buffer`, optional Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. scalar_params: a list of variables used by op, whose values will be passed as scalar_inputs when the tensor intrinsic is called. Returns ------- intrin: TensorIntrin A TensorIntrin that can be used in tensorize schedule. """ if not isinstance(op, _tensor.Operation): raise TypeError("expect Operation") inputs = op.input_tensors binds = binds if binds else {} tensors = list(inputs) for i in range(op.num_outputs): tensors.append(op.output(i)) binds_list = [] for t in inputs: if not isinstance(t.op, PlaceholderOp): raise ValueError("Do not yet support composition op") cfg = BuildConfig.current() for t in tensors: buf = (binds[t] if t in binds else tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor)) binds_list.append(buf) if scalar_params: body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params) else: body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) scalar_params = [] if isinstance(body, (tvm.tir.PrimExpr, tvm.tir.Stmt)): body = [body] body = [tvm.tir.Evaluate(x) if isinstance(x, tvm.tir.PrimExpr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _ffi_api.TensorIntrin( name, op, inputs, binds_list, scalar_params, *body)
def _build_for_device(flist, target, target_host): """Build the lowered functions for a device with the given compilation target. Parameters ---------- flist : list of LoweredFunc The schedule to be built. target : str or :any:`tvm.target.Target` The target and option of the compilation. target_host : str or :any:`tvm.target.Target` The host compilation target. Returns ------- fhost : list of LoweredFunc A list of lowered functions for the host. mdev : tvm.module A module that contains device code. """ @tvm.tir.transform.prim_func_pass(opt_level=0) class BindTarget: def __init__(self, target): self.target = target # pylint: disable=unused-argument def transform_function(self, func, mod, ctx): return func.with_attr("target", self.target) target = _target.create(target) device_type = ndarray.context(target.target_name, 0).device_type fhost = [] fdevice = [] for func in flist: if not ir_pass.VerifyMemory(func, device_type): raise ValueError( "Direct host side access to device memory is detected in %s. " "Did you forget to bind?" % func.name) if func.func_type == LoweredFunc.MixedFunc: if BuildConfig.current().detect_global_barrier: func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "warp") func = ir_pass.InferFragment(func) warp_size = target.thread_warp_size func = ir_pass.LowerThreadAllreduce(func, warp_size) fsplits = list(ir_pass.SplitHostDevice(func)) fhost.append(fsplits[0]) for x in fsplits[1:]: fdevice.append(x) elif func.func_type == LoweredFunc.HostFunc: fhost.append(func) elif func.func_type == LoweredFunc.DeviceFunc: fdevice.append(func) else: raise ValueError("unknown function type %d" % func.func_type) if "gpu" in target.keys and not fdevice: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target) fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] if device_type == ndarray.cpu(0).device_type and target_host == target: assert not fdevice target_host = _target.create(target_host) # device optimizations mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice) opt_device = tvm.ir.transform.Sequential( [BindTarget(target), tvm.tir.transform.LowerWarpMemory(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) mod_dev = opt_device(mod_dev) # host optimizations mod_host = tvm.testing.LoweredFuncsToIRModule(fhost) opt_host = tvm.ir.transform.Sequential( [BindTarget(target_host), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin(), tvm.tir.transform.CombineContextCall()]) mod_host = opt_host(mod_host) rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None return mod_host, rt_mod_dev
def _build_for_device(flist, target, target_host): """Build the lowered functions for a device with the given compilation target. Parameters ---------- flist : list of LoweredFunc The schedule to be built. target : str or :any:`tvm.target.Target` The target and option of the compilation. target_host : str or :any:`tvm.target.Target` The host compilation target. Returns ------- fhost : list of LoweredFunc A list of lowered functions for the host. mdev : tvm.module A module that contains device code. """ target = _target.create(target) device_type = ndarray.context(target.target_name, 0).device_type fhost = [] fdevice = [] for func in flist: if not ir_pass.VerifyMemory(func, device_type): raise ValueError( "Direct host side access to device memory is detected in %s. " "Did you forget to bind?" % func.name) if func.func_type == LoweredFunc.MixedFunc: if BuildConfig.current().detect_global_barrier: func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "warp") func = ir_pass.InferFragment(func) warp_size = target.thread_warp_size func = ir_pass.LowerThreadAllreduce(func, warp_size) fsplits = list(ir_pass.SplitHostDevice(func)) fhost.append(fsplits[0]) for x in fsplits[1:]: fdevice.append(x) elif func.func_type == LoweredFunc.HostFunc: fhost.append(func) elif func.func_type == LoweredFunc.DeviceFunc: fdevice.append(func) else: raise ValueError("unknown function type %d" % func.func_type) for i, func in enumerate(fdevice): warp_size = target.thread_warp_size fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) if "gpu" in target.keys and not fdevice: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target) fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] if device_type == ndarray.cpu(0).device_type and target_host == target: assert not fdevice target_host = _target.create(target_host) fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] mdev = codegen.build_module(fdevice, str(target)) if fdevice else None return fhost, mdev