def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out): """Generate scatter ir for 1d inputs, using a sorting based approach. By sorting indices and comparing neighboring two indices, we can tell which of elements in the indices tensor can scatter its update value into the output. Sorting of indices, and sorting of updates with respect to indices, can be done at the same time by thrust's sort_by_key function. It is important that sorting be done in a "stable" way via stable_sort, to guarantee deterministic output. Negative indices are assumed to have been converted to corresponding positive indices. Parameters ---------- data : tir.Tensor The input data to the operator. indices_sorted : tir.Tensor The sorted index locations to update. updates : tir.Tensor The values to update, sorted by indices. out : tir.Tensor The output tensor. Returns ------- ret : tir The computational ir. """ n = data.shape[0] ib = tvm.tir.ir_builder.create() out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads with ib.new_scope(): nthread_bx = ceil_div(n, nthread_tx) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx with ib.if_scope(tid < n): out_ptr[tid] = data_ptr[tid] indices_ptr = ib.buffer_ptr(indices_sorted) updates_ptr = ib.buffer_ptr(updates_sorted) ni = indices_sorted.shape[0] with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx with ib.if_scope(tid == ni - 1): # The last element can always update. index = indices_ptr[tid] update = updates_ptr[tid] out_ptr[index] = update with ib.else_scope(): with ib.if_scope(tid < ni - 1): index = indices_ptr[tid] index_next = indices_ptr[tid + 1] # If the next neighbor in the sorted list of indices has a different index, # that means thread tid is the last one to have this index. # This thread can update the output. with ib.if_scope(index != index_next): update = updates_ptr[tid] out_ptr[index] = update return ib.get()
def schedule_conv2d_winograd(cfg, s, output, pre_computed): """Schedule winograd template""" inverse = s[output].op.input_tensors[0] bgemm, A = s[inverse].op.input_tensors kernel_pack, data_pack_trans = s[bgemm].op.input_tensors data_pack = s[data_pack_trans].op.input_tensors[0] input_tile, B = s[data_pack].op.input_tensors pad_data = s[input_tile].op.input_tensors[0] # data transform s[B].compute_inline() s[A].compute_inline() # probably will improve real topology execution if autotvm.GLOBAL_SCOPE.in_tuning: # Padding to texture AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) bind_data_copy(s[AA]) s[input_tile].compute_inline() OL = s.cache_write(data_pack, "local") c, p, eps, nu, cb = s[data_pack].op.axis fused = s[data_pack].fuse(c, p, eps, nu) bx, tx = s[data_pack].split(fused, 128) s[data_pack].vectorize(cb) s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) _, _, eps, nu, cb = s[OL].op.axis r_a, r_b = s[OL].op.reduce_axis s[OL].unroll(eps) s[OL].unroll(nu) s[OL].unroll(r_a) s[OL].unroll(r_b) s[OL].vectorize(cb) s[OL].compute_at(s[data_pack], tx) s[data_pack].set_scope(get_texture_storage(data_pack.shape)) s[data_pack_trans].compute_inline() # transform kernel if not pre_computed: kernel, G = s[kernel_pack].op.input_tensors eps, nu, ci, co, cob = s[kernel_pack].op.axis if autotvm.GLOBAL_SCOPE.in_tuning: # skip this part during tuning to make recrods accurate # this part will be pre-computed during pre-compute optimization pass s[G].pragma(s[G].op.axis[0], "debug_skip_region") s[kernel_pack].pragma(eps, "debug_skip_region") else: s[G].compute_inline() r_a, r_b = s[kernel_pack].op.reduce_axis for axis in [eps, nu, r_a, r_b]: s[kernel_pack].unroll(axis) fused = s[kernel_pack].fuse(ci, co) bb, tt = s[kernel_pack].split(fused, 128) s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) s[kernel_pack].vectorize(cob) s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) else: kernel = kernel_pack if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: # manage scheduling of datacopy pack_data = pad_data.op.input_tensors[0] bind_data_copy(s[pack_data]) bind_data_copy(s[kernel]) elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() s[pad_data].compute_inline() ##### space definition begin ##### cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) b1, b2, y, x, cb = s[bgemm].op.axis rcc = s[bgemm].op.reduce_axis[0] alpha = get_const_int(b1.dom.extent) cfg.define_split( "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 16 ) min_x_div = 1 for bn in range(4, 0, -1): if bgemm.shape[3] % bn == 0: min_x_div = bn break cfg.define_split( "tile_x", x, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= min_x_div and entry.size[1] <= 16, ) cfg.define_split("tile_rc", rcc, num_outputs=2) # TODO: Uncomment the following lines when multi_filter will be introduced # cfg.multi_filter( # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) # ) ##### space definition end ##### # batch gemm OL = s.cache_write(bgemm, "local") if ( autotvm.GLOBAL_SCOPE.in_tuning or isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag ): BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) bind_data_copy(s[BB]) by = s[bgemm].fuse(b1, b2, y) # tile and bind spatial axes bgemm_scope, by = s[bgemm].split(by, nparts=1) by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) s[bgemm].bind(by, te.thread_axis("blockIdx.y")) s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) s[bgemm].bind(vy, te.thread_axis("vthread")) s[bgemm].bind(vx, te.thread_axis("vthread")) s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) s[bgemm].vectorize(cb) s[bgemm].set_scope(get_texture_storage(bgemm.shape)) # tile reduction axes s[OL].compute_at(s[bgemm], tx) b1, b2, y, x, cb = s[OL].op.axis (rcc, rcb) = s[OL].op.reduce_axis b = s[OL].fuse(b1, b2) s[OL].reorder(b, y, x, rcc, rcb, cb) # s[OL].unroll(rcb) s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[OL].pragma(rcb, "unroll_explicit", True) s[OL].vectorize(cb) # schedule inverse, output and fusion if output.op in s.outputs: OL = None else: OL = output s[OL].set_scope("local") output = s.outputs[0] if len(s[output].op.axis) == 4: n, co, h, w = s[output].op.axis cb = None else: n, co, h, w, cb = s[output].op.axis inverse_scope, n = s[output].split(n, nparts=1) fused = s[output].fuse(n, co, h, w) bb, tt = s[output].split(fused, 128) if cb is not None: s[output].reorder(bb, tt, cb) s[output].vectorize(cb) s[output].bind(bb, te.thread_axis("blockIdx.x")) s[output].bind(tt, te.thread_axis("threadIdx.x")) if OL is not None: s[OL].compute_at(s[output], tt) co, p, vh, vw, cb = s[inverse].op.axis r_a, r_b = s[inverse].op.reduce_axis for axis in [vh, vw, r_a, r_b]: s[inverse].unroll(axis) s[inverse].vectorize(cb) s[inverse].compute_at(s[output], tt) return s
# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm import te import numpy as np import topi import unittest from tvm.contrib.nvcc import have_fp16, have_int8 from tvm.contrib import nvcc tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") def test_cuda_vectorize_add(): num_thread = 8 def check_cuda(dtype, n, lanes): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") return if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("Skip because gpu does not have fp16 support") return if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version): print("skip because gpu does not support int8")
def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. Parameters ---------- data : Buffer Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. valid_indices: Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. Returns ------- out : Buffer Sorted valid boxes out_indices : Buffer Incidices of valid boxes in original data """ batch_size = data.shape[0] num_anchors = data.shape[1] elem_length = data.shape[2] ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) valid_indices = ib.buffer_ptr(valid_indices) valid_boxes = ib.buffer_ptr(valid_boxes) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) one = tvm.tir.const(1, dtype=out.dtype) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 nthread_by = batch_size with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid with ib.for_range(0, elem_length) as k: out[(i * num_anchors + j) * elem_length + k] = -one out_indices[i * num_anchors + j] = -1 with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid with ib.if_scope(valid_boxes[i, tid] > 0): with ib.for_range(0, elem_length) as k: out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[(i * num_anchors + j) * elem_length + k] out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get()
def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): """Low level IR to identify bounding boxes given a score threshold. Parameters ---------- data : Buffer Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. score_threshold : Buffer or float32 Lower limit of score for valid bounding boxes. id_index : optional, int index of the class categories, -1 to disable. score_index: optional, int Index of the scores/confidence of boxes. Returns ------- valid_boxes: Buffer 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. """ batch_size = data.shape[0] num_anchors = data.shape[1] elem_length = data.shape[2] ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) valid_boxes = ib.buffer_ptr(valid_boxes) if isinstance(score_threshold, float): score_threshold = tvm.tir.FloatImm("float32", score_threshold) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid score = data[(i * num_anchors + j) * elem_length + score_index] with ib.if_scope( tvm.tir.all( score > score_threshold, tvm.tir.any( id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0), )): valid_boxes[i * num_anchors + j] = 1 with ib.else_scope(): valid_boxes[i * num_anchors + j] = 0 return ib.get()
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): assert N == 1, "Only consider batch_size = 1 in this template" data = te.placeholder((N, CI, H, W), name="data") kernel = te.placeholder((CO, CI, KH, KW), name="kernel") conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32") s = te.create_schedule([conv.op]) ##### space definition begin ##### n, f, y, x = s[conv].op.axis rc, ry, rx = s[conv].op.reduce_axis cfg = autotvm.get_config() cfg.define_split("tile_f", f, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rc", rc, num_outputs=3) cfg.define_split("tile_ry", ry, num_outputs=3) cfg.define_split("tile_rx", rx, num_outputs=3) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) cfg.define_knob("unroll_explicit", [0, 1]) ##### space definition end ##### # inline padding pad_data = s[conv].op.input_tensors[0] s[pad_data].compute_inline() data, raw_data = pad_data, data output = conv OL = s.cache_write(conv, "local") # create cache stage AA = s.cache_read(data, "shared", [OL]) WW = s.cache_read(kernel, "shared", [OL]) AL = s.cache_read(AA, "local", [OL]) WL = s.cache_read(WW, "local", [OL]) # tile and bind spatial axes n, f, y, x = s[output].op.axis bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) kernel_scope = n # this is the scope to attach global config inside this kernel s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) s[OL].compute_at(s[output], tx) # tile reduction axes n, f, y, x = s[OL].op.axis rc, ry, rx = s[OL].op.reduce_axis rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc) ryo, rym, ryi = cfg["tile_rx"].apply(s, OL, ry) rxo, rxm, rxi = cfg["tile_ry"].apply(s, OL, rx) s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x) s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) s[AL].compute_at(s[OL], rxm) s[WL].compute_at(s[OL], rxm) # cooperative fetching for load in [AA, WW]: n, f, y, x = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) # tune unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) return s, [raw_data, kernel, conv]
def apply( self, sch, op, axes, axis_lens=None, max_unroll=None, vec_size=None, cfg=None, source=None ): """Apply annotation to an array of axes Parameters ---------- sch: tvm.te.schedule.Schedule The tvm schedule op: tvm.te.Operation The stage to be applied axes: Array of tvm.te.schedule.IterVar axis to split axis_lens: Array of int, optional the length of axes max_unroll: int, optional maximum unroll step vec_size: Array of int, optional valid vector lanes for vectorization cfg: ConfigEntity, optional cfg for recording error source: Array of Array tensor, optional source tensor for attaching cache Returns ------- axes : list of tvm.te.schedule.IterVar The transformed axes """ if source is not None: # special case : attach cache_read/cache_write for src, to in zip(source, self.anns): for t in src: sch[t].compute_at(sch[op], axes[to]) else: # other cases for i, ann in enumerate(self.anns): if ann == "none": pass elif ann == "unroll": if max_unroll and axis_lens[i] > max_unroll: cfg.raise_error("Too large factor for unrolling") sch[op].unroll(axes[i]) elif ann == "vec": if vec_size and axis_lens[i] not in vec_size: cfg.raise_error("Wrong size of lanes in vectorization") sch[op].vectorize(axes[i]) elif ann == "blockIdx.x": sch[op].bind(axes[i], thread_axis("blockIdx.x")) elif ann == "blockIdx.y": sch[op].bind(axes[i], thread_axis("blockIdx.y")) elif ann == "blockIdx.z": sch[op].bind(axes[i], thread_axis("blockIdx.z")) elif ann == "threadIdx.x": sch[op].bind(axes[i], thread_axis("threadIdx.x")) elif ann == "threadIdx.y": sch[op].bind(axes[i], thread_axis("threadIdx.y")) elif ann == "threadIdx.z": sch[op].bind(axes[i], thread_axis("threadIdx.z")) elif ann == "vthread": sch[op].bind(axes[i], thread_axis("vthread")) elif ann == "fuse": assert i < len(axes) - 1 axes[i + 1] = sch[op].fuse(axes[i], axes[i + 1]) else: raise RuntimeError("Invalid annotation " + ann) return axes
def _schedule_conv2d_NCHWc_int8(cfg, s, output): conv = output.op.input_tensors[0] packed_data, packed_kernel = conv.op.input_tensors if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag: pad_data = packed_data packed_data = pad_data.op.input_tensors[0] else: pad_data = packed_data if autotvm.GLOBAL_SCOPE.in_tuning: # skip this part during tuning to make recrods accurate # this part will be pre-computed during NNVM's pre-compute optimization pass s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region") s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region") else: if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel": # data and kernel are not pre-computed, schedule layout transform here schedule_injective_from_existing(s, packed_data) schedule_injective_from_existing(s, packed_kernel) if pad_data != packed_data: s[pad_data].compute_inline() # create cache stage AA = s.cache_read(pad_data, "shared", [conv]) WW = s.cache_read(packed_kernel, "shared", [conv]) s[conv].set_scope("local") # handle bias if output.op not in s.outputs: s[output].compute_inline() output = s.outputs[0].output(0) # tile and bind spatial axes if len(s[output].op.axis) == 5: n, f, y, x, c = s[output].op.axis else: # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks # are created from scratch, therefore the real auto-tuning will still happen on 5D output. n, f, y, x = s[output].op.axis cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) # this is the scope to attach global config inside this kernel kernel_scope, n = s[output].split(n, nparts=1) bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) s[output].bind(bn, te.thread_axis("blockIdx.z")) s[output].bind(bf, te.thread_axis("blockIdx.y")) s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x")) s[output].bind(vn, te.thread_axis("vthread")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf if cfg["fuse_yx"].val: s[output].bind(tn, te.thread_axis("threadIdx.z")) s[output].bind(tf, te.thread_axis("threadIdx.y")) tyx = s[output].fuse(ty, tx) s[output].bind(tyx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tyx) # number of threads n_tz = cfg["tile_n"].size[2] n_ty = cfg["tile_f"].size[2] n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] else: s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tx) # number of threads n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] # tile and bind reduction axes n, f, y, x, c = s[conv].op.axis rc, ry, rx, rc_block = s[conv].op.reduce_axis cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2) cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2) cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2) rco, rci = cfg["tile_rc"].apply(s, conv, rc) ryo, ryi = cfg["tile_ry"].apply(s, conv, ry) rxo, rxi = cfg["tile_rx"].apply(s, conv, rx) s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block) cfg.define_reorder("reorder_inner", [rco, ryo, rxo], policy="all") cfg["reorder_inner"].apply(s, conv, [rco, ryo, rxo]) cfg["reorder_inner"].apply(s, conv, [rci, ryi, rxi]) _, rc_block = s[conv].split(rc_block, factor=4) s[conv].tensorize(rc_block, _dp4a) cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]] s[AA].compute_at(s[conv], cache_loc) s[WW].compute_at(s[conv], cache_loc) # cooperative fetching for load in [AA, WW]: c = s[load].op.axis[-1] c_outer, c = s[load].split(c, factor=4) s[load].vectorize(c) fused = s[load].op.axis[:-1] + [c_outer] fused = s[load].fuse(*fused) fused, tx = s[load].split(fused, factor=n_tx) fused, ty = s[load].split(fused, factor=n_ty) fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) # double buffer cfg.define_knob("AA_double_buffer", [0, 1]) cfg.define_knob("WW_double_buffer", [0, 1]) if cfg["AA_double_buffer"].val: s[AA].double_buffer() if cfg["WW_double_buffer"].val: s[WW].double_buffer() # unroll cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", False) return s
def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): """Schedule tensorcore template""" packed_data, packed_kernel = s[Conv].op.input_tensors ic, kh, kw, ii = s[Conv].op.reduce_axis pad_data = s[packed_data].op.input_tensors[0] block_x = te.thread_axis('blockIdx.x') block_y = te.thread_axis('blockIdx.y') block_z = te.thread_axis('blockIdx.z') thread_x = te.thread_axis('threadIdx.x') thread_y = te.thread_axis('threadIdx.y') thread_z = te.thread_axis('threadIdx.z') # Designate the memory hierarchy AS = s.cache_read(packed_data, 'shared', [Conv]) WS = s.cache_read(packed_kernel, 'shared', [Conv]) AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) ConvF = s.cache_write(Conv, 'wmma.accumulator') if Conv.op in s.outputs: output = Conv ConvS = s.cache_read(ConvF, 'shared', [Conv]) OL = ConvS else: output = s.outputs[0].output(0) s[Conv].set_scope('shared') OL = Conv out_dtype = Conv.dtype if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel": if autotvm.GLOBAL_SCOPE.in_tuning: s[packed_kernel].pragma( s[packed_kernel].op.axis[0], "debug_skip_region") else: with Target('cuda'): schedule_injective_from_existing(s, packed_kernel) if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag: s[pad_data].compute_inline() data = pad_data.op.input_tensors[0] if autotvm.GLOBAL_SCOPE.in_tuning: # skip this part during tuning to make recrods accurate # this part will be pre-computed during NNVM's pre-compute optimization pass s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region") else: data = pad_data s[data].compute_inline() data_dtype = data.dtype kernel_dtype = packed_kernel.dtype # Schedule for autotvm cfg.define_knob("block_row_warps", [1, 2, 4]) cfg.define_knob("block_col_warps", [1, 2, 4]) cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("fuse_pack", [0, 1]) cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32]) cfg.define_knob("vector_ws", [1, 8]) cfg.define_knob("vector_as", [1, 8, 16]) block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val vector_as = cfg["vector_as"].val vector_ws = cfg["vector_ws"].val split_block_k_nums = cfg["split_block_k_nums"].val fuse_pack = cfg["fuse_pack"].val if not fuse_pack: s[packed_data].compute_inline() else: with Target('cuda'): schedule_injective_from_existing(s, packed_data) if data_dtype in ['int4', 'uint4']: wmma_m = wmma_n = 8 wmma_k = 32 else: wmma_m = 8 wmma_n = 32 wmma_k = 16 warp_size = 32 # Schedule for output if len(s[output].op.axis) == 4: hc, wc, nc, oc, = output.op.axis nc, nnc = s[output].split(nc, factor=wmma_m) oc, ooc = s[output].split(oc, factor=wmma_n) else: hc, wc, nc, oc, nnc, ooc = output.op.axis kernel_scope, hc = s[output].split(hc, nparts=1) block_k = s[output].fuse(hc, wc) block_k, split_block_k = s[output].split( block_k, factor=split_block_k_nums) nc, nci = s[output].split(nc, factor=warp_row_tiles) block_i, nc = s[output].split(nc, factor=block_row_warps) oc, oci = s[output].split(oc, factor=warp_col_tiles) block_j, oc = s[output].split(oc, factor=block_col_warps) s[output].reorder(block_k, split_block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) t = s[output].fuse(nnc, ooc) _, tx = s[output].split(t, factor=warp_size) s[output].bind(block_k, block_z) s[output].bind(block_i, block_x) s[output].bind(block_j, block_y) s[output].bind(tx, thread_x) s[output].bind(nc, thread_y) s[output].bind(oc, thread_z) # Schedule wmma store s[OL].compute_at(s[output], block_j) hc, wc, nc, oc, nnc, ooc = OL.op.axis oc, oci = s[OL].split(oc, factor=warp_col_tiles) _, oc = s[OL].split(oc, factor=block_col_warps) nc, nci = s[OL].split(nc, factor=warp_row_tiles) _, nc = s[OL].split(nc, factor=block_row_warps) s[OL].reorder(nc, oc, nci, oci, nnc, ooc) s[OL].bind(nc, thread_y) s[OL].bind(oc, thread_z) # Schedule local computation s[ConvF].compute_at(s[OL], oc) _, _, n, o, nnf, oof = ConvF.op.axis ko, ki = s[ConvF].split(ic, factor=chunk) s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) cfg.define_reorder("reorder_inner", [ko, kh], policy="all") cfg["reorder_inner"].apply(s, ConvF, [ko, kh]) cfg["reorder_inner"].apply(s, ConvF, [ki, kw]) cfg.define_knob("compute_at_AS", [0, 1, 2, 3]) cfg.define_knob("compute_at_WS", [0, 1, 2, 3]) compute_at_AS = cfg["compute_at_AS"].val compute_at_WS = cfg["compute_at_WS"].val # Move intermediate computation into each output compute tile s[AF].compute_at(s[ConvF], kw) s[WF].compute_at(s[ConvF], kw) # Schedule for A's share memory if compute_at_AS == 0: s[AS].compute_at(s[ConvF], ki) elif compute_at_AS == 1: s[AS].compute_at(s[ConvF], kw) elif compute_at_AS == 2: s[AS].compute_at(s[ConvF], ko) else: s[AS].compute_at(s[ConvF], kh) _, _, n, _, nn, ii = AS.op.axis tx, xo = s[AS].split(n, nparts=block_row_warps) ty, _ = s[AS].split(xo, nparts=block_col_warps) t = s[AS].fuse(nn, ii) to, ti = s[AS].split(t, nparts=warp_size) ti, _t = s[AS].split(ti, factor=vector_as) s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) s[AS].bind(to, thread_x) s[AS].vectorize(_t) # Schedule for W's share memory if compute_at_WS == 0: s[WS].compute_at(s[ConvF], ki) elif compute_at_WS == 1: s[WS].compute_at(s[ConvF], kw) elif compute_at_WS == 2: s[WS].compute_at(s[ConvF], ko) else: s[WS].compute_at(s[ConvF], kh) s[WS].compute_at(s[ConvF], kw) kh, kw, ic, o, ii, oo = WS.op.axis tx, xo = s[WS].split(o, nparts=block_row_warps) ty, _ = s[WS].split(xo, nparts=block_col_warps) t = s[WS].fuse(ii, oo) to, ti = s[WS].split(t, nparts=warp_size) ti, _t = s[WS].split(ti, factor=vector_ws) s[WS].bind(tx, thread_y) s[WS].bind(ty, thread_z) s[WS].bind(to, thread_x) s[WS].vectorize(ti) # double buffer cfg.define_knob('AS_double_buffer', [0, 1]) cfg.define_knob('WS_double_buffer', [0, 1]) if cfg['AS_double_buffer'].val: s[AS].double_buffer() if cfg['WS_double_buffer'].val: s[WS].double_buffer() # unroll cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) s[output].pragma(kernel_scope, 'unroll_explicit', False) shape = (wmma_m, wmma_n, wmma_k) AS_shape = (wmma_m, wmma_k) AL_shape = (wmma_m, wmma_k) WS_shape = (wmma_n, wmma_k) WL_shape = (wmma_n, wmma_k) CL_shape = (wmma_m, wmma_n) CS_shape = (wmma_m, wmma_n) AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype) WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k") CL_compute = te.compute(CL_shape, lambda ii, jj: te.sum((AL_gemm[ii, k_gemm].astype( 'int32') * WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm), name='C') AL_strides = [wmma_k, 1] AS_strides = [wmma_k, 1] WL_strides = [wmma_k, 1] WS_strides = [wmma_k, 1] CL_strides = [wmma_n, 1] CS_strides = [wmma_n, 1] s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, data_dtype)) s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, "col_major", WS_shape, WL_shape, kernel_dtype)) s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)) s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape)) return s
def _schedule_group_conv2d_nchw_direct(cfg, s, conv): """Schedule group conv2d NCHW direct template""" workload = conv.op.attrs["workload"] groups = get_const_int(workload[6]) num_filters = get_const_int(conv.shape[1]) ##### space definition begin ##### n, f, y, x = s[conv].op.axis rc, ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_n", n, num_outputs=4) cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2) cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rc", rc, num_outputs=2) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) target = tvm.target.Target.current() if target.kind.name in ["nvptx", "rocm"]: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) pad_data, kernel = s[conv].op.input_tensors s[pad_data].compute_inline() if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage AA = s.cache_read(pad_data, "shared", [OL]) WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, f, y, x = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) g, f = s[output].split(f, nparts=groups) bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bg, vg = cfg["tile_g"].apply(s, output, g) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) s[output].bind(bn, te.thread_axis("blockIdx.z")) s[output].bind(s[output].fuse(bg, bf), te.thread_axis("blockIdx.y")) s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x")) s[output].bind(vn, te.thread_axis("vthread")) s[output].bind(vg, te.thread_axis("vthread")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf if cfg["fuse_yx"].val: s[output].bind(tn, te.thread_axis("threadIdx.z")) s[output].bind(tf, te.thread_axis("threadIdx.y")) tyx = s[output].fuse(ty, tx) s[output].bind(tyx, te.thread_axis("threadIdx.x")) s[OL].compute_at(s[output], tyx) # number of threads n_tz = cfg["tile_n"].size[2] n_ty = cfg["tile_f"].size[2] n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] else: s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[OL].compute_at(s[output], tx) # number of threads n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] # tile reduction axes n, f, y, x = s[OL].op.axis rc, ry, rx = s[OL].op.reduce_axis rco, rci = cfg["tile_rc"].apply(s, OL, rc) ryo, ryi = cfg["tile_rx"].apply(s, OL, ry) rxo, rxi = cfg["tile_ry"].apply(s, OL, rx) s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: n, f, y, x = s[load].op.axis fused = s[load].fuse(n, f, y, x) fused, tx = s[load].split(fused, factor=n_tx) fused, ty = s[load].split(fused, factor=n_ty) fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) # unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) N, CO, OH, OW = get_const_tuple(output.shape) _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): """Schedule tensorcore template""" kh, kw, ic = s[Conv].op.reduce_axis out_dtype = Conv.dtype trans_paddata, kernel = s[Conv].op.input_tensors in_dtype = trans_paddata.dtype batch, _, _, _ = get_const_tuple(Conv.shape) _, _, _, out_channels = get_const_tuple(kernel.shape) paddata = s[trans_paddata].op.input_tensors # inline the pad and dtype transform s[trans_paddata].compute_inline() s[kernel].compute_inline() s[paddata[0]].compute_inline() # Designate the memory hierarchy AS = s.cache_read(trans_paddata, 'shared', [Conv]) WS = s.cache_read(kernel, 'shared', [Conv]) AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) ConvF = s.cache_write(Conv, 'wmma.accumulator') if Conv.op in s.outputs: output = Conv ConvS = s.cache_read(ConvF, 'shared', [Conv]) OL = ConvS else: output = s.outputs[0].output(0) s[Conv].set_scope('shared') OL = Conv # Schedule for autotvm cfg.define_knob("block_row_warps", [1, 2, 4]) cfg.define_knob("block_col_warps", [1, 2, 4]) cfg.define_knob("warp_row_tiles", [1, 2, 4]) cfg.define_knob("warp_col_tiles", [1, 2, 4]) cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("offset", [0, 8]) cfg.define_knob("vector_width", [1, 2, 4, 8]) if (batch % 16 == 0 and out_channels % 16 == 0): cfg.define_knob("wmma_m", [16, 8, 32]) elif (batch % 8 == 0 and out_channels % 32 == 0): cfg.define_knob("wmma_m", [8, 16, 32]) elif (batch % 32 == 0 and out_channels % 8 == 0): cfg.define_knob("wmma_m", [32, 16, 8]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( target.id.name, target.model, 'conv2d_nhwc_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val offset = cfg["offset"].val wmma_m = cfg["wmma_m"].val vector_width = cfg["vector_width"].val wmma_k = 16 if wmma_m == 16: wmma_n = 16 elif wmma_m == 8: wmma_n = 32 elif wmma_m == 32: wmma_n = 8 warp_size = 32 block_x = te.thread_axis('blockIdx.x') block_y = te.thread_axis('blockIdx.y') block_z = te.thread_axis('blockIdx.z') thread_x = te.thread_axis('threadIdx.x') thread_y = te.thread_axis('threadIdx.y') thread_z = te.thread_axis('threadIdx.z') # Define the intrin strides def get_strides(extents): return [np.prod(extents[i:]).tolist() for i in range(len(extents))] AS_align = chunk * wmma_k + offset WS_align = warp_col_tiles * block_col_warps * wmma_n + offset block_factor_n = wmma_m * warp_row_tiles * block_row_warps block_factor_o = wmma_n * warp_col_tiles * block_col_warps CS_align = block_factor_o + offset AS_strides = get_strides([1, 1, AS_align, 1]) AL_strides = get_strides([1, 1, wmma_k, 1]) WS_strides = get_strides([WS_align, 1]) WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1]) CS_strides = get_strides([1, 1, CS_align, 1]) # Schedule for output nc, hc, wc, oc = output.op.axis block_k = s[output].fuse(hc, wc) s[output].bind(block_k, block_z) block_i, nc = s[output].split(nc, factor=block_factor_n) block_j, oc = s[output].split(oc, factor=block_factor_o) s[output].reorder(block_k, block_i, block_j, nc, oc) t = s[output].fuse(nc, oc) t, ti = s[output].split(t, factor=vector_width) t, tx = s[output].split(t, factor=warp_size) t, ty = s[output].split(t, factor=block_row_warps) t, tz = s[output].split(t, factor=block_col_warps) s[output].bind(block_i, block_x) s[output].bind(block_j, block_y) s[output].bind(tz, thread_z) s[output].bind(ty, thread_y) s[output].bind(tx, thread_x) s[output].vectorize(ti) # Schedule wmma store s[OL].compute_at(s[output], block_j) nc, hc, wc, oc = OL.op.axis s[OL].reorder(hc, wc, nc, oc) s[OL].storage_align(wc, CS_align - 1, CS_align) oc, ooc = s[OL].split(oc, factor=wmma_n) oc, oci = s[OL].split(oc, factor=warp_col_tiles) _, oc = s[OL].split(oc, factor=block_col_warps) nc, nnc = s[OL].split(nc, factor=wmma_m) nc, nci = s[OL].split(nc, factor=warp_row_tiles) _, nc = s[OL].split(nc, factor=block_row_warps) s[OL].reorder(nc, oc, nci, oci, nnc, ooc) s[OL].bind(nc, thread_y) s[OL].bind(oc, thread_z) # Schedule wmma computation s[ConvF].compute_at(s[OL], oc) n, h, w, o = ConvF.op.axis n, nnf = s[ConvF].split(n, factor=wmma_m) o, oof = s[ConvF].split(o, factor=wmma_n) ic, ii = s[ConvF].split(ic, factor=wmma_k) ko, ki = s[ConvF].split(ic, factor=chunk) s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii) s[AF].compute_at(s[ConvF], ki) s[WF].compute_at(s[ConvF], ki) # Schedule wmma load n, h, w, i = AF.op.axis n, nn = s[AF].split(n, factor=wmma_m) i, ii = s[AF].split(i, factor=wmma_k) s[AF].reorder(n, i, nn, ii) kh, kw, i, o = WF.op.axis i, ii = s[WF].split(i, factor=wmma_k) o, oo = s[WF].split(o, factor=wmma_n) s[WF].reorder(o, i, oo) s[WF].reorder(i, o, ii, oo) s[WS].compute_at(s[ConvF], ko) s[AS].compute_at(s[ConvF], ko) # Schedule for data's share memory n, h, w, i = AS.op.axis s[AS].reorder(h, w, n, i) s[AS].storage_align(w, AS_align - 1, AS_align) t = s[AS].fuse(n, i) t, ti = s[AS].split(t, factor=vector_width) t, tx = s[AS].split(t, factor=warp_size) t, ty = s[AS].split(t, factor=block_row_warps) _, tz = s[AS].split(t, factor=block_col_warps) s[AS].bind(ty, thread_y) s[AS].bind(tz, thread_z) s[AS].bind(tx, thread_x) s[AS].vectorize(ti) # Schedule for kernel's share memory kh, kw, ic, o = WS.op.axis t = s[WS].fuse(ic, o) s[WS].storage_align(ic, WS_align - 1, WS_align) t, ti = s[WS].split(t, factor=vector_width) t, tx = s[WS].split(t, factor=warp_size) t, ty = s[WS].split(t, factor=block_row_warps) _, tz = s[WS].split(t, factor=block_col_warps) s[WS].bind(ty, thread_y) s[WS].bind(tz, thread_z) s[WS].bind(tx, thread_x) s[WS].vectorize(ti) shape = (wmma_m, wmma_n, wmma_k) # tensorize the wmma process AS_shape = (wmma_m, 1, 1, wmma_k) AL_shape = (wmma_m, 1, 1, wmma_k) WS_shape = (wmma_k, wmma_n) WL_shape = (wmma_k, wmma_n) CL_shape = (wmma_m, 1, 1, wmma_n) CS_shape = (wmma_m, 1, 1, wmma_n) AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k") CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), name='C') s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype)) s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, "row_major", WS_shape, WL_shape, in_dtype)) s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)) s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape)) N, OH, OW, CO = get_const_tuple(output.shape) KH, KW, CI, _ = get_const_tuple(kernel.shape) cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
def gen_ir_2d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 2d inputs Parameters ---------- data : tir.Tensor The input data to the operator. indices : tir.Tensor The index locations to update. updates : tir.Tensor The values to update. axis : int The axis to scatter on out : tir.Tensor The output tensor. update_func: function The function to be applied to a destination and the corresponding update Returns ------- ret : tir The computational ir. """ n = data.shape[0] c = data.shape[1] ib = tvm.tir.ir_builder.create() out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) ni = indices.shape[0] ci = indices.shape[1] if axis == 0: with ib.new_scope(): j = te.thread_axis("blockIdx.x") ib.scope_attr(j, "thread_extent", ci) with ib.for_range(0, ni, name="i") as i: idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, (index + n) * c + j, updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, index * c + j, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", ni) with ib.for_range(0, ci, name="j") as j: idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, i * c + (index + c), updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, i * c + index, updates_ptr[idx]) return ib.get()
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i fused_updates_dimension = 1 for i in updates_ptr.shape[len(indices_ptr.shape) - 1:]: fused_updates_dimension *= i fused_shape = 1 for i in data_ptr.shape: fused_shape *= i # For now we avoid parallizing over dimensions indexed by `indices` as # there may be repeated indices and hadling parallel accumulation can # be hard. So we parallelize over X_M .. X_{N-1} instead. This will # work well when these dimensions are large enough to saturate memory # bandwidth, but performance will be bad when these dimensions are # small. bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) ib.scope_attr(tx, "thread_extent", tdim) bdim = ceil_div(fused_updates_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) # Copy data into the output. This loop writes to the same portions of # memory as the following loop, so we do not need a memory sync. with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: index = i * fused_updates_dimension + bx * tdim + tx with ib.if_scope(bx * tdim + tx < fused_updates_dimension): out[index] = data[index] with ib.for_range(0, fused_indices_dimension) as i: j = bx * tdim + tx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] else: raise NotImplementedError( "scatter_nd mode not in [update, add]:", mode) return ib.get()
def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): """Generate scatter add ir for 1d inputs, using atomic_add instruction Parameters ---------- data : tir.Tensor The input data to the operator. indices : tir.Tensor The index locations to update. updates : tir.Tensor The values to update. axis : int The axis to scatter on out : tir.Tensor The output tensor. Returns ------- ret : tir The computational ir. """ assert axis == 0 n = data.shape[0] ib = tvm.tir.ir_builder.create() out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads with ib.new_scope(): nthread_bx = ceil_div(n, nthread_tx) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx with ib.if_scope(tid < n): out_ptr[tid] = data_ptr[tid] indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) ni = indices.shape[0] atomic_add_return = ib.allocate(updates.dtype, (1, ), name="atomic_add_return", scope="local") with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx with ib.if_scope(tid < ni): index = indices_ptr[tid] with ib.if_scope(index < 0): atomic_add_return[0] = atomic_add( tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), updates_ptr[tid], ) with ib.else_scope(): atomic_add_return[0] = atomic_add( tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), updates_ptr[tid], ) return ib.get()
def _schedule_dense_int8(cfg, s, output): data, weight = s[output].op.input_tensors batch, in_dim = get_const_tuple(data.shape) out_dim, _ = get_const_tuple(weight.shape) in_dim_factor = 4 assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor) if in_dim % 16 == 0: in_dim_factor = 16 # create tuning space cfg.define_split("tile_y", batch, num_outputs=4) cfg.define_split("tile_x", out_dim, num_outputs=4) cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) # create cache stage AA = s.cache_read(data, "shared", [output]) WW = s.cache_read(weight, "shared", [output]) CC = s.cache_write(output, "local") # handle bias if output.op not in s.outputs: s[output].compute_inline() output = s.outputs[0].output(0) n, x = s[output].op.axis # this is the scope to attach global config inside this kernel kernel_scope, n = s[output].split(n, nparts=1) ko = CC.op.reduce_axis[0] ko, ki = s[CC].split(ko, factor=4) ko, kt = cfg["tile_k"].apply(s, CC, ko) s[CC].tensorize(ki, _dp4a) by, vy, ty, yi = cfg["tile_y"].apply(s, output, n) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] s[CC].compute_at(s[output], tx) yo, xo = CC.op.axis[:2] s[CC].reorder(ko, kt, yo, xo, ki) for load in [AA, WW]: s[load].compute_at(s[CC], ko) outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor) s[load].vectorize(inner) fused = s[load].op.axis[:-1] + [outer] fused = s[load].fuse(*fused) fused, tx = s[load].split(fused, factor=n_tx) fused, ty = s[load].split(fused, factor=n_ty) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", False) return s
def schedule(attrs): cfg, s, output = attrs.auto_config, attrs.scheduler, attrs.outputs[0] th_vals, rd_vals = [attrs.get_extent(x) for x in output.op.axis], [ attrs.get_extent(x) for x in output.op.reduce_axis ] C = output A, B = C.op.input_tensors y, x = s[C].op.axis k = s[C].op.reduce_axis[0] # storage_align params factor = 16 offset = 8 layout = 'NN' for opt in attrs.options: if opt.startswith('layout/'): layout = opt[len('layout/'):] break ''' if dtype == 'int8': factor = 32 offset = 16 ''' # create cache stages AA = s.cache_read(A, "shared", [C]) if (layout == "NN" or layout == "TN"): s[AA].storage_align(AA.op.axis[0], factor, offset) AL = s.cache_read(AA, "local", [C]) BB = s.cache_read(B, "shared", [C]) if (layout == "TT" or layout == "NT"): s[BB].storage_align(BB.op.axis[0], factor, offset) BL = s.cache_read(BB, "local", [C]) CL = s.cache_write(C, "local") # autotvm search space definition cfg.define_knob("bx", [2, 4, 8]) cfg.define_knob("by", [16, 32, 64]) cfg.define_knob("step_k", [8, 16, 32]) cfg.define_knob("v", [4, 8]) by = cfg['by'].val bx = cfg['bx'].val step_k = cfg['step_k'].val v = cfg['v'].val # thread tile TX, TY = 8, 1 # warp tile cfg.define_knob("warp_m", [16, 8, 32]) warp_tile_m = cfg[ 'warp_m'].val # it could be 8, 16, 32 on CUDA version >= 10.0 warp_tile_k = 16 # it must be 16 # block tile tile_x = bx * TX tile_y = by * TY yo, ty = s[C].split(y, tile_y) ty, yi = s[C].split(ty, TY) # schedule for C stage xo, xi = s[C].split(x, tile_x) WX = min(warp_tile_m, tile_x) tz, xi = s[C].split(xi, WX) tx, xi = s[C].split(xi, TX) s[C].reorder(yo, xo, tz, ty, tx, yi, xi) s[C].bind(yo, te.thread_axis("blockIdx.y")) s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(ty, te.thread_axis("threadIdx.y")) s[C].bind(tz, te.thread_axis("threadIdx.z")) s[C].bind(tx, te.thread_axis("threadIdx.x")) # schedule for CL stage ko, ki = s[CL].split(k, step_k * warp_tile_k) kl, ki = s[CL].split(ki, warp_tile_k) s[CL].compute_at(s[C], tx) yo, xo = CL.op.axis s[CL].reorder(ko, kl, ki, yo, xo) # schedule for AA stage s[AA].compute_at(s[CL], ko) xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v) tz, tx = s[AA].split(xi, factor=(WX // TX) * v) tx, vec = s[AA].split(tx, factor=v) fused = s[AA].fuse(s[AA].op.axis[0], xo) _, ty = s[AA].split(fused, factor=by) s[AA].bind(ty, te.thread_axis("threadIdx.y")) s[AA].bind(tz, te.thread_axis("threadIdx.z")) s[AA].bind(tx, te.thread_axis("threadIdx.x")) # vectorization is very important for float16/int8 inputs s[AA].vectorize(vec) # schedule for BB stage s[BB].compute_at(s[CL], ko) xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v) tz, tx = s[BB].split(xi, factor=(WX // TX) * v) tx, vec = s[BB].split(tx, factor=v) fused = s[BB].fuse(s[BB].op.axis[0], xo) _, ty = s[BB].split(fused, factor=by) s[BB].bind(ty, te.thread_axis("threadIdx.y")) s[BB].bind(tz, te.thread_axis("threadIdx.z")) s[BB].bind(tx, te.thread_axis("threadIdx.x")) s[BB].vectorize(vec) s[AL].compute_at(s[CL], kl) s[BL].compute_at(s[CL], kl) s[CL].pragma(ko, 'tensor_core')
def rnn_matexp(): n_num_step = 128 n_num_hidden = 1152 n_batch_size = 4 detect_global_barrier = DETECT_GLOBAL_BARRIER num_step = te.var("num_step") num_hidden = tvm.runtime.convert(n_num_hidden) batch_size = tvm.runtime.convert(n_batch_size) num_thread_y = 8 num_thread_x = 16 * 3 num_sm = 24 Whh = te.placeholder((num_hidden, num_hidden), name="Whh") s_init = te.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init") s_state = te.placeholder((num_step, batch_size, num_hidden)) kh = te.reduce_axis((0, num_hidden), name="kh") s_update = te.compute( (num_step, batch_size, num_hidden), lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh), name="update", ) s_scan = tvm.te.scan(s_init, s_update, s_state) # schedule s = te.create_schedule(s_scan.op) CL = s_update SS = s.cache_read(s_state, "shared", [CL]) SL = s.cache_read(SS, "local", [CL]) WhhL = s.cache_read(Whh, "local", [CL]) ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y) CLF = s.rfactor(CL, ko) block_x = te.thread_axis((0, num_sm), "blockIdx.x") thread_x = te.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = te.thread_axis((0, num_thread_y), "threadIdx.y") if PERSIST_KERNEL: s[s_scan.op].env_threads([block_x, thread_y, thread_x]) bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm) tx, xi = s[s_init].split(xi, nparts=num_thread_x) s[s_init].bind(bx, block_x) s[s_init].bind(tx, thread_x) bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm) tx, xi = s[s_update].split(xi, nparts=num_thread_x) s[s_update].bind(bx, block_x) s[s_update].bind(tx, thread_x) s[CL].bind(s[CL].op.reduce_axis[0], thread_y) s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0]) # Duplicate store predicate. s[CL].set_store_predicate(thread_y.equal(0)) if PERSIST_KERNEL: s[WhhL].compute_at(s[s_scan], thread_x) s[WhhL].unroll(WhhL.op.axis[0]) else: s[WhhL].compute_at(s[CLF], CLF.op.axis[3]) kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1) ko, ki = s[CLF].split(ki, factor=4) s[SS].compute_at(s[CLF], kr) s[SL].compute_at(s[CLF], ko) xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3) ty, xi = s[SS].split(xi, nparts=num_thread_y) tx, xi = s[SS].split(xi, nparts=num_thread_x) s[SS].bind(ty, thread_y) s[SS].bind(tx, thread_x) def check_device(target): with tvm.transform.PassContext( config={ "tir.UnrollLoop": { "auto_max_step": 128, }, "tir.detect_global_barrier": detect_global_barrier, }): f = tvm.build(s, [s_scan, Whh], target) dev = tvm.cuda(0) if target == "cuda" else tvm.cl(0) # launch the kernel. res_np = np.zeros( (n_num_step, n_batch_size, n_num_hidden)).astype("float32") Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32") Whh_np[:] = 2.0 / n_num_hidden Whh_np[:, n_num_hidden // 2:] = 0 res_a = tvm.nd.array(res_np, dev) Whh_a = tvm.nd.array(Whh_np, dev) # Skip first pass as it is compilation f(res_a, Whh_a) dev.sync() # measure time cost of second step. tstart = time.time() f(res_a, Whh_a) dev.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # correctness if not SKIP_CHECK: res_cuda = res_a.asnumpy() res_cmp = np.ones_like(res_np).astype("float64") Whh_np = Whh_np.astype("float64") for t in range(1, n_num_step): res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np) for i in range(n_num_step): for j in range(n_num_hidden): if abs(res_cmp[i, 0, j] - res_cuda[i, 0, j]) > 1e-5: print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_cuda[i, 0, j])) tvm.testing.assert_allclose(res_cuda, res_cmp, rtol=1e-3) check_device("cuda")
def _callback(op): if op.tag == "conv2d_transpose_nchw": pad_data = op.input_tensors[0] kernel = op.input_tensors[1] conv = op.output(0) ##### space definition begin ##### n, f, y, x = s[conv].op.axis rc = s[conv].op.reduce_axis[0] cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) target = tvm.target.Target.current() if target.kind.name in ["nvptx", "rocm"]: cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) if cfg.is_fallback: N, F, Y, X = get_const_tuple(conv.shape) _fallback_schedule(N, F, Y, X) ##### space definition end ##### if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage s[pad_data].set_scope("shared") AA = pad_data WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, f, y, x = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) s[output].bind(bn, te.thread_axis("blockIdx.z")) s[output].bind(bf, te.thread_axis("blockIdx.y")) s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x")) s[output].bind(vn, te.thread_axis("vthread")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf if cfg["fuse_yx"].val: s[output].bind(tn, te.thread_axis("threadIdx.z")) s[output].bind(tf, te.thread_axis("threadIdx.y")) tyx = s[output].fuse(ty, tx) s[output].bind(s[output].fuse(ty, tx), te.thread_axis("threadIdx.x")) s[OL].compute_at(s[output], tyx) # number of threads n_tz = cfg["tile_n"].size[2] n_ty = cfg["tile_f"].size[2] n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] else: s[output].bind(s[output].fuse(tn, tf), te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[OL].compute_at(s[output], tx) # number of threads n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] # tile reduction axes n, f, y, x = s[OL].op.axis rc, ry, rx = s[OL].op.reduce_axis rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc) s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x) s[AA].compute_at(s[OL], rx) s[WW].compute_at(s[OL], rx) # cooperative fetching for load in [AA, WW]: n, f, y, x = s[load].op.axis fused = s[load].fuse(f, y, x) tz, fused = s[load].split(fused, nparts=n_tz) ty, fused = s[load].split(fused, nparts=n_ty) tx, fused = s[load].split(fused, nparts=n_tx) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
def test_rpc_module(): # graph n = tvm.runtime.convert(1024) A = te.placeholder((n, ), name="A") B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") a_np = np.random.uniform(size=1024).astype(A.dtype) temp = utils.tempdir() # Establish remote connection with target hardware tracker = rpc.connect_tracker(tracker_host, tracker_port) remote = tracker.request(key, priority=0, session_timeout=60) # Compile the Graph for CPU target s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=64) s[B].parallel(xi) s[B].pragma(xo, "parallel_launch_point") s[B].pragma(xi, "parallel_barrier_when_finish") f = tvm.build(s, [A, B], target, name="myadd_cpu") path_dso_cpu = temp.relpath("cpu_lib.so") f.export_library(path_dso_cpu, ndk.create_shared) # Execute the portable graph on cpu target print("Run CPU test ...") dev = remote.cpu(0) remote.upload(path_dso_cpu) f2 = remote.load_module("cpu_lib.so") a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) time_f = f2.time_evaluator(f2.entry_name, dev, number=10) cost = time_f(a, b).mean print("%g secs/op\n" % cost) np.testing.assert_equal(b.numpy(), a.numpy() + 1) # Compile the Graph for OpenCL target if test_opencl: s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=64) s[B].bind(xi, te.thread_axis("threadIdx.x")) s[B].bind(xo, te.thread_axis("blockIdx.x")) # Build the dynamic lib. # If we don't want to do metal and only use cpu, just set target to be target f = tvm.build(s, [A, B], tvm.target.Target("opencl", host=target), name="myadd") path_dso_cl = temp.relpath("dev_lib_cl.so") f.export_library(path_dso_cl, ndk.create_shared) print("Run GPU(OpenCL Flavor) test ...") dev = remote.cl(0) remote.upload(path_dso_cl) f1 = remote.load_module("dev_lib_cl.so") a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, dev, number=10) cost = time_f(a, b).mean print("%g secs/op\n" % cost) np.testing.assert_equal(b.numpy(), a.numpy() + 1) # Compile the Graph for Vulkan target if test_vulkan: s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=64) s[B].bind(xi, te.thread_axis("threadIdx.x")) s[B].bind(xo, te.thread_axis("blockIdx.x")) # Build the dynamic lib. # If we don't want to do metal and only use cpu, just set target to be target f = tvm.build(s, [A, B], tvm.target.Target("vulkan", host=target), name="myadd") path_dso_vulkan = temp.relpath("dev_lib_vulkan.so") f.export_library(path_dso_vulkan, ndk.create_shared) print("Run GPU(Vulkan Flavor) test ...") dev = remote.vulkan(0) remote.upload(path_dso_vulkan) f1 = remote.load_module("dev_lib_vulkan.so") a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, dev, number=10) cost = time_f(a, b).mean print("%g secs/op\n" % cost) np.testing.assert_equal(b.numpy(), a.numpy() + 1)
def copy_to_texture(stage): _y, _x, _v = s[stage].op.axis # TODO(csullivan): removing this vectorize results in numerical errors, autovectorize s[stage].vectorize(_v) s[stage].bind(_y, te.thread_axis("blockIdx.x")) s[stage].bind(_x, te.thread_axis("threadIdx.x"))
def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): """Low level IR to get the ouput indices of valid boxes and the count of valid boxes Parameters ---------- valid_boxes: Buffer 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. Returns ------- valid_count: Buffer 1D Buffer of number of valid boxes per batch [batch_size]. valid_indices: Buffer 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. """ batch_size = valid_boxes.shape[0] num_anchors = valid_boxes.shape[1] ib = tvm.tir.ir_builder.create() valid_boxes = ib.buffer_ptr(valid_boxes) valid_count = ib.buffer_ptr(valid_count) valid_indices = ib.buffer_ptr(valid_indices) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) with ib.if_scope(num_anchors > 0): # Copy boxes to valid_indices with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx with ib.if_scope(tid < num_anchors): valid_indices[by, tid] = valid_boxes[by, tid] nthread_tx = max_threads nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size ## The following algorithm performs parallel exclusive scan to get ## a tensor that can later be used to select valid indices # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( tvm.tir.ceil( tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64") with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr( bx, "thread_extent", tvm.tir.generic.cast( ceil_div(num_anchors, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) start = ib.allocate("int64", (1, ), name="start", scope="local") middle = ib.allocate("int64", (1, ), name="middle", scope="local") end = ib.allocate("int64", (1, ), name="end", scope="local") start[0] = width * tid with ib.if_scope(start[0] < num_anchors): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, num_anchors) with ib.if_scope(middle[0] < num_anchors): valid_indices[by * num_anchors + end[0] - 1] += valid_indices[by * num_anchors + middle[0] - 1] # Down Sweep of exclusive scan with ib.new_scope(): bx = te.thread_axis("blockIdx.x") ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] valid_indices[(bx + 1) * num_anchors - 1] = 0 with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr( bx, "thread_extent", tvm.tir.generic.cast( ceil_div(num_anchors, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) start = ib.allocate("int64", (1, ), name="start", scope="local") middle = ib.allocate("int64", (1, ), name="middle", scope="local") end = ib.allocate("int64", (1, ), name="end", scope="local") tmp = ib.allocate("int32", (1, ), name="end", scope="local") start[0] = width * tid with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.tir.min(start[0] + width, num_anchors) with ib.if_scope(middle[0] < num_anchors): tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[by * num_anchors + end[0] - 1] valid_indices[by * num_anchors + end[0] - 1] += tmp[0] with ib.else_scope(): with ib.new_scope(): bx = te.thread_axis("blockIdx.x") ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): valid_count[bx] = 0 return ib.get()
def schedule_conv2d_1x1_WCHNc_CRSKk(data, filt, packed_data, packed_filter, conv): # data: [W, C, H*N, c] # filter: [C, R*S*K, k] # output: [W, K, H, N, k] # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) # data: (56, 128//4, 56*1, 4) = (56, 32, 56, 4) # filt: (128, 1*1*128//4, 4) = (128, 32, 4) # conv: (56, 32, 56, 1, 4) s = te.create_schedule(conv.op) cfg = autotvm.get_config() s[packed_data].compute_inline() s[packed_filter].compute_inline() A, B, C = packed_data, packed_filter, conv At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(At) copy_to_texture(Bt) _w, _ko, _h, _n, _ki = s[C].op.axis kernel_scope, _n = s[C].split(_n, nparts=1) cfg.define_split("tile_f", _ko, num_outputs=4) cfg.define_split("tile_w", _w, num_outputs=4) cfg.define_split("tile_h", _h, num_outputs=4) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) bk, vk, tk, ki = cfg["tile_f"].apply(s, C, _ko) bw, vw, tw, wi = cfg["tile_w"].apply(s, C, _w) bh, vh, th, hi = cfg["tile_h"].apply(s, C, _h) s[C].reorder(bh, _n, vh, th, hi) bhn = s[C].fuse(bh, _n) s[C].bind(bk, te.thread_axis("blockIdx.z")) s[C].bind(bhn, te.thread_axis("blockIdx.y")) s[C].bind(bw, te.thread_axis("blockIdx.x")) s[C].bind(vk, te.thread_axis("vthread")) s[C].bind(vh, te.thread_axis("vthread")) s[C].bind(vw, te.thread_axis("vthread")) s[C].bind(tk, te.thread_axis("threadIdx.z")) s[C].bind(th, te.thread_axis("threadIdx.y")) s[C].bind(tw, te.thread_axis("threadIdx.x")) s[C].reorder(bw, bk, bhn, vw, vk, vh, tw, tk, th, ki, hi, wi, _ki) s[C].vectorize(_ki) # TODO(csullivan): Try uneven workgroup split # _wo, _wi = s[C].split(_w, factor=4) # #_hno, _hni = s[C].split(_hn, factor=8) # #s[C].reorder(_wo, _wi, _ko, _hno, _hni, _ki) # s[C].reorder(_wo, _ko, _hn, _ki, _wi) # s[C].unroll(_wi) # # mace: # # const int out_ch_blk = get_global_id(0); # # const int out_w_blk = get_global_id(1); # # const int out_hb = get_global_id(2); # bx = te.thread_axis("blockIdx.x") # by = te.thread_axis("blockIdx.y") # bz = te.thread_axis("blockIdx.z") # s[C].bind(_ko, bx) # s[C].bind(_wo, by) # s[C].bind(_hn, bz) # s[Cl].compute_at(s[C], _hn) s[Cl].compute_at(s[C], th) _wl, _kol, _hl, _nl, _kil = s[Cl].op.axis _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis cfg.define_split("tile_c", _cl, num_outputs=2) cfg.define_split("tile_kh", _khl, num_outputs=2) cfg.define_split("tile_kw", _kwl, num_outputs=2) _clo, _cli = cfg["tile_c"].apply(s, Cl, _cl) _khlo, _khli = cfg["tile_kh"].apply(s, Cl, _khl) _kwlo, _kwli = cfg["tile_kw"].apply(s, Cl, _kwl) # s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli, _kol, _hl, _nl, _kil, _wl) # s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli) # s[Cl].reorder(_cl, _cl4, _kil, _wl) s[Cl].unroll(_cl4) s[Cl].unroll(_wl) s[Cl].vectorize(_kil) _wla, _cla, _hnla, _cl4a = s[Al].op.axis s[Al].compute_at(s[Cl], _cli) s[Al].vectorize(_cl4a) s[Al].unroll(_wla) _clb, _rskolb, _kilb = s[Bl].op.axis s[Bl].compute_at(s[Cl], _cli) s[Bl].vectorize(_kilb) s[Bl].unroll(_clb) s[C].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) WO, K, HO, N, K4 = get_const_tuple(C.shape) RSC, _, _ = get_const_tuple(B.shape) cfg.add_flop(2 * N * K * K4 * HO * WO * RSC) return s
def nms_ir( data, sorted_index, valid_count, indices, out, box_indices, num_valid_boxes, max_output_size, iou_threshold, force_suppress, top_k, coord_start, id_index, score_index, return_indices, ): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- data : Buffer Buffer of output boxes with class and score. sorted_index : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer Buffer of number of valid output boxes. indices : Buffer indices in original tensor, with shape [batch_size, num_anchors], represents the index of box in original data. It could be the third output out_indices of get_valid_counts. The values in the second dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. out : Buffer Output buffer, to be filled with sorted boxes. box_indices : Buffer A indices tensor mapping sorted indices to original indices This is the first output of NMS when return_indices=True. num_valid_boxes : Buffer Record the number of boxes that have survived IOU tests. This is the second output of NMS when return_indices=True. max_output_size : int Max number of output valid boxes for each instance. By default all valid boxes are returned. iou_threshold : float Overlapping(IoU) threshold to suppress object with smaller score. force_suppress : boolean Whether to suppress all detections regardless of class_id. top_k : int Keep maximum top k detections before nms, -1 for no limit. coord_start : int Start index of the consecutive 4 coordinates. id_index : int index of the class categories, -1 to disable. score_index : optional, int Index of the scores/confidence of boxes. return_indices : boolean Whether to return box indices in input data. Returns ------- stmt : Stmt The result IR statement. """ def get_boundaries(output, box_idx): l = tvm.te.min( output[box_idx], output[box_idx + 2], ) t = tvm.te.min( output[box_idx + 1], output[box_idx + 3], ) r = tvm.te.max( output[box_idx], output[box_idx + 2], ) b = tvm.te.max( output[box_idx + 1], output[box_idx + 3], ) return l, t, r, b def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes.""" a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) # Overlapping width and height w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) # Overlapping area area = h * w # total area of the figure formed by box a and box b # except for overlapping area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area return tvm.tir.Select(u <= 0.0, 0.0, area / u) batch_size = data.shape[0] num_anchors = data.shape[1] box_data_length = data.shape[2] ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) indices = ib.buffer_ptr(indices) num_valid_boxes = ib.buffer_ptr(num_valid_boxes) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) top_k = tvm.tir.IntImm("int32", top_k) coord_start = tvm.tir.IntImm("int32", coord_start) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) force_suppress = tvm.tir.IntImm("int32", 1 if force_suppress else 0) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) i = by base_idx = i * num_anchors * box_data_length with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output nkeep = if_then_else( tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) j = bx * max_threads + tx with ib.if_scope(j < num_anchors): box_indices[i * num_anchors + j] = -1 with ib.if_scope(j < nkeep): # Fill in out with sorted boxes with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[( base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)] with ib.else_scope(): # Indices > nkeep are discarded with ib.if_scope(j < num_anchors): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = -1.0 with ib.else_scope(): with ib.if_scope(j < valid_count[i]): with ib.for_range(0, box_data_length) as k: offset = base_idx + j * box_data_length + k out[offset] = data[offset] box_indices[i * num_anchors + j] = j with ib.new_scope(): nthread_by = batch_size nthread_tx = max_threads by = te.thread_axis("blockIdx.y") tx = te.thread_axis("threadIdx.x") ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(tx, "thread_extent", nthread_tx) i = by base_idx = i * num_anchors * box_data_length num_valid_boxes_local = ib.allocate("int32", (1, ), name="num_valid_boxes_local", scope="local") num_valid_boxes_local[0] = 0 nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) def nms_inner_loop(ib, j): # The box j is valid, invalidate other boxes that overlap with j above iou_threshold # When return_indices is False, no need to populate box_indices if return_indices: with ib.if_scope(tx + 0 == 0): orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 offset_j = j * box_data_length num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) with ib.for_range(0, num_iter_per_thread) as _k: k = j + 1 + _k * nthread_tx + tx offset_k = k * box_data_length with ib.if_scope( tvm.tir.all( k < nkeep, out[base_idx + offset_k + score_index] > 0, # is the box k still valid? tvm.tir.any( force_suppress > 0, id_index < 0, out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index], ), )): iou = calculate_overlap( out, base_idx + offset_j + coord_start, base_idx + offset_k + coord_start, ) with ib.if_scope(iou >= iou_threshold): # invalidate the box k out[base_idx + offset_k + score_index] = -1.0 with ib.if_scope(id_index >= 0): out[base_idx + offset_k + id_index] = -1.0 # Make sure to do the next loop in a lock step ib.emit( tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size) with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, nkeep) as j: # Proceed to the inner loop if the box j is still valid with ib.if_scope( out[base_idx + (j * box_data_length) + score_index] > -1.0): with ib.if_scope(max_output_size > 0): # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope( num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) with ib.else_scope(): nms_inner_loop(ib, j) with ib.if_scope(tx + 0 == 0): num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0 return ib.get()
def schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): """schedule optimized for batch size = 1""" conv = output.op.input_tensors[0] ##### space definition begin ##### n, fc, y, x, fb = s[conv].op.axis ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_fc", fc, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) pad_data, flattened_kernel = s[conv].op.input_tensors kernel = s[flattened_kernel].op.input_tensors[0] s[flattened_kernel].compute_inline() s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() kernel = flattened_kernel if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(AT) copy_to_texture(WT) AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bf = s[output].fuse(n, bf) s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) s[output].vectorize(fb) s[OL].compute_at(s[output], tx) # tile reduction axes n, fc, y, x, fb = s[OL].op.axis ry, rx = s[OL].op.reduce_axis ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) s[OL].vectorize(fb) # s[OL].unroll() s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: if load == WW: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: n, f, y, x, v = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[load].vectorize(v) # unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) ICC, MKHKW, ICB = get_const_tuple(kernel.shape) M = (OCC * OCB) // (ICC * ICB) KHKW = MKHKW // M if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW)
# # .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/conv_gpu_blocking.png # :align: center # :height: 308px # :width: 317px # # tile consts tile = 8 num_thread = 8 block_factor = tile * num_thread step = 8 vthread = 2 # Get the GPU thread indices block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") block_z = te.thread_axis("blockIdx.z") thread_x = te.thread_axis((0, num_thread), "threadIdx.x") thread_y = te.thread_axis((0, num_thread), "threadIdx.y") thread_xz = te.thread_axis((0, vthread), "vthread", name="vx") thread_yz = te.thread_axis((0, vthread), "vthread", name="vy") # Split the workloads hi, wi, fi, ni = s[B].op.axis bz = s[B].fuse(hi, wi) by, fi = s[B].split(fi, factor=block_factor) bx, ni = s[B].split(ni, factor=block_factor) # Bind the iteration variables to GPU thread indices s[B].bind(bz, block_z)
def mcpu_auto_schedule(s, output, prefix): hyper_params = [[-1, 2, 8, 4], [-1, 1, 512, 1]] slice_data, slice_reduce = [], [] for i in range(len(output.op.axis)): slice_data.append( cfg.define_split(f"{prefix}:D{i}", attrs.get_extent(output.op.axis[i]), num_outputs=4, init_vals=[ hyper_params[i % len(hyper_params)], ])) for i in range(len(output.op.reduce_axis)): slice_reduce.append( cfg.define_split(f"{prefix}:R{i}", attrs.get_extent(output.op.reduce_axis[i]), num_outputs=2, init_vals=[ [-1, 4], ])) unroll = cfg.define_knob(f"{prefix}:UN", [1, 4, 8, 16, 32, 64], init_vals=[ 1, ] if attrs.backend == 'c-mcpu_avx512' else [ 0, ]) output_local, = s.cache_write([output], "local") slice_axes = [] for i in range(len(output.op.axis)): slice_axes.append( cfg.apply_split(s, output_local, output_local.op.axis[i], slice_data[i])) if output.op.reduce_axis: reduce_at = cfg.define_knob( f"{prefix}:RA", [x for x in range(len(output.op.reduce_axis))], init_vals=[ 0, ]) output_local_K_o, output_local_K_i = cfg.apply_split( s, output_local, output_local.op.reduce_axis[reduce_at], slice_reduce[reduce_at]) output_local_K_o, output_local_K_i = [output_local_K_o ], [output_local_K_i] else: output_local_K_o, output_local_K_i = [], [] first, second, third, fourth = [x[0] for x in slice_axes], [ x[1] for x in slice_axes ], [x[2] for x in slice_axes], [x[3] for x in slice_axes] s[output_local].reorder(*(first + second + output_local_K_o + third + output_local_K_i + fourth)) slice_global_axes = [] for i in range(len(output.op.axis)): if cfg.define_knob(f"{prefix}:_{i}", [False, True], init_vals=[ 0, ]): slice_global_axes.append( cfg.apply_split(s, output, output.op.axis[i], [ -1, slice_data[i][1], int(np.product(slice_data[i][2:])) ])) else: slice_global_axes.append( cfg.apply_split( s, output, output.op.axis[i], [-1, 1, int(np.product(slice_data[i][1:]))])) s[output].reorder(*([x[0] for x in slice_global_axes] + [x[1] for x in slice_global_axes] + [x[2] for x in slice_global_axes])) s[output_local].compute_at(s[output], slice_global_axes[-1][1]) s[output].bind(s[output].fuse(*[x[0] for x in slice_global_axes]), te.thread_axis('threadIdx.x')) s[output_local].pragma(first[0], "auto_unroll_max_step", unroll) s[output_local].pragma(first[0], "unroll_explicit", True) # s[output_local].vectorize(fourth[-1]) s[output_local].unroll(fourth[-1])
from tvm import te batch = 128 in_channel = 64 in_size = 62 # cnhw -> nchw A = te.placeholder((in_channel, batch, in_size, in_size), name='A') A_ch = te.compute((batch, in_channel, in_size, in_size), lambda n, c, h, w: A[c, n, h, w], name='A_change') s = te.create_schedule(A_ch.op) block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") block_z = te.thread_axis("blockIdx.z") thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") blockdim, threaddim = 32, 31 n, c, h, w = s[A_ch].op.axis hw = s[A_ch].fuse(h, w) no, ni = s[A_ch].split(n, nparts=blockdim) co, ci = s[A_ch].split(c, nparts=blockdim) hwo, hwi = s[A_ch].split(hw, nparts=31 * 31) s[A_ch].reorder(no, co, hwo, ni, ci, hwi) s[A_ch].bind(no, block_y) s[A_ch].bind(co, block_x) s[A_ch].bind(hwo, thread_x)
def _schedule_dense_large_batch(cfg, s, C): """Schedule float32/64 dense with large batch size""" A, B = C.op.input_tensors batch, in_dim = get_const_tuple(A.shape) out_dim, _ = get_const_tuple(B.shape) k = C.op.reduce_axis[0] # create tuning space try: block_cand = [64, 128] vthread_cand = [2 ** x for x in range(1, 7)] n_thread_cand = [2 ** x for x in range(3, 7)] cfg.define_split( "tile_x", batch, num_outputs=4, filter=lambda x: ( x.size[1] in vthread_cand and x.size[2] in n_thread_cand and (x.size[1] * x.size[2] * x.size[3]) in block_cand ), ) cfg.define_split( "tile_y", out_dim, num_outputs=4, filter=lambda x: ( x.size[1] in vthread_cand and x.size[2] in n_thread_cand and (x.size[1] * x.size[2] * x.size[3]) in block_cand ), ) cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2) except IndexError: # Index error happens when no entities left after filtering, which was designed # to prune tuning space for better search efficiency. logger.debug("Tuning space was created without pruning due to unfit shapes") cfg.define_split("tile_x", batch, num_outputs=4) cfg.define_split("tile_y", out_dim, num_outputs=4) cfg.define_split("tile_k", in_dim, num_outputs=3) if cfg.is_fallback: if batch > 1: cfg["tile_x"] = SplitEntity([-1, 2, 16, 2]) else: cfg["tile_x"] = SplitEntity([1, 1, 1, 1]) if out_dim > 1: cfg["tile_y"] = SplitEntity([-1, 2, 16, 2]) else: cfg["tile_y"] = SplitEntity([1, 1, 1, 1]) if in_dim > 8: cfg["tile_k"] = SplitEntity([-1, 8, 1]) else: cfg["tile_k"] = SplitEntity([-1, 1, 1]) # Explicit memory access AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") # Deal with op fusion if C.op not in s.outputs: s[C].compute_inline() C = s.outputs[0].output(0) # Split and reorder computation bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0]) by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1]) s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) # Binding s[C].bind(by, te.thread_axis("blockIdx.y")) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tyz, te.thread_axis("vthread")) s[C].bind(txz, te.thread_axis("vthread")) s[C].bind(ty, te.thread_axis("threadIdx.y")) s[C].bind(tx, te.thread_axis("threadIdx.x")) # Split reduction yo, xo = CC.op.axis ko, kt, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, kt, ki, yo, xo) s[AA].compute_at(s[CC], ko) s[BB].compute_at(s[CC], ko) s[CC].unroll(kt) s[AL].compute_at(s[CC], kt) s[BL].compute_at(s[CC], kt) # Schedule for A's shared memory load num_thread_x = cfg["tile_x"].size[2] ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x) _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4) tx, xi = s[AA].split(xi, nparts=num_thread_x) s[AA].bind(ty, te.thread_axis("threadIdx.y")) s[AA].bind(tx, te.thread_axis("threadIdx.x")) s[AA].double_buffer() # Schedule for B' shared memory load num_thread_y = cfg["tile_y"].size[2] ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y) _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4) tx, xi = s[BB].split(xi, nparts=num_thread_y) s[BB].bind(ty, te.thread_axis("threadIdx.y")) s[BB].bind(tx, te.thread_axis("threadIdx.x")) s[BB].double_buffer()
""" import tvm from tvm import te import numpy as np n = te.var('n') m = te.var('m') A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda i: A[i] * 2, name='B') s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], factor=64) s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) print(tvm.lower(s, [A, B], simple_mode=True)) """ Results: primfn(A_1: handle, B_1: handle) -> () attr = {"tir.noalias": True, "global_symbol": "main"} buffers = {A: Buffer(A_2: handle, float32, [n: int32], [stride: int32], type="auto"), B: Buffer(B_2: handle, float32, [n], [stride_1: int32], type="auto")} buffer_map = {B_1: B, A_1: A} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 63), 64); attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
def gen_ir_4d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 4d inputs Parameters ---------- data : tir.Tensor The input data to the operator. indices : tir.Tensor The index locations to update. updates : tir.Tensor The values to update. axis : int The axis to scatter on out : tir.Tensor The output tensor. update_func: function The function to be applied to a destination and the corresponding update Returns ------- ret : tir The computational ir. """ warp_size = tvm.target.Target.current(False).thread_warp_size n = data.shape[0] c = data.shape[1] h = data.shape[2] w = data.shape[3] ib = tvm.tir.ir_builder.create() out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) _memcpy_ir(ib, out_ptr, data_ptr, data.shape) indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) ni = indices.shape[0] ci = indices.shape[1] hi = indices.shape[2] wi = indices.shape[3] if axis == 0: with ib.new_scope(): j = te.thread_axis("blockIdx.y") ib.scope_attr(j, "thread_extent", ci) k = te.thread_axis("blockIdx.z") ib.scope_attr(k, "thread_extent", hi) tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", warp_size) with ib.for_range(0, ni, name="i") as i: with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: l = l_ * warp_size + tx with ib.if_scope(l < wi): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx]) elif axis == 1: with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", ni) k = te.thread_axis("blockIdx.z") ib.scope_attr(k, "thread_extent", hi) tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", warp_size) with ib.for_range(0, ci, name="j") as j: with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: l = l_ * warp_size + tx with ib.if_scope(l < wi): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx]) elif axis == 2: with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", ni) j = te.thread_axis("blockIdx.y") ib.scope_attr(j, "thread_extent", ci) tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", warp_size) with ib.for_range(0, hi, name="k") as k: with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: l = l_ * warp_size + tx with ib.if_scope(l < wi): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", ni) j = te.thread_axis("blockIdx.y") ib.scope_attr(j, "thread_extent", ci) k = te.thread_axis("blockIdx.z") ib.scope_attr(k, "thread_extent", hi) with ib.for_range(0, wi, name="l") as l: idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx]) with ib.else_scope(): update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx]) return ib.get()