import tvm from tvm import te from tvm import topi import numpy as np ###################################################################### # Basic example # ------------- # Let's revisit the sum of rows operation (equivalent to :code:`B = numpy.sum(A, axis=1)`') \ # To compute the sum of rows of a two dimensional TVM tensor A, we should # specify the symbolic operation as well as schedule as follows # n = te.var("n") m = te.var("m") A = te.placeholder((n, m), name='A') k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) ###################################################################### # and to examine the IR code in human readable format, we can do # print(tvm.lower(s, [A], simple_mode=True)) ###################################################################### # However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with # :code:`te.compute`. Imagine for more complicated operations how much details we need to provide. # Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum` # C = topi.sum(A, axis=1) ts = te.create_schedule(C.op)
def group_conv2d_nhwc(Input, Filter, stride, padding, dilation, groups, out_dtype=None): """Group convolution operator in NHWC layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] Filter : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel // groups, num_filter] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation : int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] groups : int number of groups out_dtype : str The output type. This is used for mixed precision. Returns ------- Output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = get_const_tuple(Input.shape) kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) assert in_channel % groups == 0, "input channels must divide group size" assert num_filter % groups == 0, "output channels must divide group size" pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) # compute the output shape out_channel = num_filter out_height = simplify( (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1 ) out_width = simplify( (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1 ) # compute graph pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] temp = pad(Input, pad_before, pad_after, name="pad_temp") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") rc = te.reduce_axis((0, in_channel // groups), name="rc") return te.compute( (batch, out_height, out_width, out_channel), lambda nn, yy, xx, ff: te.sum( temp[ nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ff // (num_filter // groups) * (in_channel // groups) + rc, ].astype(out_dtype) * Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc], ), tag="group_conv2d_nhwc", )
data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN) weight_shape = (out_channels // env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN) output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT) num_ops = in_channels * out_channels * batch_size * 2 # Reduction axes ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name='ic') ic_tns = te.reduce_axis((0, env.BLOCK_IN), name='ic_tns') # Input placeholder tensors data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype) weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype) # Copy buffers data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf") weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf") # Declare matrix multiply computation
def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): """Convolution operator in NCHW layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_channel, in_height, in_width] Filter : tvm.te.Tensor 4-D with shape [num_filter, in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] Returns ------- Output : tvm.te.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channel, in_height, in_width = Input.shape num_filter, channel, kernel_h, kernel_w = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_channel = num_filter out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left] pad_after = [0, 0, pad_down, pad_right] temp = pad(Input, pad_before, pad_after, name="pad_temp") rc = te.reduce_axis((0, in_channel), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") return te.compute( (batch, out_channel, out_height, out_width), lambda nn, ff, yy, xx: te.sum( temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype( out_dtype ) * Filter[ff, rc, ry, rx].astype(out_dtype), axis=[rc, ry, rx], ), tag="conv2d_nchw", )
def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32"): """Conv2D operator for nChw[x]c layout. Parameters ---------- data : tvm.te.Tensor 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] kernel : tvm.te.Tensor 6-D with shape [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block, num_filter_block] stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] layout : str Input data layout out_layout : str Output data layout out_dtype : str output data type Returns ------- output : tvm.te.Tensor 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] """ # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn target = tvm.target.Target.current(allow_none=False) oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) HPAD = pad_top + pad_down WPAD = pad_left + pad_right # output shape out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) pad_before = (0, 0, pad_top, pad_left, 0) pad_after = (0, 0, pad_down, pad_right, 0) # DOPAD DOPAD = HPAD != 0 or WPAD != 0 if DOPAD: data_pad = pad(data, pad_before, pad_after, name="data_pad") else: data_pad = data ic = te.reduce_axis((0, in_channel), name="ic") kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod return te.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: te.sum( data_pad[ n, idxdiv(ic, ic_bn), oh * HSTR + kh * dilation_h, ow * WSTR + kw * dilation_w, idxmod(ic, ic_bn), ].astype(out_dtype) * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block], axis=[ic, kh, kw], ), name="conv2d_NCHWc", tag="conv2d_NCHWc", )
def dot_int8_int8_int32_neon_82(int32_lanes, dtype="uint"): """ Int8 dot product by every 4 elements using ARM v8.2 udot. This function takes two arrays of int8 datatype -- data[4] and kernel[int32_lanes][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype. The pseudo code is as follows. .. code-block:: c void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < int32_lanes; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } Physically, the kernel array sits in a vector register and the data[4] is broadcasted to another vector register. This function returns a TensorIntrin that can be used to tensorize a schedule. Parameters ---------- int32_lanes : int How many int32/uint32 to produce dtype : str, optional, {"uint", "int"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The ARM uint8 TensorIntrin that can be used in tensorizing schedule """ num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements, ), dtype="%s8" % dtype, name="data") kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="%s8" % dtype, name="kernel") k = te.reduce_axis((0, num_int8_elements), name="k") C = te.compute( (int32_lanes, ), lambda i: te.sum(data[k].astype("%s32" % dtype) * kernel[i, k].astype( "%s32" % dtype), axis=k), name="C", ) a_buffer = tvm.tir.decl_buffer(data.shape, dtype="%s8" % dtype, name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.tir.decl_buffer( kernel.shape, dtype="%s8" % dtype, name="b_buffer", offset_factor=1, strides=[te.var("s"), 1], ) def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore( 0, tvm.tir.const(0, "%s32x%d" % (dtype, int32_lanes)))) return ib.get() dtype_a = "%s8x%d" % (dtype, num_int8_elements) dtype_b = "%s8x%d" % (dtype, int32_lanes * num_int8_elements) dtype_c = "%s32x%d" % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.tir.call_intrin("%s32" % dtype, "tir.reinterpret", a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.tir.call_intrin(dtype_b, "tir.reinterpret", vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = "udot" if dtype == "uint" else "sdot" inst = "llvm.aarch64.neon.%s.v%di32.v%di8" % ( inst, int32_lanes, int32_lanes * num_int8_elements, ) vdot = tvm.tir.call_llvm_pure_intrin(dtype_c, inst, tvm.tir.const(2, "uint32"), vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer }, default_buffer_params=buffer_params, )
def gemm_acc_4x4_int8_int8_int32(dtype): """ Int8 4x4 matrix multiplication and accumulation using sdot/udot instructions. This function takes two arrays of int8 datatype -- A[4][4] and B[4][4] and produces a 4x4 matrix which is equal to A*B'. The pseudo code is as follows. .. code-block:: c void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; i < 4; i++){ for (int k = 0; k < 4; k++){ C[i][j] += A[i][k] * B[j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] # This needs to be a variable number of "rows" since TVM # "thinks" I only need to compute one row because of # padding A = te.placeholder((te.var("rows"), 4), dtype, name="A") B = te.placeholder((4, 4), dtype, name="B") dtype_vec = dtype + "x16" k = te.reduce_axis((0, 4), name="k") C = te.compute( (te.var("rows"), 4), lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer(B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: for i in range(0, 4): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4"))) return ib.get() # Load all the elements of tile A. # vec_a = [a, b, c, d, # e, f, g, h, # l, m, n, o, # p, q, r, s]; vec_a = ins[0].vload([0, 0], dtype_vec) # Replicate 4 times the i-th row of A. For instance, # vec_a[0] = [a, b, c, d, # a, b, c, d, # a, b, c, d, # a, b, c, d,]; vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)] # Load all the elements of B. Remember that B # is transposed: # vec_b = [0, 4, 8, 12, # 1, 5, 9, 13, # 2, 6, 10, 14, # 3, 7, 11, 15,]; vec_b = ins[1].vload([0, 0], dtype_vec) # Execute the dot product for i in range(0, 4): vec_c = outs[0].vload([i, 0], "int32x4") # Compute the product between the i-th row of A # and all the rows of B. Remember that sdot/udot # subdive the input vectors in 16 elements # and then take the dot product among each group. # The result is stored in a int32x4 register # # For instance, for i=0, we have: # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12, # a*1+b*5+c*9+d*13, # a*2+b*6+c*10+d*14, # a*3+b*7+c*11+d*15] vdot = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i], ) # Store the result ib.emit(outs[0].vstore([i, 0], vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
import tvm from tvm import te import numpy as np ###################################################################### # Define Matrix Multiplication # ---------------------------- # Take matrix multiplication as our example. # Matmul first multiply the corresponding elements between two matrix, # then accumulate across a certain axis. # The following lines describe the computation :code:`A * B^T` in TVM. # N, M, L = 1024, 512, 64 A = te.placeholder((N, L), name='A') B = te.placeholder((M, L), name='B') k = te.reduce_axis((0, L), name='k') C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name='C') s = te.create_schedule(C.op) print(tvm.lower(s, [A, B, C], simple_mode=True)) ###################################################################### # Schedule the Matmul # ------------------- # Now, suppose we have an accelerator that supports # matrix-vector multiplication (GEMV) as a hardware primitive, # which can take arbitrary size of reduce axis, # but another axis needs to be no larger than 16. # Thus we break down the matmul loops to make the innermost loops a (16x64) GEMV. # factor = 16
def test_gemm(): # graph nn = 2048 n = te.var("n") n = tvm.runtime.convert(nn) m, l = n, n A = te.placeholder((l, n), name="A") B = te.placeholder((l, m), name="B") k = te.reduce_axis((0, l), name="k") C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C") # schedule s = te.create_schedule(C.op) AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") scale = 8 num_thread = 8 block_factor = scale * num_thread block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis((0, num_thread), "threadIdx.x") block_y = te.thread_axis("blockIdx.y") thread_y = te.thread_axis((0, num_thread), "threadIdx.y") thread_xz = te.thread_axis((0, 2), "vthread", name="vx") thread_yz = te.thread_axis((0, 2), "vthread", name="vy") by, yi = s[C].split(C.op.axis[0], factor=block_factor) bx, xi = s[C].split(C.op.axis[1], factor=block_factor) s[C].bind(by, block_y) s[C].bind(bx, block_x) s[C].reorder(by, bx, yi, xi) tyz, yi = s[C].split(yi, nparts=2) ty, yi = s[C].split(yi, nparts=num_thread) txz, xi = s[C].split(xi, nparts=2) tx, xi = s[C].split(xi, nparts=num_thread) s[C].bind(tyz, thread_yz) s[C].bind(txz, thread_xz) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) s[C].reorder(tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) yo, xo = CC.op.axis ko, ki = s[CC].split(k, factor=8) kt, ki = s[CC].split(ki, factor=1) s[CC].reorder(ko, kt, ki, yo, xo) s[AA].compute_at(s[CC], ko) s[BB].compute_at(s[CC], ko) s[CC].unroll(kt) s[AL].compute_at(s[CC], kt) s[BL].compute_at(s[CC], kt) # Schedule for A's shared memory load ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4) tx, xi = s[AA].split(xi, nparts=num_thread) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) s[AA].vectorize(xi) # Schedule for B' shared memory load ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4) tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].vectorize(xi) s[AA].double_buffer() s[BB].double_buffer() # correctness def check_device(device): dev = tvm.device(device, 0) if not dev.exist: print("Skip because %s is not enabled" % device) return print("Device %s" % device) f = tvm.build(s, [A, B, C], device) # launch the kernel. n, m, l = nn, nn, nn a_np = np.random.uniform(size=(n, l)).astype(A.dtype) b_np = np.random.uniform(size=(m, l)).astype(B.dtype) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) for i in range(2): f(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.dot(b_np.T, a_np), rtol=1e-5) num_flops = 2 * nn * nn * nn num_runs = 10 timer_f = f.time_evaluator(f.entry_name, dev, number=num_runs) t = timer_f(a, b, c).mean GFLOPS = num_flops / (t * 1e3) / 1e6 print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]: with tvm.transform.PassContext( config={ "tir.UnrollLoop": { "auto_max_step": 128, "explicit_unroll": device != "cuda" } }): check_device(device)
def measure_bandwidth_sum( total_item, item_per_thread, stride, base_type, bits, lanes, target, target_host, remote, dev, n_times, ): """measure memory bandwidth of gpu by product reduction for a given type The IR for measurement is for each thread for i in 1..num_per_thread: y[global_id] = y[global_id] * x[base + i * stride] Parameters ---------- total_item: int number of elements in input array item_per_thread: int number of elements each thread accumulates stride: int stride in memory access base_type: str can be "int", "float" bits: int can be 16, 32 lanes: int lane of the vector type, can be 1, 2, 4, 8, 16 target: :any:`tvm.target.Target` the target and option of the compilation. target_host : str or :any:`tvm.target.Target` host compilation target dev: Device the device of array remote: tvm.rpc.RPCSession remote rpc session n_times: int number of runs for taking mean Returns ------- GBPS: float gigabyte per second """ target, target_host = Target.check_and_update_host_consist(target, target_host) n, m = total_item, item_per_thread n //= lanes base_type = str(base_type) + str(bits) dtype = base_type if lanes == 1 else base_type + "x" + str(lanes) k = te.reduce_axis((0, m), name="k") x = te.placeholder((n,), dtype=dtype, name="x") op = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name="sum") y = te.compute( (n // m,), lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k) ) s = te.create_schedule(y.op) yo, yi = s[y].split(y.op.axis[0], target.max_num_threads) s[y].bind(yo, te.thread_axis("blockIdx.x")) s[y].bind(yi, te.thread_axis("threadIdx.x")) s[y].unroll(k) try: func = tvm.build(s, [x, y], target) x = tvm.nd.empty((n,), dtype=dtype, device=dev) y = tvm.nd.empty((n // m,), dtype=dtype, device=dev) func = _convert_to_remote(func, remote) time_f = func.time_evaluator(func.entry_name, dev, number=n_times) time = time_f(x, y).mean except tvm._ffi.base.TVMError: # build error (occur when device does not support half) return -1 return 1.0 * (total_item * bits / 8) / 1e9 / time
def test_gemm_gpu(N, times, bn, num_block, num_thread): assert bn <= N assert num_thread * num_thread * 16 <= N assert num_block * num_block * 2 <= N A = te.placeholder((N, N), name="A") B = te.placeholder((N, N), name="Btmp") k = te.reduce_axis((0, N), name="k") packedB = te.compute((N, N / bn, bn), lambda x, y, z: B[x, y * bn + z], name="B") C = te.compute( (N, N), lambda ii, jj: te.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k), name="C") s = te.create_schedule(C.op) CC = s.cache_write(C, "local") block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") thread_xz = te.thread_axis((0, 2), "vthread", name="vx") thread_yz = te.thread_axis((0, 2), "vthread", name="vy") pby, pbi = s[packedB].split(packedB.op.axis[0], nparts=num_thread) pbx, pbj = s[packedB].split(packedB.op.axis[1], nparts=num_thread) s[packedB].bind(pby, thread_y) s[packedB].bind(pbx, thread_x) pbz, pbk = s[packedB].split(packedB.op.axis[2], factor=8) s[packedB].vectorize(pbk) by, yi = s[C].split(C.op.axis[0], nparts=num_block) bx, xi = s[C].split(C.op.axis[1], nparts=num_thread) s[C].bind(by, block_y) s[C].bind(bx, thread_y) s[C].reorder(by, bx, yi, xi) tyz, yi = s[C].split(yi, nparts=2) ty, yi = s[C].split(yi, nparts=num_block) txz, xi = s[C].split(xi, nparts=2) tx, xi = s[C].split(xi, nparts=num_thread) s[C].reorder(tyz, txz, ty, tx, yi, xi) s[C].bind(tyz, thread_yz) s[C].bind(txz, thread_xz) s[C].bind(ty, block_x) s[C].bind(tx, thread_x) xyi, xxi = s[C].split(xi, factor=8) s[C].reorder(tyz, txz, ty, tx, yi, xyi, xxi) s[C].vectorize(xxi) s[CC].compute_at(s[C], yi) yo, xo = CC.op.axis s[CC].reorder(k, yo, xo) xo, xi = s[CC].split(xo, factor=8) s[CC].vectorize(xi) ko, ki = s[CC].split(k, factor=2) s[CC].unroll(ki) print(tvm.lower(s, [A, B, C], simple_mode=True)) f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu") temp = utils.tempdir() path_dso = temp.relpath("gemm_gpu.so") f.export_library(path_dso, ndk.create_shared) # connect to the proxy remote = rpc.connect(proxy_host, proxy_port, key=key) dev = remote.cl(0) remote.upload(path_dso) f = remote.load_module("gemm_gpu.so") evaluate(f, dev, N, times)
def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile): out_dtype = out_dtype or data.dtype N, C, IH, IW = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if len(kernel.shape) == 4: pre_packed = False C, M, KH, KW = get_const_tuple(kernel.shape) else: # kernel tensor is pre packed pre_packed = True C, M, KH, KW, VC = get_const_tuple(kernel.shape) C = C * VC dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 # pack data HPAD = pad_top + pad_down WPAD = pad_left + pad_right DOPAD = HPAD != 0 or WPAD != 0 if DOPAD: data_pad = nn.pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), name="data_pad") else: data_pad = data # fallback support # Currently, Mali schedule doesn't use it like conv2d. if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( "arm_cpu", "rk3399", "depthwise_conv2d_nchw_spatial_pack.arm_cpu") cfg.fallback_with_reference_log(ref_log) # ==================== define configuration space ==================== n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW) kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW) # Currently, Mali schedule doesn't use it like conv2d. # Leave num_tile for possible future use of Mali schedule if num_tile == 2: # for arm cpu co, vc = cfg.define_split("tile_co", c, num_outputs=2) oh, vh = cfg.define_split("tile_oh", oh, num_outputs=2) ow, vw = cfg.define_split("tile_ow", ow, num_outputs=2) else: raise RuntimeError("Invalid num_tile") cfg.define_reorder( "reorder_0", [n, co, oh, ow, kh, kw, vh, vw, vc], policy="candidate", candidate=[[n, co, oh, ow, kh, kw, vh, vw, vc], [n, co, oh, ow, kh, kw, vc, vh, vw]], ) cfg.define_reorder( "reorder_1", [n, co, oh, ow, vh, vw, vc], policy="candidate", candidate=[ [n, co, oh, ow, vh, vw, vc], [n, co, oh, ow, vc, vh, vw], [n, co, oh, ow, vh, vc, vw], ], ) cfg.define_annotate("ann_reduce", [kh, kw], policy="try_unroll") cfg.define_annotate("ann_spatial", [vh, vw, vc], policy="try_unroll_vec") # ==================================================================== VC = cfg["tile_co"].size[-1] VH = cfg["tile_oh"].size[-1] VW = cfg["tile_ow"].size[-1] kvshape = (C // VC, M, KH, KW, VC) ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC) oshape = (N, C * M, OH, OW) if dilation_h != 1 or dilation_w != 1: # undilate input data dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW) data_vec = te.compute( dvshape, lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][ (h * VH + vh) * HSTR + kh * dilation_h][ (w * VW + vw) * WSTR + kw * dilation_w], name="data_vec_undilated", ) else: dvshape = (N, OH // VH, OW // VW, C, VH * HSTR + KH - 1, VW * WSTR + KW - 1) data_vec = te.compute( dvshape, lambda n, h, w, c, vh, vw: data_pad[n][c][h * VH * HSTR + vh][ w * VW * WSTR + vw], name="data_vec", ) if pre_packed: kernel_vec = kernel else: kernel_vec = te.compute( kvshape, lambda co, m, kh, kw, vc: kernel[co * VC + vc][m][kh][kw], name="kernel_vec") kh = te.reduce_axis((0, KH), name="kh") kw = te.reduce_axis((0, KW), name="kw") idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod if dilation_h != 1 or dilation_w != 1: conv = te.compute( ovshape, lambda n, co, h, w, vh, vw, vc: te.sum( data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]. astype(out_dtype) * kernel_vec[idxdiv( co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype), axis=[kh, kw], ), name="depthwise_conv", ) else: conv = te.compute( ovshape, lambda n, co, h, w, vh, vw, vc: te.sum( data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh, vw * WSTR + kw].astype(out_dtype) * kernel_vec[idxdiv( co, M), idxmod(co, M), kh, kw, vc].astype( out_dtype), axis=[kh, kw], ), name="depthwise_conv", ) output = te.compute( oshape, lambda n, co, h, w: conv[n, idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW), idxmod(h, VH), idxmod(w, VW), idxmod(co, VC), ], name="output_unpack", tag="spatial_depthwise_conv2d_nchw_output", ) return output
def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype): """TOPI compute callback for depthwise_conv2d nhwc Parameters ---------- cfg: ConfigEntity The config for this template data : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] kernel : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] strides : list of two ints [stride_height, stride_width] padding : list of two ints [pad_height, pad_width] dilation : list of two ints [dilation_height, dilation_width] out_dtype: str The output type. This is used for mixed precision. Returns ------- output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ out_dtype = out_dtype or data.dtype N, IH, IW, IC = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape) dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 if pad_top or pad_left or pad_down or pad_right: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: data_pad = data output_shape = (N, OH, OW, IC * channel_multiplier) idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod reduce_h = te.reduce_axis((0, KH), name="reduce_h") reduce_w = te.reduce_axis((0, KW), name="reduce_w") out = te.compute( output_shape, lambda n, h, w, c: te.sum( data_pad[n, HSTR * h + dilation_h * reduce_h, w * WSTR + reduce_w * dilation_w, idxdiv(c, channel_multiplier), ].astype(out_dtype) * kernel[reduce_h, reduce_w, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)].astype(out_dtype), axis=[reduce_h, reduce_w], ), name="depthwise_conv2d_nhwc_output", ) return out
def f(n): rv = te.reduce_axis((0, n)) return te.sum(X[rv], axis=rv)
def gemm_acc_2x2_int8_int8_int32(dtype): """ Int8 2x2 matrix multiplication using smmla/ummla instructions This function takes two arrays of int8 datatype -- A[2][8] and B[2][8] and produces a 2x2 matrix which is equal to A*B' The pseudo code is as follows. .. code-block:: c void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){ for (int i = 0; i < 2; i++){ for (int j = 0; i < 2; i++){ for (int k = 0; k < 8; k++){ C[i][j] += A[i][k] * B[j][k] } } } Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] A = te.placeholder((2, 8), dtype, name="A") B = te.placeholder((2, 8), dtype, name="B") dtype_vec = dtype + "x16" k = te.reduce_axis((0, 8), name="k") C = te.compute( (2, 2), lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer(B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4"))) return ib.get() # Load in vec_a the two rows of A # vec_a = [a, b, c, d, e, f, g, h; # i, j, k, l, m, n, o, p,] vec_a = ins[0].vload([0, 0], dtype_vec) # Load in vec_b the two rows of B # vec_b = [0, 2, 4, 6, 8, 10, 12, 14; # 1, 3, 5, 7, 9, 11, 13, 14,] vec_b = ins[1].vload([0, 0], dtype_vec) # Execute the matrix multiplication via (s/u)mmla: # vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14; # a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15; # i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14; # i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15] vec_c = outs[0].vload([0, 0], "int32x4") vmmla = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b, ) # Store the result ib.emit(outs[0].vstore([0, 0], vmmla)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
def test_gemm(): # graph nn = 1024 n = tvm.runtime.convert(nn) m = n l = n A = te.placeholder((n, l), name='A') B = te.placeholder((m, l), name='B') k = te.reduce_axis((0, l), name='k') C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), name='CC') # schedule s = te.create_schedule(C.op) xtile, ytile = 32, 32 scale = 8 num_thread = 8 block_factor = scale * num_thread block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis("threadIdx.x") block_y = te.thread_axis("blockIdx.y") thread_y = te.thread_axis("threadIdx.y") CC = s.cache_write(C, "local") AA = s.cache_read(A, "shared", [CC]) BB = s.cache_read(B, "shared", [CC]) by, yi = s[C].split(C.op.axis[0], factor=block_factor) bx, xi = s[C].split(C.op.axis[1], factor=block_factor) s[C].reorder(by, bx, yi, xi) s[C].bind(by, block_y) s[C].bind(bx, block_x) ty, yi = s[C].split(yi, nparts=num_thread) tx, xi = s[C].split(xi, nparts=num_thread) s[C].reorder(ty, tx, yi, xi) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) yo, xo = CC.op.axis s[CC].reorder(k, yo, xo) s[CC].compute_at(s[C], tx) s[AA].compute_at(s[CC], k) s[BB].compute_at(s[CC], k) s[AA].double_buffer() s[BB].double_buffer() ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) tx, xi = s[AA].split(xi, nparts=num_thread) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) # lowering test s = s.normalize() # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return with tvm.target.create(device): f = tvm.build(s, [A, B, C]) # launch the kernel. n = nn m = n l = n a_np = np.random.uniform(size=(n, l)).astype(A.dtype) b_np = np.random.uniform(size=(m, l)).astype(B.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) ftimer = f.time_evaluator(f.entry_name, ctx, number=1) tcost = ftimer(a, b, c).mean print("%s: exec=%g sec/op" % (ctx, tcost)) tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) check_device("vulkan") check_device("nvptx -mcpu=sm_20") check_device("rocm") check_device("metal") check_device("opencl") check_device("cuda")
def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type): """ Int8 4x4 matrix multiplication and accumulation using a sequence of umull -> uadalp -> umull2 -> uadalp instructions. This function takes two arrays of int8 data type A[4][K] and B[4][K], and produces a 4x4 matrix which is equal to A*B'. The pseudo code is as follows. .. code-block:: c void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; j < 4; j++){ for (int k = 0; k < K; k++){ C[i][j] += A[i][k] * B[j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters ---------- M : int rows of the matrix A N : int columns of the matrix B K : int columns of matrix A unroll : bool Unroll the loop accumulation if True in_type : str, {'uint8', 'int8'} Returns ------- intrin : TensorIntrin The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule """ assert in_type in ["uint8", "int8"] A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A") B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B") dtype_vec = in_type + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, K), "k") C = te.compute( (te.var("m"), te.var("n")), lambda x, y: te.sum( A[k // 16, x, idxm(k, 16)].astype("int32") * B[ k // 16, y, idxm(k, 16)].astype("int32"), axis=k, ), name="C", ) a_buffer = tvm.tir.decl_buffer( A.shape, dtype=in_type, name="a_buffer", offset_factor=1, strides=[te.var("sa_1"), te.var("sa_2"), 1], ) b_buffer = tvm.tir.decl_buffer( B.shape, dtype=in_type, name="b_buffer", offset_factor=1, strides=[te.var("sb_1"), te.var("sb_2"), 1], ) c_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]) # Intrinsics used in the following algorithm umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull" uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp" addp_intrin = "llvm.aarch64.neon.addp" def uadalp(a, b): """Add pair and accumulate Parameters: ---------- a: int16x8 vector b: int16x8 vector Returns: -------- return a int32x4 vector Pseudocode: ---------- a += (b0+b1, b2+b3, b4+b5, b6+b7) """ return a + tvm.tir.call_llvm_pure_intrin("int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b) def umull(a, b): """Multiply long (higher part) Parameters: ---------- a: int8x16 vector b: int8x16 vector Returns: -------- return a int16x8 vector Pseudocode: ---------- c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7) """ a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a) b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b) c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high) return c def umull2(a, b): """Multiply long (lower part) Parameters: ---------- a: int8x16 vector b: int8x16 vector Returns: -------- return a int16x8 vector Pseudocode: ---------- c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15) """ a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a) b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b) c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low) return c def addp(a, b): """Add two vectors in pairs Parameters: ---------- a: int32x4 vector b: int32x4 vector Returns: -------- return a int32x4 vector Pseudocode: ---------- c = (a0+a1, a2+a3, b0+b1, b0+b3) """ return tvm.tir.call_llvm_pure_intrin("int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b) def accumulation_loop(M, N, ins, acc, tile_idx): """Internal tile accumulation. This function takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces a 4x4 matrix which is equal to A*B' and accumulates into C[4][4] The pseudo code is as follows. .. code-block:: c void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K], int8 B[tile_idx][4][K], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; j < 4; j++){ for (int k = 0; k < 16; k++){ C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters: ---------- M : int Number of total rows of the output matrix N : int Number of total columns of the output matrix ins : list of tvm.tir.buffer Input buffers acc : tvm.tir.ir_builder.BufferVar Bank of register accumulators tiled_idx : int Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:]. Please note that 0 <= tile_idx <= K//16 """ a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec) a1 = tvm.tir.const(0, "int8x16") if M > 1: a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec) a2 = tvm.tir.const(0, "int8x16") if M > 2: a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec) a3 = tvm.tir.const(0, "int8x16") if M > 3: a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec) b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec) b1 = tvm.tir.const(0, "int8x16") if N > 1: b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec) b2 = tvm.tir.const(0, "int8x16") if N > 2: b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec) b3 = tvm.tir.const(0, "int8x16") if N > 3: b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec) # First half # Lower part of a0 * {b0,b1,b2,b3} d00 = umull(a0, b0) d01 = umull(a0, b1) d02 = umull(a0, b2) d03 = umull(a0, b3) # Lower part of a1 * {b0,b1,b2,b3} d10 = umull(a1, b0) d11 = umull(a1, b1) d12 = umull(a1, b2) d13 = umull(a1, b3) # Accumulate acc[0] = uadalp(acc[0], d00) acc[1] = uadalp(acc[1], d01) acc[2] = uadalp(acc[2], d02) acc[3] = uadalp(acc[3], d03) acc[4] = uadalp(acc[4], d10) acc[5] = uadalp(acc[5], d11) acc[6] = uadalp(acc[6], d12) acc[7] = uadalp(acc[7], d13) # Higher part of a0 * {b0,b1,b2,b3} d00 = umull2(a0, b0) d01 = umull2(a0, b1) d02 = umull2(a0, b2) d03 = umull2(a0, b3) # Higher part of a1 * {b0,b1,b2,b3} d10 = umull2(a1, b0) d11 = umull2(a1, b1) d12 = umull2(a1, b2) d13 = umull2(a1, b3) # Accumulate again acc[0] = uadalp(acc[0], d00) acc[1] = uadalp(acc[1], d01) acc[2] = uadalp(acc[2], d02) acc[3] = uadalp(acc[3], d03) acc[4] = uadalp(acc[4], d10) acc[5] = uadalp(acc[5], d11) acc[6] = uadalp(acc[6], d12) acc[7] = uadalp(acc[7], d13) # Second half # Lower part of a2 * {b0,b1,b2,b3} d00 = umull(a2, b0) d01 = umull(a2, b1) d02 = umull(a2, b2) d03 = umull(a2, b3) # Lower part of a3 * {b0,b1,b2,b3} d10 = umull(a3, b0) d11 = umull(a3, b1) d12 = umull(a3, b2) d13 = umull(a3, b3) # Accumulate acc[8] = uadalp(acc[8], d00) acc[9] = uadalp(acc[9], d01) acc[10] = uadalp(acc[10], d02) acc[11] = uadalp(acc[11], d03) acc[12] = uadalp(acc[12], d10) acc[13] = uadalp(acc[13], d11) acc[14] = uadalp(acc[14], d12) acc[15] = uadalp(acc[15], d13) # Higher part of a2 * {b0,b1,b2,b3} d00 = umull2(a2, b0) d01 = umull2(a2, b1) d02 = umull2(a2, b2) d03 = umull2(a2, b3) # Lower part of a3 * {b0,b1,b2,b3} d10 = umull2(a3, b0) d11 = umull2(a3, b1) d12 = umull2(a3, b2) d13 = umull2(a3, b3) # Accumulate acc[8] = uadalp(acc[8], d00) acc[9] = uadalp(acc[9], d01) acc[10] = uadalp(acc[10], d02) acc[11] = uadalp(acc[11], d03) acc[12] = uadalp(acc[12], d10) acc[13] = uadalp(acc[13], d11) acc[14] = uadalp(acc[14], d12) acc[15] = uadalp(acc[15], d13) def _intrin_func(ins, outs): def _instr(): ib = tvm.tir.ir_builder.create() # Allocate a local buffer (possibly translates to registers) acc = ib.allocate("int32x4", 16, name="accs", scope="local") m = outs[0].shape[0] n = outs[0].shape[1] # Initialization for i in range(0, 16): acc[i] = tvm.tir.const(0, "int32x4") if unroll: for i in range(0, int(K // 16)): accumulation_loop(M, N, ins, acc, i) else: with ib.for_range(0, K // 16, name="i") as i: accumulation_loop(M, N, ins, acc, i) # Final accumulations # acc[4*r + c] contains the partial accumulations of element C[r][c] # # In particular: # acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d) # acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h) # acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l) # acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p) # # Please note that 0<= r, c < 4 acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h) acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p) acc[0] = addp(acc[0], acc[1]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[4] = addp(acc[4], acc[5]) # (a+b, c+d, e+f, g+h) acc[5] = addp(acc[6], acc[7]) # (i+j, k+l, m+n, o+p) acc[4] = addp(acc[4], acc[5]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[8] = addp(acc[8], acc[9]) # (a+b, c+d, e+f, g+h) acc[9] = addp(acc[10], acc[11]) # (i+j, k+l, m+n, o+p) acc[8] = addp(acc[8], acc[9]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[12] = addp(acc[12], acc[13]) # (a+b, c+d, e+f, g+h) acc[13] = addp(acc[14], acc[15]) # (i+j, k+l, m+n, o+p) acc[12] = addp(acc[12], acc[13]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) # Store the result if N > 3: out_0 = acc[0] out_1 = acc[4] out_2 = acc[8] out_3 = acc[12] elif N > 2: out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[12]) elif N > 1: out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[12]) else: out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[12]) ib.emit(outs[0].vstore([0, 0], out_0)) if M > 1: ib.emit(outs[0].vstore([1, 0], out_1)) if M > 2: ib.emit(outs[0].vstore([2, 0], out_2)) if M > 3: ib.emit(outs[0].vstore([3, 0], out_3)) return ib.get() # body, reset, update return _instr() buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: a_buffer, B: b_buffer, C: c_buffer }, default_buffer_params=buffer_params, )
def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed): """Compute declaration for winograd""" tile_size = _infer_tile_size(data, kernel) N, CI, H, W = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides if not pre_computed: # kernel tensor is raw tensor, do strict check if dilation_h != 1 or dilation_w != 1: kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w)) CO, CI, KH, KW = get_const_tuple(kernel.shape) alpha = KW + tile_size - 1 assert HSTR == 1 and WSTR == 1 and KH == KW else: # kernel tensor is pre-transfomred. this op is created by alter op layout. # dilation is not supported alpha, _, CI, CO = get_const_tuple(kernel.shape) KH = KW = alpha + 1 - tile_size assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") r = KW m = tile_size A, B, G = winograd_transform_matrices(m, r, out_dtype) H = (H + pt + pb - KH) // HSTR + 1 W = (W + pl + pr - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW # transform kernel if not pre_computed: r_kh = te.reduce_axis((0, KH), name="r_kh") r_kw = te.reduce_axis((0, KW), name="r_kw") kernel_pack = te.compute( (alpha, alpha, CI, CO), lambda eps, nu, ci, co: te.sum(kernel[co][ci][r_kh][r_kw] * G[eps][ r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name="kernel_pack", ) else: kernel_pack = kernel idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod # pack input tile input_tile = te.compute( (CI, P, alpha, alpha), lambda c, p, eps, nu: data_pad[idxdiv(p, (nH * nW))][c][idxmod( idxdiv(p, nW), nH) * m + eps][idxmod(p, nW) * m + nu], name="d", ) # transform data r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_a") data_pack = te.compute( (alpha, alpha, CI, P), lambda eps, nu, ci, p: te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name="data_pack", ) # do batch gemm ci = te.reduce_axis((0, CI), name="ci") bgemm = te.compute( (alpha, alpha, CO, P), lambda eps, nu, co, p: te.sum(kernel_pack[eps][nu][ci][co] * data_pack[ eps][nu][ci][p], axis=[ci]), name="bgemm", ) # inverse transform r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_a") inverse = te.compute( (CO, P, m, m), lambda co, p, vh, vw: te.sum( bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name="inverse", ) # output output = te.compute( (N, CO, H, W), lambda n, co, h, w: inverse[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)], name="output", tag="conv2d_nchw_winograd", ) cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) return output
def dot_int8_int8_int32_neon(): """ Int8 dot product using vmlal instructions .. code-block:: c void dot_int8_int8_int32(int8 data[4], int8 kernel[4][4], int32 output[4]){ for (int i = 0; i < 4; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } We use the smull and saddlp instructions to compute the dot product. smull : int8x16 -> int8x16 -> int16x8 elementwise multiplication saddlp: int16x8 -> int32x4 pairwise addition of elements Data is broadcast across the register int8 elements | data | data | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | smull int8 elements | kernel[i] | kernel[i+1] | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | = int16 elements | data * kernel[i] | data * kernel[i+1] | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | saddlp = int32 elements | partial sum(data * kernel[i]) | partial sum(data * kernel[i+1]) | | 0 | 1 | 2 | 3 | We apply the above kernel twice and use addp to compute the second set of pairwise additions int32 elements (narrowed for so they fit on a line) | psum d*k[i] | psum d*k[i+1] | | psum d*k[i+2] | psum d*k[i+3] | | 0 | 1 | 2 | 3 | addp | 4 | 5 | 6 | 7 | = |sum d*ki |sum d*ki1|sum d*ki2|sum d*ki3| | 0 | 1 | 2 | 3 | """ int32_lanes = 4 # 4 int32 lanes = 128 num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements, ), dtype="int8", name="data") kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") k = te.reduce_axis((0, num_int8_elements), name="k") C = te.compute( (int32_lanes, ), lambda i: te.sum( data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), name="C", ) a_buffer = tvm.tir.decl_buffer(data.shape, dtype="int8", name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) def _intrin_func(ins, outs): def _instr(index): int_8xl = "int8x8" int_32xl = "int32x4" ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() def pairwise_add_mul(idx): # this broadcasts data to the vector size a_int8 = ins[0].vload([0], "int8x4") re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) vec_ai32 = re_int32.astype("int32x2") vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time multiply = tvm.tir.call_llvm_pure_intrin( "int16x8", "llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication tvm.tir.const(2, "uint32"), vec_a, vec_b, ) pairwise_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.saddlp.v4i32.v8i16", tvm.tir.const(1, "uint32"), multiply, ) return pairwise_reduction pair_1 = pairwise_add_mul(0) pair_2 = pairwise_add_mul(1) quad_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.addp.v4i32", tvm.tir.const(2, "uint32"), pair_1, pair_2, ) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], int_32xl))) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer }, default_buffer_params=buffer_params, )
def maxpool2d_logical( shape_nhwc, window_shape, stride, padding, dtype, storage_scope="global", ): """ Maxpool2d TE wherein the input activation is defined by its logical NHWC shape. The packed physical layout for the activation is nhwc8h8w32c. """ block_H, block_W, block_C = get_block_shape() shape = get_packed_shape(shape_nhwc) logical_output_shape = ( shape_nhwc[0], (shape_nhwc[1] - window_shape[0] + padding[0] + padding[1]) // stride[0] + 1, (shape_nhwc[2] - window_shape[1] + padding[2] + padding[3]) // stride[0] + 1, shape_nhwc[3], ) output_shape = get_packed_shape(logical_output_shape) N, H, W, C = shape_nhwc X = te.placeholder(shape_nhwc, dtype=dtype) # Combination of padding required by maxpool operator and padding to evenly divisible # number of blocks. Note that this padding should be inlined in the schedule so # as to avoid input copying. pad_h = (block_H - ((H + padding[1]) % block_H)) % block_H pad_w = (block_W - ((W + padding[3]) % block_W)) % block_W X_pad = topi.nn.pad(X, [0, padding[0], padding[2], 0], [0, pad_h, pad_w, 0], pad_value=0) # Calculate packed layout X_packed = te.compute( shape, lambda n, ho, wo, co, hi, wi, ci: X_pad[ n, ho * block_H + hi, wo * block_W + wi, co * block_C + ci ], ) rh = te.reduce_axis((0, window_shape[0]), name="rh") rw = te.reduce_axis((0, window_shape[1]), name="rw") def compute(n, ho, wo, co, hi, wi, ci): # Construct blockized strided maxpool height indices h = ho * block_H + hi h_contig = h * stride[0] + rh h_block_id = h_contig // block_H h_block_offset = h_contig % block_H # Construct blockized strided maxpool width indices w = wo * block_W + wi w_contig = w * stride[1] + rw w_block_id = w_contig // block_W w_block_offset = w_contig % block_W return te.max( X_packed[n, h_block_id, w_block_id, co, h_block_offset, w_block_offset, ci], axis=[rh, rw], ) Y = te.compute(output_shape, compute) s = te.create_schedule(Y.op) # Ensure the padding and array packing is performed inline s[X_pad].compute_inline() s[X_packed].compute_inline() binds = {} if storage_scope and storage_scope != "global": with tvm.transform.PassContext(): Xb = tvm.tir.decl_buffer(shape, name="Xb", dtype=dtype, scope=storage_scope) Yb = tvm.tir.decl_buffer(output_shape, name="Yb", dtype=dtype, scope=storage_scope) binds = {X: Xb, Y: Yb} return (s, [X, Y], binds)
def gemm_acc_nx16_int8_int8_int32(dtype, rows): """ Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions This function takes two arrays of int8 datatype -- A[n][4] and B[4][16] and produces a rowsx16 matrix which is equal to A*B' The pseudo code is as follows. .. code-block:: c void mmla_nx16_int8_int8_int32(int8 A[n][16], int8 B[4][16][4], int32 output[n][16]){ for (int i = 0; i < n; i++){ for (int j = 0; i < 16; i++){ for (int k = 0; k < 16; k++){ out[i][j] += A[i][k] * B[k//4][j][k%4] } } } } Notes: * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16 we need 4 tiles of B to compute a single row of the output. The first 4 values of k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on * The tiling strategy is picked to maximize register usage. Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int rows : int Number of of the output rows "n" Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] A = te.placeholder((rows, 16), dtype, name="A") B = te.placeholder((4, 16, 4), dtype, name="B") dtype_vec = dtype + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, 16), name="k") C = te.compute( (rows, 16), lambda i, j: te.sum(A[i, k].astype("int32") * B[ k // 4, j, idxm(k, 4)].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer( B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb0"), te.var("sb1"), 1], ) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: for i in range(0, rows): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16"))) return ib.get() # Iterate on the number of rows of the output for k in range(0, rows): # Load 16 elements of A # vec_a = [a, b, c, d, e, f, g, h, l, m, n, o, p, q, r, s]; vec_a = ins[0].vload([k, 0], dtype_vec) # Iterate over each of the 4 rowsx4 tiles of the output for j in range(0, 4): # Accumulate over each of the 4 (16x4) tiles contained in B for i in range(0, 4): # Replicate a single 4-element group of A (A[k, i:i+4]) vec_aa = select_word(vec_a, i, dtype_vec) # Load 4 rows (each rows with 4 elements) from B (B[i:i+4, j:j+4]) # vec_b = [0, 16, 32, 48, # 1, 17, 33, 49, # 2, 18, 34, 50, # 3, 19, 35, 51,]; vec_b = ins[1].vload([i, 4 * j, 0], dtype_vec) # Accumulate in the correct part of the output vec_c = outs[0].vload([k, 4 * j], "int32x4") # Compute the dot product between the rowsx4 tile # from A and the 4x4 tile from B # # For instance, for i=0, we have: # sdot(vec_aa[0], vec_b) = [a*0+b*16+c*32+d*48, # a*1+b*17+c*33+d*49, # a*2+b*18+c*34+d*50, # a*3+b*19+c*35+d*51] vdot = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa, ) ib.emit(outs[0].vstore([k, 4 * j], vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
def _schedule(cfg, s, C): A, B = s[C].op.input_tensors if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A: s[B].compute_inline() batch, m_dim, k_dim = get_const_tuple(A.shape) batch, n_dim, k_dim = get_const_tuple(B.shape) data_dtype = A.dtype out_dtype = C.dtype # Explicit memory access AS = s.cache_read(A, "shared", [C]) BS = s.cache_read(B, "shared", [C]) AF = s.cache_read(AS, "wmma.matrix_a", [C]) BF = s.cache_read(BS, "wmma.matrix_b", [C]) CF = s.cache_write(C, "wmma.accumulator") CS = s.cache_read(CF, "shared", [C]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( target.kind.name, target.model, "batch_matmul_tensorcore.cuda" ) cfg.fallback_with_reference_log(ref_log) # Deal with op fusion, such as bias/relu and slice after padding if C.op not in s.outputs and "injective" in s.outputs[0].tag: s[C].compute_inline() C = s.outputs[0].output(0) # create tuning space cfg.define_knob("block_row_warps", [1, 2, 4]) cfg.define_knob("block_col_warps", [1, 2, 4]) cfg.define_knob("warp_row_tiles", [1, 2, 4]) cfg.define_knob("warp_col_tiles", [1, 2, 4]) cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("offset", [0, 8]) cfg.define_knob("offsetCS", [0, 8]) cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use if data_dtype in ["float16", "uint8", "int8"]: if m_dim % 32 == 0 and n_dim % 8 == 0: cfg.define_knob("wmma_m", [32, 16, 8]) elif m_dim % 16 == 0 and n_dim % 16 == 0: cfg.define_knob("wmma_m", [16, 8, 32]) elif m_dim % 8 == 0 and n_dim % 32 == 0: cfg.define_knob("wmma_m", [8, 16, 32]) wmma_k = 16 wmma_m = cfg["wmma_m"].val if wmma_m == 16: wmma_n = 16 elif wmma_m == 8: wmma_n = 32 elif wmma_m == 32: wmma_n = 8 elif data_dtype in ["int4", "uint4"]: wmma_m = wmma_n = 8 wmma_k = 32 else: raise ValueError("data dtype %s is not yet supported" % data_dtype) warp_size = 32 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val vec = cfg["vec"].val # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS AS_stride = [AS_align, 1] BS_stride = [BS_align, 1] AF_stride = [wmma_k, 1] BF_stride = [wmma_k, 1] CF_stride = [warp_col_tiles * wmma_n, 1] CS_stride = [CS_align, 1] block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") block_z = te.thread_axis("blockIdx.z") thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") thread_z = te.thread_axis("threadIdx.z") # Schedule for dense computation block_factor_m = wmma_m * warp_row_tiles * block_row_warps block_factor_n = wmma_n * warp_col_tiles * block_col_warps b, m, n = C.op.axis block_i, bc = s[C].split(m, factor=block_factor_m) block_j, oc = s[C].split(n, factor=block_factor_n) s[C].reorder(b, block_i, block_j, bc, oc) t = s[C].fuse(bc, oc) t, vi = s[C].split(t, factor=vec) t, tx = s[C].split(t, factor=warp_size) t, ty = s[C].split(t, factor=block_row_warps) t, tz = s[C].split(t, factor=block_col_warps) s[C].bind(block_i, block_x) s[C].bind(block_j, block_y) s[C].bind(b, block_z) s[C].bind(tz, thread_z) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) s[C].vectorize(vi) # Schedule for wmma store s[CS].compute_at(s[C], block_j) bs, bb, oo = CS.op.axis s[CS].storage_align(bb, CS_align - 1, CS_align) bb, bbi = s[CS].split(bb, factor=wmma_m) oo, ooi = s[CS].split(oo, factor=wmma_n) bb, bbii = s[CS].split(bb, factor=warp_row_tiles) oo, ooii = s[CS].split(oo, factor=warp_col_tiles) s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi) # Schedule for wmma computation s[CF].compute_at(s[CS], oo) bs, warp_i, warp_j = CF.op.axis warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) (k,) = CF.op.reduce_axis k, _k = s[CF].split(k, factor=wmma_k) ko, ki = s[CF].split(k, factor=chunk) s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k) # Schedule for wmma_matrix_a load s[AF].compute_at(s[CF], ki) bs, b, i = AF.op.axis b, b_ii = s[AF].split(b, factor=wmma_m) i, i_jj = s[AF].split(i, factor=wmma_k) s[AF].reorder(bs, b, i, b_ii, i_jj) # Schedule for wmma_matrix_b load s[BF].compute_at(s[CF], ki) bs, o, i = BF.op.axis o, o_ii = s[BF].split(o, factor=wmma_n) i, i_ii = s[BF].split(i, factor=wmma_k) s[BF].reorder(bs, o, i, o_ii, i_ii) # Schedule for A's(B's) shared memory load def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) bs, xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) t = s[stage].fuse(xo, yo) t, vi = s[stage].split(t, factor=vec) t, tx = s[stage].split(t, factor=warp_size) t, ty = s[stage].split(t, factor=block_row_warps) _, tz = s[stage].split(t, factor=block_col_warps) s[stage].bind(ty, thread_y) s[stage].bind(tz, thread_z) s[stage].bind(tx, thread_x) s[stage].vectorize(vi) shared_schedule(AS, AS_align) shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), lambda ii, jj: te.sum( AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm, ), name="CL_compute", ) # lower the computation loops down to TensorCore hardware intrinsics # by mapping the dense tensorcore to tensor intrinsics s[AF].tensorize( b_ii, intrin_wmma_load_matrix_A( AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype, ), ) s[BF].tensorize( o_ii, intrin_wmma_load_matrix_W( BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype, ), ) s[CF].tensorize( _ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape), ) s[CS].tensorize( bbi, intrin_wmma_store_matrix( CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n) ), )
def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype="float32"): """Convolution operator in NHWC layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] Filter : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel, num_filter] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] Returns ------- output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_channel = num_filter out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") rc = te.reduce_axis((0, in_channel), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") Output = te.compute( (batch, out_height, out_width, out_channel), lambda nn, yy, xx, ff: te.sum( PaddedInput[ nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc ].astype(out_dtype) * Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc], ), name="Conv2dOutput", tag="conv2d_nhwc", ) return Output
# Algorithm A = te.placeholder((in_size, in_size, in_channel, batch), name="A") W = te.placeholder((kernel, kernel, in_channel, out_channel), name="W") out_size = (in_size - kernel + 2 * pad) // stride + 1 # Pad input Apad = te.compute( (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch), lambda yy, xx, cc, nn: tvm.tir.if_then_else( tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size), A[yy - pad, xx - pad, cc, nn], tvm.tir.const(0.0, "float32"), ), name="Apad", ) # Create reduction variables rc = te.reduce_axis((0, in_channel), name="rc") ry = te.reduce_axis((0, kernel), name="ry") rx = te.reduce_axis((0, kernel), name="rx") # Compute the convolution B = te.compute( (out_size, out_size, out_channel, batch), lambda yy, xx, ff, nn: te.sum( Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc] ), name="B", ) ############################################################################### # Memory Hierarchy # ----------------
def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32" ): """Conv2D operator for nChw[x]c layout. Parameters ---------- data : tvm.te.Tensor 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] kernel : tvm.te.Tensor 7-D with shape [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4, num_filter_block, 4] stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] layout : str Input data layout out_layout : str Output data layout out_dtype : str output data type Returns ------- output : tvm.te.Tensor 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] """ # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple( kernel.shape ) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) HPAD = pad_top + pad_down WPAD = pad_left + pad_right # output shape out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) pad_before = (0, 0, pad_top, pad_left, 0) pad_after = (0, 0, pad_down, pad_right, 0) # DOPAD DOPAD = HPAD != 0 or WPAD != 0 if DOPAD: data_pad = pad(data, pad_before, pad_after, name="data_pad") else: data_pad = data ic = te.reduce_axis((0, in_channel), name="ic") kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") if groups == 1: n_elems = 4 ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer") ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner") ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner") return te.compute( oshape, lambda n, oc_chunk, oh, ow, oc_block: te.sum( data_pad[ n, ic_outer, oh * HSTR + kh * dilation_h, ow * WSTR + kw * dilation_w, ic_f_inner * n_elems + ic_s_inner, ].astype(out_dtype) * kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype( out_dtype ), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner], ), name="conv2d_NCHWc_int8", tag="conv2d_NCHWc_int8", ) # for int8 group conv support n_elems = 4 ic_chunk = in_channel // ic_bn ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer") ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner") ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner") oshape = (n, oc_chunk, out_height, out_width, oc_bn) return te.compute( oshape, lambda n, occ, oh, ow, oc_block: te.sum( data_pad[ n, (occ * oc_bn // (oc_chunk * oc_bn // groups)) * (ic_chunk // groups) + ic_outer, oh * HSTR + kh, ow * WSTR + kw, ic_f_inner * n_elems + ic_s_inner, ].astype(out_dtype) * kernel[occ, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner], ), name="conv2d_NCHWc_int8", tag="conv2d_NCHWc_int8", )
def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype, stride_w=1): """Defines a v7e-m DSP-accelerated transposed matmul.""" # we generate a unique ID for every intrinsic definition, to prevent name # collisions in the generated source (e.g., if there are multiple operators # in the same module that use the same intrinsic) # # TODO(weberlo, areusch): to cut down on memory usage, we should cache each intrinsic # instantiation and include it only once, eliminating the need for unique # IDs UNIQ_ID_LEN = 8 uniq_id = "".join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN)) if isinstance(M, tvm.tir.IntImm): M = M.value if isinstance(K, tvm.tir.IntImm): K = K.value if isinstance(N, tvm.tir.IntImm): N = N.value # TODO(weberlo, areusch): support more dtypes? assert in_dtype in ("int8", "int16") assert out_dtype == "int32" A = te.placeholder((M * stride_w - (stride_w - 1), K), name="a", dtype=in_dtype) B = te.placeholder((N, K), name="b", dtype=in_dtype) k = te.reduce_axis((0, K), name="k") C = te.compute( (M, N), lambda i, j: te.sum(A[i * stride_w, k].astype(out_dtype) * B[j, k]. astype(out_dtype), axis=k), name="c", ) A_buf = tvm.tir.decl_buffer(A.shape, A.dtype, name="A", offset_factor=1, strides=[te.var("A_s"), 1]) B_buf = tvm.tir.decl_buffer(B.shape, B.dtype, name="B", offset_factor=1, strides=[te.var("B_s"), 1]) C_buf = tvm.tir.decl_buffer(C.shape, C.dtype, name="C", offset_factor=1, strides=[te.var("C_s"), 1]) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] gemm_func_prefix = "gemm" if in_dtype == "int8" else "gemm16" def _reduce_update(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", f"{gemm_func_prefix}_{M}x{K}x{N}_update_{uniq_id}", aa.access_ptr("r"), bb.access_ptr("r"), cc.access_ptr("w"), aa.strides[0] * stride_w, bb.strides[0], cc.strides[0], )) return ib.get() def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0])) return ib.get() def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", f"{gemm_func_prefix}_{M}x{K}x{N}_body_{uniq_id}", aa.access_ptr("r"), bb.access_ptr("r"), cc.access_ptr("w"), aa.strides[0] * stride_w, bb.strides[0], cc.strides[0], )) return ib.get() return _body(), _reduce_reset(), _reduce_update() intrin_decl = te.decl_tensor_intrin(C.op, intrin_func, binds={ A: A_buf, B: B_buf, C: C_buf }) return intrin_decl, uniq_id
A = te.placeholder((in_height, in_width, in_channel, batch), name='A') W = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name='W') out_width = (in_width - kernel_width + 2 * pad_width) // stride + 1 out_height = (in_height - kernel_height + 2 * pad_height) // stride + 1 # Pad input Apad = te.compute( (in_height + 2 * pad_height, in_width + 2 * pad_width, in_channel, batch), lambda yy, xx, cc, nn: tvm.tir.if_then_else( tvm.tir.all(yy >= pad_height, yy - pad_height < in_height, xx >= pad_width, xx - pad_width < in_width), A[yy - pad_height, xx - pad_width, cc, nn], tvm.tir.const( 0., "float32")), name='Apad') # Create reduction variables rc = te.reduce_axis((0, in_channel), name='rc') ry = te.reduce_axis((0, kernel_height), name='ry') rx = te.reduce_axis((0, kernel_width), name='rx') # Compute the convolution B = te.compute( (out_height, out_width, out_channel, batch), lambda yy, xx, ff, nn: te.sum(Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]), name='B') ############################################################################### # Memory Hierarchy # ---------------- # # We first specify the memory hierarchy for buffers. The figure below shows the
def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile, output_padding): assert layout == "NCHW", "Only support NCHW" out_dtype = out_dtype or data.dtype N, CI, IH, IW = get_const_tuple(data.shape) _, CO, KH, KW = get_const_tuple(kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) opad_h, opad_w = output_padding assert opad_h < HSTR and opad_w < WSTR pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom + opad_h bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right + opad_w OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH + opad_h OW = (IW - 1) * WSTR - pad_left - pad_right + KW + opad_w dilated_input = dilate(data, [1, 1, HSTR, WSTR]) data_pad = pad(dilated_input, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right]) # ==================== define configuration space ==================== n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) if num_tile == 2: # for arm cpu co, vc = cfg.define_split('tile_co', co, num_outputs=2) oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) elif num_tile == 3: # for mali gpu co, _, vc = cfg.define_split('tile_co', co, num_outputs=3) oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3) ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3) else: raise RuntimeError("Invalid num_tile") cfg.define_reorder("reorder_0", [n, co, oh, ow, ci, kh, kw, vh, vw, vc], policy='candidate', candidate=[ [n, co, oh, ow, ci, kh, kw, vh, vw, vc], [n, co, oh, ow, ci, kh, kw, vc, vh, vw]]) cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') # ==================================================================== VC = cfg["tile_co"].size[-1] VH = cfg["tile_oh"].size[-1] VW = cfg["tile_ow"].size[-1] dvshape = (N, OH // VH, OW // VW, CI, VH + KH-1, VW + KW-1) kvshape = (CO // VC, CI, KH, KW, VC) ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC) oshape = (N, CO, OH, OW) data_vec = te.compute(dvshape, lambda n, h, w, ci, vh, vw: data_pad[n][ci][h*VH + vh][w*VW + vw], name='data_vec') kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc: kernel[ci][co*VC+vc][kh][kw], name='kernel_vec_conv2d_transpose') ci = te.reduce_axis((0, CI), name='ci') kh = te.reduce_axis((0, KH), name='kh') kw = te.reduce_axis((0, KW), name='kw') conv = te.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ te.sum(data_vec[n, h, w, ci, vh + kh, vw + kw].astype(out_dtype) * kernel_vec[co, ci, KH - 1 - kh, KW - 1 - kw, vc].astype(out_dtype), axis=[ci, kh, kw]), name='conv') idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod output = te.compute(oshape, lambda n, co, h, w: conv[n, idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW), idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)], name='output_unpack', tag='spatial_conv2d_transpose_output') return output
def test_vectorize_commreduce(): V = te.placeholder((128, ), name='V') ax = te.reduce_axis((0, 128), name='ax') O = te.compute((1, ), lambda _: te.sum(V[ax], axis=[ax])) s = te.create_schedule(O.op) s[O].vectorize(ax) # should throw here
def gemm_int8(n, m, l): A = te.placeholder((n, l), name='A', dtype='int8') B = te.placeholder((m, l), name='B', dtype='int8') k = te.reduce_axis((0, l), name='k') C = te.compute( (n, m), lambda i, j: te.sum(A[i, k].astype('int32') * B[j, k].astype('int32'), axis=k), name='C') cfg = autotvm.get_config() s = te.create_schedule(C.op) y, x = C.op.axis AA = s.cache_read(A, 'shared', [C]) BB = s.cache_read(B, 'shared', [C]) AL = s.cache_read(AA, 'local', [C]) BL = s.cache_read(BB, 'local', [C]) CC = s.cache_write(C, 'local') k = CC.op.reduce_axis[0] cfg.define_split('tile_k', cfg.axis(k), num_outputs=3, filter=lambda entity: entity.size[2] == 4 and \ entity.size[0] * 2 >= entity.size[1]) ko, kt, ki = cfg['tile_k'].apply(s, CC, k) s[CC].tensorize(ki, intrin_dp4a) block_x = te.thread_axis('blockIdx.x') block_y = te.thread_axis('blockIdx.y') thread_x = te.thread_axis('threadIdx.x') thread_y = te.thread_axis('threadIdx.y') def block_size_filter(entity): return entity.size[0] * 2 >= entity.size[1] * 2 and \ entity.size[1] <= 16 and entity.size[3] <= 4 cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter) cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter) by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y) bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x) s[C].bind(by, block_y) s[C].bind(bx, block_x) s[C].bind(tyz, te.thread_axis('vthread')) s[C].bind(txz, te.thread_axis('vthread')) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) yo, xo = CC.op.axis s[CC].reorder(ko, kt, yo, xo, ki) s[CC].unroll(kt) for stage in [AL, BL]: s[stage].compute_at(s[CC], kt) _, xi = s[stage].split(stage.op.axis[1], factor=4) s[stage].vectorize(xi) s[stage].double_buffer() cfg.define_knob('storage_align', [16, 48]) for stage in [AA, BB]: s[stage].storage_align(s[stage].op.axis[0], cfg['storage_align'].val, 0) s[stage].compute_at(s[CC], ko) fused = s[stage].fuse(*s[stage].op.axis) ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2]) tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2]) _, xi = s[stage].split(xi, factor=16) s[stage].bind(ty, thread_y) s[stage].bind(tx, thread_x) s[stage].vectorize(xi) cfg.define_knob('auto_unroll_max_step', [512, 1500]) s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) s[C].pragma(by, 'unroll_explicit', False) cfg.add_flop(n * m * l * 2) return s, [A, B, C]