def _default_dense_pack_config(cfg, M, N, K): vec_width = get_fp32_len() tilex_ii = 1 for bn in range(vec_width*2, 0, -1): if N % bn == 0: tilex_ii = bn break NN = N // tilex_ii tilex_oi = 1 while NN // tilex_oi > 4: if (NN // tilex_oi) % 2 == 1: break tilex_oi *= 2 tiley_ii = 8 while M % tiley_ii != 0: tiley_ii //= 2 MM = M // tiley_ii tiley_oi = 1 while MM // tiley_oi > 4: if (MM // tiley_oi) % 2 == 1: break tiley_oi *= 2 cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) cfg["tile_k"] = SplitEntity([K, 1])
def _fallback_schedule_int8(cfg, wkl): HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 oc_bn = 16 assert wkl.out_filter % oc_bn == 0 ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: ic_bn = bn break assert wkl.in_filter % 4 == 0 reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. if isinstance(M, tvm.expr.Var): M = 16 if isinstance(N, tvm.expr.Var): N = 16 if isinstance(K, tvm.expr.Var): K = 16 vec_width = get_fp32_len() tilex_ii = 1 for bn in range(vec_width * 2, 0, -1): if N % bn == 0: tilex_ii = bn break NN = N // tilex_ii tilex_oi = 1 while NN // tilex_oi > 4: if (NN // tilex_oi) % 2 == 1: break tilex_oi *= 2 tiley_ii = 8 while M % tiley_ii != 0: tiley_ii //= 2 MM = M // tiley_ii tiley_oi = 1 while MM // tiley_oi > 4: if (MM // tiley_oi) % 2 == 1: break tiley_oi *= 2 cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) cfg["tile_k"] = SplitEntity([K, 1])
def _fallback_schedule_int8(cfg, wkl): pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1 oc_bn = 16 assert wkl.out_filter % oc_bn == 0 ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: ic_bn = bn break assert wkl.in_filter % 4 == 0 reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule(cfg, wkl): simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break for ow_factor in range(out_width, 0, -1): if out_width % ow_factor == 0: for oh_factor in range(out_height, 0, -1): if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: cfg["tile_ic"] = SplitEntity( [wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity( [wkl.out_filter // oc_bn, oc_bn]) cfg["tile_oh"] = OtherOptionEntity(oh_factor) cfg["tile_ow"] = SplitEntity( [out_width // ow_factor, ow_factor]) return raise ValueError( "cannot decide default schedule for workload: {}".format(wkl))
def _fallback_schedule(cfg, wkl): simd_width = get_fp32_len() DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break reg_n = 1 for n in range(7, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule(cfg, wkl): simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _default_dense_nopack_config(cfg, M, N, K): vec_width = get_fp32_len() tilek_bn = 1 for bn in range(vec_width*2, 0, -1): if K % bn == 0: tilek_bn = bn break cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) cfg["tile_x"] = SplitEntity([N, 1]) cfg["tile_y"] = SplitEntity([1, M])
def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements): """Fallback schedule for 1x1 conv2d int8 on cpu. Normally the inner most pattern takes two int8/uint8 tensors data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], produces a dot product int32/uint32 output[int32_lanes]. Parameters ---------- int32_lanes : int How many numbers of int32/uint32 will be produced using intrinsic. This is related to output channel. num_int8_elements : int How many numbers of input int32/uint32 will be multiplied and reduced. This is related to input channel. """ pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w out_height = (wkl.height + pt + pb - wkl.kernel_h) // HSTR + 1 out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1 assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % ( wkl.out_filter, int32_lanes, ) assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % ( wkl.in_filter, num_int8_elements, ) oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: ic_bn = bn break for ow_factor in range(out_width, 0, -1): if out_width % ow_factor == 0: for oh_factor in range(out_height, 0, -1): if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: cfg["tile_ic"] = SplitEntity( [wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity( [wkl.out_filter // oc_bn, oc_bn]) cfg["tile_oh"] = OtherOptionEntity(oh_factor) cfg["tile_ow"] = SplitEntity( [out_width // ow_factor, ow_factor]) return raise ValueError( "cannot decide default schedule for workload: {}".format(wkl))
def schedule_dense_small_batch(cfg, s, C): """Schedule float32/64 dense with small batch size""" A, _ = C.op.input_tensors _, in_dim = get_const_tuple(A.shape) cfg.define_split('tile_k', in_dim, num_outputs=2) if cfg.is_fallback: cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0]) CF = s.rfactor(C, kf) if C.op in s.outputs: Out = C else: Out = s.outputs[0].output(0) s[C].compute_at(s[Out], s[Out].op.axis[1]) s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y")) s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x")) tx = s[C].op.reduce_axis[0] thread_x = tvm.thread_axis("threadIdx.x") s[C].bind(tx, thread_x) s[CF].compute_at(s[C], tx) s[C].set_store_predicate(thread_x.var.equal(0)) s[Out].set_store_predicate(thread_x.var.equal(0))
def _callback(op): if op.tag == "sparse_dense_bsrmm": y_bsrmm = op.input_tensors[0] assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" out = s.outputs[0].output(0) (_, c) = s[y_bsrmm].op.reduce_axis (m_o, n_o) = s[out].op.axis s[out].bind(m_o, te.thread_axis("blockIdx.x")) s[out].bind(n_o, te.thread_axis("blockIdx.y")) s[y_bsrmm].compute_at(s[out], n_o) thread_x = te.thread_axis("threadIdx.x") cfg.define_split("tile_c", c, num_outputs=2) if cfg.is_fallback: cfg["tile_c"] = SplitEntity([-1, 8]) _, ci = cfg['tile_c'].apply(s, y_bsrmm, c) y_bsrmm_factored = s.rfactor(y_bsrmm, ci) tx = s[y_bsrmm].op.reduce_axis[0] s[y_bsrmm].bind(tx, thread_x) s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) s[y_bsrmm].set_store_predicate(thread_x.var.equal(0)) s[out].set_store_predicate(thread_x.var.equal(0))
def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements): """Fallback schedule for conv2d int8 on cpu. Normally the inner most pattern takes two int8/uint8 tensors data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], produces a dot product int32/uint32 output[int32_lanes]. Parameters ---------- int32_lanes : int How many numbers of int32/uint32 will be produced using intrinsic. This is related to output channel. num_int8_elements : int How many numbers of input int32/uint32 will be multiplied and reduced. This is related to input channel. """ pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1 assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % ( wkl.out_filter, int32_lanes, ) assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % ( wkl.in_filter, num_int8_elements, ) oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: ic_bn = bn break reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _default_dense_nopack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. if isinstance(M, tvm.expr.Var): M = 16 if isinstance(N, tvm.expr.Var): N = 16 if isinstance(K, tvm.expr.Var): K = 16 vec_width = get_fp32_len() tilek_bn = 1 for bn in range(vec_width * 2, 0, -1): if K % bn == 0: tilek_bn = bn break cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) cfg["tile_x"] = SplitEntity([N, 1]) cfg["tile_y"] = SplitEntity([1, M])
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): if is_depthwise: raise RuntimeError("Depthwise not supported for intel graphics.") batch_size, in_channel, height, width = get_const_tuple(data.shape) out_channel, _, hkernel, _ = get_const_tuple(kernel.shape) HSTR, _ = strides ic_bn = 1 oc_bn, oc_bn_upper = 16, 16 for i in range(oc_bn_upper, 0, -1): if out_channel % i == 0: oc_bn = i break if HSTR == 2: if out_channel + hkernel == 515: block_oh = 4 block_ow = 4 else: block_oh = 4 block_ow = 5 elif hkernel == 3: if out_channel == 512: block_oh = 2 block_ow = 7 else: block_oh = 2 block_ow = 14 else: block_oh = 1 block_ow = 16 cfg["tile_ic"] = SplitEntity([in_channel // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([out_channel // oc_bn, oc_bn]) cfg["block_oh"] = OtherOptionEntity(block_oh) cfg["block_ow"] = OtherOptionEntity(block_ow)
def get_default_conv2d_config(cfg, fc, y, x): """Defines conv2d default parameters for split axis for Adreno conv2d and depthwise conv2d""" # look for vthread params: vy = 1 for n in range(5, 0, -1): if y % n == 0: vy = n break vx = 1 for n in range(5, 0, -1): if x % n == 0 and vy * n < 9: vx = n break y = y // vy x = x // vx tfc = 1 for n in range(64, 0, -1): if fc % n == 0: tfc = n break ty = 1 for n in range(16, 0, -1): if y % n == 0 and tfc * n <= 512: ty = n break tx = 1 for n in range(16, 0, -1): if x % n == 0 and tfc * ty * n <= 512: tx = n break fc = fc // tfc y = y // ty x = x // tx cfg["tile_fc"] = SplitEntity([fc, 1, tfc]) cfg["tile_y"] = SplitEntity([y, vy, ty]) cfg["tile_x"] = SplitEntity([x, vx, tx])
def _fallback_schedule(cfg, wkl): """ Get default schedule for the workload Parameters ---------- cfg : tvm.autotvm.task.space.FallbackConfigEntity Fallback config to be updated wkl : topi.nn.depthwise_conv2d.Workload Convolution workload """ simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 out_width = (wkl.width - dilated_kernel_w + pl + pr) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule(N, F, Y, X): # pylint: disable=unused-argument # split N (batch dimension) if N > 1: cfg["tile_n"] = SplitEntity([-1, 1, 1, 4]) else: cfg["tile_n"] = SplitEntity([1, 1, 1, 1]) # split F (output channel dimension) if F > 1: cfg["tile_f"] = SplitEntity([-1, 1, 64, 1]) # split Y (height dimension) y_split_factor = 1 for candidate in range(5, 17): if Y % candidate == 0: y_split_factor = candidate break cfg["tile_y"] = SplitEntity([-1, 1, 1, y_split_factor]) # split X (width dimension) x_split_factor = 1 for candidate in range(5, 17): if X % candidate == 0: x_split_factor = candidate break cfg["tile_x"] = SplitEntity([-1, x_split_factor, 1, 1]) # split RC (input channel dimension, which is a reduction axis) cfg["tile_rc"] = SplitEntity([-1, 1, 16]) # other configurations cfg["fuse_yx"] = OtherOptionEntity(False) cfg["unroll_explicit"] = OtherOptionEntity(True) cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
def _fallback_schedule(cfg, wkl): """ Get default schedule for the workload Parameters ---------- cfg : tvm.autotvm.task.space.FallbackConfigEntity Fallback config to be updated wkl : topi.nn.depthwise_conv2d.Workload Convolution workload """ simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule(cfg, wkl): simd_width = 4 # assume ARM SIMD Width is 4 pad_left, pad_right = wkl.padl, wkl.padr stride_w = wkl.stride_w out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1 groups = wkl.groups kernels_per_group = wkl.out_filter // groups kernel_depth = wkl.in_filter // groups oc_bn = 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if kernels_per_group % bn == 0: oc_bn = bn break if oc_bn > kernels_per_group: oc_bn = kernels_per_group ic_bn = 1 for bn in range(oc_bn, 0, -1): if kernel_depth % bn == 0: ic_bn = bn break if ic_bn > kernel_depth: ic_bn = kernel_depth reg_n = 1 for n in range(31, 0, -1): if out_width % n == 0: reg_n = n break cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule(cfg, wkl): simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1 dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 out_height = (wkl.height + pt + pb - dilated_kernel_h) // HSTR + 1 out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1 oc_bn = 1 for bn in range(simd_width, 0, -1): if wkl.out_filter % bn == 0: oc_bn = bn break ic_bn = 1 for bn in range(oc_bn, 0, -1): if wkl.in_filter % bn == 0: ic_bn = bn break for ow_factor in range(out_width, 0, -1): if out_width % ow_factor == 0: for oh_factor in range(out_height, 0, -1): if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: cfg["tile_ic"] = SplitEntity( [wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity( [wkl.out_filter // oc_bn, oc_bn]) cfg["tile_oh"] = OtherOptionEntity(oh_factor) cfg["tile_ow"] = SplitEntity( [out_width // ow_factor, ow_factor]) return raise ValueError( "cannot decide default schedule for workload: {}".format(wkl))
def _schedule_dense_small_batch(cfg, s, C): A, weights = C.op.input_tensors if len(weights.op.input_tensors) == 1 and weights.op.input_tensors[0] == A: s[weights].compute_inline() _, in_dim_weights = get_const_tuple(weights.shape) _, in_dim_A = get_const_tuple(A.shape) if isinstance(in_dim_A, int): in_dim = in_dim_A elif isinstance(in_dim_weights, int): in_dim = in_dim_weights else: in_dim = None if in_dim is not None: cfg.define_split("tile_k", in_dim, num_outputs=2) if cfg.is_fallback: cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) else: tile_k = 64 _, kf = s[C].split(C.op.reduce_axis[0], tile_k) CF = s.rfactor(C, kf) if C.op in s.outputs: Out = C else: Out = s.outputs[0].output(0) s[C].compute_at(s[Out], s[Out].op.axis[1]) s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) tx = s[C].op.reduce_axis[0] thread_x = te.thread_axis("threadIdx.x") s[C].bind(tx, thread_x) s[CF].compute_at(s[C], tx) s[C].set_store_predicate(thread_x.var.equal(0)) s[Out].set_store_predicate(thread_x.var.equal(0))
def schedule_depthwise_conv2d_nhwc(cfg, outs): """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) out = outs[0] ##### space definition begin ##### n, h, w, c = s[out].op.axis # Split the number of input/output channels cfg.define_split("tile_c", c, num_outputs=2) # Split the height of the convolution _, hi = cfg.define_split("tile_h", h, num_outputs=2) # Split the width of the convolution _, wi = cfg.define_split("tile_w", w, num_outputs=2) # Additional out (e.g., requantization, bias addition, etc..) # 0: locate the output on the second last axis of the main compuation # 1: locate the output closest to the main computation cfg.define_knob("locate_output", [0, 1]) # Determine if we should unroll the computation of the inner tile cfg.define_knob("unroll_tile", [True, False]) # fallback support if cfg.is_fallback: cfg["tile_c"] = SplitEntity([-1, 8]) cfg["tile_h"] = SplitEntity([-1, 2]) cfg["tile_w"] = SplitEntity([-1, 2]) cfg["locate_output"] = OtherOptionEntity(1) cfg["unroll_tile"] = OtherOptionEntity(True) ##### space definition end ##### def schedule_conv(conv): conv_data = conv.op.input_tensors[0] kernel_data = conv.op.input_tensors[1] in_type = conv_data.dtype _, _, IC, channel_multiplier = get_const_tuple(kernel_data.shape) n, w, h, c = conv.op.axis r_h, r_w = conv.op.reduce_axis ho, hi = cfg["tile_h"].apply(s, conv, h) wo, wi = cfg["tile_w"].apply(s, conv, w) co, ci = cfg["tile_c"].apply(s, conv, c) split_val = cfg["tile_c"].size[-1] use_tensorization = ( (in_type == "int16") and (split_val == 8) and (IC % split_val == 0) and (channel_multiplier == 1) and is_aarch64_arm() ) data_pad_value = -1 if conv_data.name == "data_pad": assert isinstance(conv_data.op, tvm.te.ComputeOp) # Define a strategy for padding computation cfg.define_knob("data_pad_strategy", [1, 2, 3]) if cfg.is_fallback: # We cannot inline padding when tensorizing. # So, if we can tensorize, let's compute_at the closest axis cfg["data_pad_strategy"] = ( OtherOptionEntity(2) if use_tensorization else OtherOptionEntity(3) ) # Compute padding on the third to last axis of the computation if cfg["data_pad_strategy"].val == 1: s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], ho) # Compute padding on the second to last axis of the computation if cfg["data_pad_strategy"].val == 2: s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], wo) # Inline padding during computation if cfg["data_pad_strategy"].val == 3: s[conv_data].compute_inline() data_pad_value = cfg["data_pad_strategy"].val if use_tensorization and data_pad_value != 3: smlal = smlal_int16_int32() s[conv].tensorize(ci, smlal) else: s[conv].vectorize(ci) if cfg["unroll_tile"].val: s[conv].unroll(r_h) s[conv].unroll(r_w) s[conv].unroll(wi) s[conv].unroll(hi) s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) fused_n_ho = s[conv].fuse(n, ho) return fused_n_ho def schedule_conv_out(out): n, h, w, c = out.op.axis co, ci = cfg["tile_c"].apply(s, out, c) wo, wi = cfg["tile_w"].apply(s, out, w) ho, hi = cfg["tile_h"].apply(s, out, h) s[out].reorder(n, ho, wo, co, hi, wi, ci) if cfg["unroll_tile"]: s[out].unroll(wi) s[out].unroll(hi) if out.dtype in ["int8", "uint8"]: # In case of quantized convolution further split the channel in batches of 4 elements # so that we can use arm intrinsics to run fixed_point_multiplication ci_outer, ci_inner = s[out].split(ci, 4) s[out].vectorize(ci_inner) s[out].unroll(ci_outer) fused_n_ho = s[out].fuse(n, ho) return hi, wi, fused_n_ho def _callback(op): if op.name == "depthwise_conv2d_nhwc_output": conv = op.output(0) if conv != out: hi, wi, p_axis = schedule_conv_out(out) schedule_conv(conv) if cfg["locate_output"].val == 0: s[conv].compute_at(s[out], hi) if cfg["locate_output"].val == 1: s[conv].compute_at(s[out], wi) else: p_axis = schedule_conv(out) s[out].parallel(p_axis) traverse_inline(s, outs[0].op, _callback) return s
def schedule(Apad, W, B): """Schedule conv2d_hwcn""" sch[Apad].compute_inline() AA = sch.cache_read(Apad, "shared", [B]) WW = sch.cache_read(W, "shared", [B]) AL = sch.cache_read(AA, "local", [B]) WL = sch.cache_read(WW, "local", [B]) if B.op in sch.outputs: Out = B BL = sch.cache_write(Out, "local") else: Out = sch.outputs[0].output(0) sch[B].set_scope("local") BL = B hi, wi, fi, ni = sch[Out].op.axis # Create tuning space n_thread_cand = [1, 2, 4, 8, 16, 32] vthread_cand = [1, 2, 4, 8] cfg.define_split( 'tile_fi', fi, num_outputs=4, filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand)) cfg.define_split( 'tile_ni', ni, num_outputs=4, filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand)) if cfg.is_fallback: cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4]) cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4]) # Scheduling step = 8 bz = sch[Out].fuse(hi, wi) by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi) bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni) sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) sch[Out].bind(bz, tvm.thread_axis('blockIdx.z')) sch[Out].bind(by, tvm.thread_axis('blockIdx.y')) sch[Out].bind(bx, tvm.thread_axis('blockIdx.x')) sch[Out].bind(tyz, tvm.thread_axis('vthread')) sch[Out].bind(txz, tvm.thread_axis('vthread')) sch[Out].bind(ty, tvm.thread_axis('threadIdx.y')) sch[Out].bind(tx, tvm.thread_axis('threadIdx.x')) # Schedule BL local write sch[BL].compute_at(sch[Out], tx) yi, xi, fi, ni = sch[BL].op.axis ry, rx, rc = sch[BL].op.reduce_axis rco, rci = sch[BL].split(rc, factor=step) sch[BL].reorder(rco, ry, rx, rci, fi, ni) fuse_index = sch[BL].fuse(ry, rx) fuse_index = sch[BL].fuse(fuse_index, rco) rx = fuse_index sch[AA].compute_at(sch[BL], rx) sch[WW].compute_at(sch[BL], rx) sch[AL].compute_at(sch[BL], rci) sch[WL].compute_at(sch[BL], rci) # Schedule for A's shared memory load yi, xi, ci, ni = sch[AA].op.axis ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2]) tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2]) _, ni = sch[AA].split(ni, factor=4) sch[AA].reorder(ty, tx, yi, xi, ci, ni) sch[AA].bind(ty, tvm.thread_axis('threadIdx.y')) sch[AA].bind(tx, tvm.thread_axis('threadIdx.x')) sch[AA].vectorize(ni) # Schedule for W's shared memory load yi, xi, ci, fi = sch[WW].op.axis ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2]) tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2]) _, fi = sch[WW].split(fi, factor=4) sch[WW].reorder(ty, tx, yi, xi, ci, fi) sch[WW].bind(ty, tvm.thread_axis('threadIdx.y')) sch[WW].bind(tx, tvm.thread_axis('threadIdx.x')) sch[WW].vectorize(fi)
def schedule_depthwise_conv2d_nhwc(cfg, outs): """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) out = outs[0] ##### space definition begin ##### n, h, w, c = s[out].op.axis cfg.define_split("tile_c", c, num_outputs=2) _, hi = cfg.define_split("tile_h", h, num_outputs=2) _, wi = cfg.define_split("tile_w", w, num_outputs=2) cfg.define_knob("locate_output", [0, 1]) # fallback support if cfg.is_fallback: cfg["tile_c"] = SplitEntity([-1, 8]) cfg["tile_h"] = SplitEntity([-1, 2]) cfg["tile_w"] = SplitEntity([-1, 2]) cfg["locate_output"] = OtherOptionEntity(1) ##### space definition end ##### def schedule_conv(conv): conv_data = conv.op.input_tensors[0] n, w, h, c = conv.op.axis r_h, r_w = conv.op.reduce_axis ho, hi = cfg["tile_h"].apply(s, conv, h) wo, wi = cfg["tile_w"].apply(s, conv, w) co, ci = cfg["tile_c"].apply(s, conv, c) if conv_data.name == "data_pad": assert isinstance(conv_data.op, tvm.te.ComputeOp) # Define a policy for padding computation cfg.define_knob("data_pad_inline", [1, 2, 3]) if cfg.is_fallback: cfg["data_pad_inline"] = OtherOptionEntity(3) if cfg["data_pad_inline"].val == 1: s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], ho) if cfg["data_pad_inline"].val == 2: s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], wo) if cfg["data_pad_inline"].val == 3: s[conv_data].compute_inline() s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) fused_n_ho = s[conv].fuse(n, ho) s[conv].vectorize(ci) return fused_n_ho def schedule_conv_out(out): n, h, w, c = out.op.axis co, ci = cfg["tile_c"].apply(s, out, c) wo, wi = cfg["tile_w"].apply(s, out, w) ho, hi = cfg["tile_h"].apply(s, out, h) s[out].reorder(n, ho, wo, co, hi, wi) if out.dtype in ["int8", "uint8"]: # In case of quantized convolution further split the channel in batches of 4 elements # so that we can use arm intrinsics to run fixed_point_multiplication ci_outer, ci_inner = s[out].split(ci, 4) s[out].vectorize(ci_inner) fused_n_ho = s[out].fuse(n, ho) return hi, wi, fused_n_ho def _callback(op): if op.name == "depthwise_conv2d_nhwc_output": conv = op.output(0) if conv != out: hi, wi, p_axis = schedule_conv_out(out) schedule_conv(conv) if cfg["locate_output"].val == 0: s[conv].compute_at(s[out], hi) if cfg["locate_output"].val == 1: s[conv].compute_at(s[out], wi) else: p_axis = schedule_conv(out) s[out].parallel(p_axis) traverse_inline(s, outs[0].op, _callback) return s
def _schedule(cfg, op): C = op.output(0) A, B = s[C].op.input_tensors if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A: s[B].compute_inline() _, M, N = get_const_tuple(C.shape) AA = s.cache_read(A, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BB = s.cache_read(B, "shared", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") if op not in s.outputs: s[C].compute_inline() C = s.outputs[0].output(0) b, y, x = s[C].op.axis (k, ) = s[CC].op.reduce_axis cfg.define_split("tile_y", y, num_outputs=3) cfg.define_split("tile_x", x, num_outputs=3) cfg.define_split("tile_k", k, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64]) target = tvm.target.Target.current() if target.kind.name in ["nvptx", "rocm"]: # llvm-based backends cannot do non-explicit unrolling cfg.define_knob("unroll_explicit", [1]) else: cfg.define_knob("unroll_explicit", [0, 1]) if cfg.is_fallback: y_bn = get_max_power2_factor(M, 64) x_bn = get_max_power2_factor(N, 64) y_nthreads = min(y_bn, 8) x_nthreads = min(x_bn, 8) cfg["tile_x"] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads]) cfg["tile_y"] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads]) cfg["tile_k"] = SplitEntity([-1, 8]) cfg["auto_unroll_max_step"] = OtherOptionEntity(16) by, ty, yi = cfg["tile_y"].apply(s, C, y) bx, tx, xi = cfg["tile_x"].apply(s, C, x) thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") s[C].reorder(b, by, bx, ty, tx, yi, xi) s[C].bind(b, te.thread_axis("blockIdx.z")) s[C].bind(by, te.thread_axis("blockIdx.y")) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) s[C].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[C].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val) s[CC].compute_at(s[C], tx) _, yi, xi = s[CC].op.axis ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, yi, xi) s[CC].pragma(ki, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[CC].pragma(ki, "unroll_explicit", cfg["unroll_explicit"].val) s[AA].compute_at(s[CC], ko) s[AL].compute_at(s[CC], ki) s[BB].compute_at(s[CC], ko) s[BL].compute_at(s[CC], ki) _, y, k = s[AA].op.axis ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1]) tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1]) s[AA].reorder(ty, tx, yi, ki) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) s[AA].pragma(yi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[AA].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val) _, x, k = s[BB].op.axis ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1]) tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1]) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].reorder(ty, tx, xi, ki) s[BB].pragma(xi, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[BB].pragma(xi, "unroll_explicit", cfg["unroll_explicit"].val)
def schedule_dense_large_batch(cfg, s, C): """Schedule float32/64 dense with large batch size""" A, B = C.op.input_tensors batch, in_dim = get_const_tuple(A.shape) out_dim, _ = get_const_tuple(B.shape) k = C.op.reduce_axis[0] # create tuning space try: block_cand = [64, 128] vthread_cand = [2**x for x in range(1, 7)] n_thread_cand = [2**x for x in range(3, 7)] cfg.define_split( 'tile_x', batch, num_outputs=4, filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand and (x.size[1] * x.size[2] * x.size[3]) in block_cand)) cfg.define_split( 'tile_y', out_dim, num_outputs=4, filter=lambda x: (x.size[1] in vthread_cand and x.size[2] in n_thread_cand and (x.size[1] * x.size[2] * x.size[3]) in block_cand)) cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2) except IndexError: # Index error happens when no entities left after filtering, which was designed # to prune tuning space for better search efficiency. logger.debug( 'Tuning space was created without pruning due to unfit shapes') cfg.define_split('tile_x', batch, num_outputs=4) cfg.define_split('tile_y', out_dim, num_outputs=4) cfg.define_split('tile_k', in_dim, num_outputs=3) if cfg.is_fallback: if batch > 1: cfg['tile_x'] = SplitEntity([-1, 2, 16, 2]) else: cfg['tile_x'] = SplitEntity([1, 1, 1, 1]) if out_dim > 1: cfg['tile_y'] = SplitEntity([-1, 2, 16, 2]) else: cfg['tile_y'] = SplitEntity([1, 1, 1, 1]) if in_dim > 8: cfg['tile_k'] = SplitEntity([-1, 8, 1]) else: cfg['tile_k'] = SplitEntity([-1, 1, 1]) # Explicit memory access AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") # Deal with op fusion if C.op not in s.outputs: s[C].compute_inline() C = s.outputs[0].output(0) # Split and reorder computation bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0]) by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1]) s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) # Binding s[C].bind(by, tvm.thread_axis("blockIdx.y")) s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tyz, tvm.thread_axis("vthread")) s[C].bind(txz, tvm.thread_axis("vthread")) s[C].bind(ty, tvm.thread_axis("threadIdx.y")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) # Split reduction yo, xo = CC.op.axis ko, kt, ki = cfg['tile_k'].apply(s, CC, k) s[CC].reorder(ko, kt, ki, yo, xo) s[AA].compute_at(s[CC], ko) s[BB].compute_at(s[CC], ko) s[CC].unroll(kt) s[AL].compute_at(s[CC], kt) s[BL].compute_at(s[CC], kt) # Schedule for A's shared memory load num_thread_x = cfg['tile_x'].size[2] ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x) _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4) tx, xi = s[AA].split(xi, nparts=num_thread_x) s[AA].bind(ty, tvm.thread_axis("threadIdx.y")) s[AA].bind(tx, tvm.thread_axis("threadIdx.x")) s[AA].double_buffer() # Schedule for B' shared memory load num_thread_y = cfg['tile_y'].size[2] ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y) _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4) tx, xi = s[BB].split(xi, nparts=num_thread_y) s[BB].bind(ty, tvm.thread_axis("threadIdx.y")) s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) s[BB].double_buffer()
def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_k"] = SplitEntity([K // 16, 16]) x_bn = get_max_power2_factor(N, 8) cfg["tile_x"] = SplitEntity([N // x_bn, x_bn]) y_bn = get_max_power2_factor(M, 8) cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])