def _spatial_pack_data_only(wkl, sch, data): H, W = wkl.height, wkl.width CI, CO = wkl.in_filter, wkl.out_filter KH, KW = wkl.hkernel, wkl.wkernel HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride HCAT, WCAT = KH - 1, KW - 1 VH = sch.vh VW = sch.vw VC = sch.vc UNROLL = sch.unroll TH = H + 2 * HPAD TW = W + 2 * WPAD OH = (H + 2 * HPAD - KH) // HSTR + 1 OW = (W + 2 * WPAD - KW) // WSTR + 1 dshape = (1, CI, H, W) dpshape = (1, CI, TH, TW) dvshape = (1, TH // (VH * HSTR), TW // (VW * WSTR), CI, VH * HSTR + HCAT, VW * WSTR + WCAT) DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: \ data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') s = tvm.create_schedule(data_vec.op) # traverse(s, data_vec.op) # schedule for data_vec A0, A1 = data_pad, data_vec if DOPAD: s[A0].compute_inline() n, h, w, ci, vh, vw = s[A1].op.axis s[A1].fuse(vh, vw) if sch.ba == 1: oaxis = h paxis = h else: oh, ih = s[A1].split(h, sch.ba) oaxis = oh paxis = ih s[A1].parallel(paxis) s[A1].pragma(oaxis, "parallel_launch_point") s[A1].pragma(paxis, "parallel_stride_pattern") s[A1].pragma(oaxis, "parallel_barrier_when_finish") return data_vec, s
def decl_winograd(data, U, stride, padding, out_dtype): """declare winograd fast convolution F(2x2, 3x3) for conv2d""" N, C, H, W = [util.get_const_int(x) for x in data.shape] _, _, C, K = [util.get_const_int(x) for x in U.shape] HPAD, WPAD = 1, 1 if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") m = 2 r = 3 alpha = m + r - 1 K = K nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW # pack input tile input_tile = tvm.compute( (C, P, alpha, alpha), lambda c, b, eps, nu: tvm.select( b < P, data_pad[b // (nH * nW)][c][b // nW % nH * m + eps][ b % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') V = decl_V_minimal(input_tile, alpha, C, P) # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute( (alpha, alpha, K, P), lambda eps, nu, k, b: tvm.sum(U[eps][nu][c][k] * V[eps][nu][c][b], axis=c), name='M') # inverse transform and unpack output = decl_output_minimal(M, N, K, H, W, P, m, nH, nW) return output
def _depthwise_spatial_pack(args, data, kernel, strides, padding, dilation, out_dtype): """depthwise_conv2d_arm_cpu's inner implement""" is_var, u_vh, u_vw, u_vc = args out_dtype = out_dtype or data.dtype u_n, u_c, ih, iw = data.shape if is_var else get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if len(kernel.shape) == 4: pre_packed = False u_c, um, ukh, ukw = kernel.shape if is_var else get_const_tuple( kernel.shape) else: # kernel tensor is pre packed pre_packed = True u_c, um, ukh, ukw, u_vc = kernel.shape if is_var else get_const_tuple( kernel.shape) u_c = u_c * u_vc dilated_kernel_h = (ukh - 1) * dilation_h + 1 dilated_kernel_w = (ukw - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) hstr, wstr = strides if isinstance(strides, (tuple, list)) else (strides, strides) u_oh = (ih + pad_top + pad_down - dilated_kernel_h) // hstr + 1 u_ow = (iw + pad_left + pad_right - dilated_kernel_w) // wstr + 1 # pack data hpad = pad_top + pad_down wpad = pad_left + pad_right dopad = hpad != 0 or wpad != 0 if dopad: data_pad = pad( data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), name="data_pad", ) else: data_pad = data oh_div = u_oh // u_vh ow_div = u_ow // u_vw kvshape = (u_c // u_vc, um, ukh, ukw, u_vc) ovshape = (u_n, u_c * um // u_vc, oh_div, u_ow // u_vw, u_vh, u_vw, u_vc) oshape = (u_n, u_c * um, oh_div * u_vh, ow_div * u_vw) if dilation_h != 1 or dilation_w != 1: # undilate input data dvshape = (u_n, oh_div, ow_div, u_c, ukh, ukw, u_vh, u_vw) data_vec = tvm.compute( dvshape, lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][ (h * u_vh + vh) * hstr + kh * dilation_h][ (w * u_vw + vw) * wstr + kw * dilation_w], name="data_vec_undilated", ) else: dvshape = (u_n, oh_div, ow_div, u_c, u_vh * hstr + ukh - 1, u_vw * wstr + ukw - 1) data_vec = tvm.compute( dvshape, lambda n, h, w, c, vh, vw: data_pad[n][c][h * u_vh * hstr + vh][ w * u_vw * wstr + vw], name="data_vec", ) if pre_packed: kernel_vec = kernel else: kernel_vec = tvm.compute( kvshape, lambda co, m, kh, kw, vc: kernel[co * u_vc + vc][m][kh][kw], name="kernel_vec", ) kh = tvm.reduce_axis((0, ukh), name="kh") kw = tvm.reduce_axis((0, ukw), name="kw") if dilation_h != 1 or dilation_w != 1: conv = tvm.compute( ovshape, lambda n, co, h, w, vh, vw, vc: tvm.sum( data_vec[n, h, w, (co * u_vc + vc) // um, kh, kw, vh, vw]. astype(out_dtype) * kernel_vec[co // um, co % um, kh, kw, vc ].astype(out_dtype), axis=[kh, kw], ), name="depthwise_conv", ) else: conv = tvm.compute( ovshape, lambda n, co, h, w, vh, vw, vc: tvm.sum( data_vec[n, h, w, (co * u_vc + vc) // um, vh * hstr + kh, vw * wstr + kw].astype(out_dtype) * kernel_vec[ co // um, co % um, kh, kw, vc].astype(out_dtype), axis=[kh, kw], ), name="depthwise_conv", ) output = tvm.compute( oshape, lambda n, co, h, w: conv[n][co // u_vc][h // u_vh][w // u_vw][h % u_vh] [w % u_vw][co % u_vc], name="output_unpack", tag="spatial_depthwise_conv_nchw_output", ) return output
def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, VK=6, VP=8, packed_output=False): # return _baseline_winograd(cfg, data, kernel, strides, padding, layout, out_dtype) N, CI, IH, IW = get_const_tuple(data.shape) CO, _, KH, KW = get_const_tuple(kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) assert layout == 'NCHW' assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") A_data = np.array( [[1, 1, 1, 1, 1, 32, 32, 0], [0, 1, -1, 2, -2, 16, -16, 0], [0, 1, 1, 4, 4, 8, 8, 0], [0, 1, -1, 8, -8, 4, -4, 0], [0, 1, 1, 16, 16, 2, 2, 0], [0, 1, -1, 32, -32, 1, -1, 1]], dtype=np.float32).T G_data = np.array( [[1, 0, 0], [-2 / 9, -2 / 9, -2 / 9], [-2 / 9, 2 / 9, -2 / 9], [1 / 90, 1 / 45, 2 / 45], [1 / 90, -1 / 45, 2 / 45], [1 / 45, 1 / 90, 1 / 180], [1 / 45, -1 / 90, 1 / 180], [0, 0, 1]], dtype=np.float32) B_data = np.array([[1, 0, -21 / 4, 0, 21 / 4, 0, -1, 0], [0, 1, 1, -17 / 4, -17 / 4, 1, 1, 0], [0, -1, 1, 17 / 4, -17 / 4, -1, 1, 0], [0, 1 / 2, 1 / 4, -5 / 2, -5 / 4, 2, 1, 0], [0, -1 / 2, 1 / 4, 5 / 2, -5 / 4, -2, 1, 0], [0, 2, 4, -5 / 2, -5, 1 / 2, 1, 0], [0, -2, 4, 5 / 2, -5, -1 / 2, 1, 0], [0, -1, 0, 21 / 4, 0, -21 / 4, 0, 1]], dtype=np.float32).T m = A_data.shape[1] r = 3 alpha = m + r - 1 C = CI H = (IH + 2 * HPAD - 3) // HSTR + 1 W = (IW + 2 * WPAD - 3) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m def round_up(a, b): return ((a + b - 1) // b) * b K = round_up(CO, VK) P = round_up(N * nH * nW, VP) assert K % VK == 0 assert P % VP == 0 G = const_matrix(G_data, 'G') r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') assert K >= CO if K > CO: kernel_pad = pad(kernel, (0, 0, 0, 0), (K - CO, 0, 0, 0), name="kernel_pad") else: kernel_pad = kernel input_tile = tvm.placeholder(shape=(P // VP, C, alpha, alpha, VP), dtype='float32', name="input_tile") U = tvm.placeholder(shape=(K // VK, alpha, alpha, C, VK), dtype='float32', name="U") #U = tvm.compute( # (K // VK, alpha, alpha, C, VK), lambda k, eps, nu, c, kk: # tvm.sum(kernel_pad[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * # G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') ## pack input tile #input_tile = tvm.compute((P // VP, C, alpha, alpha, VP), # lambda b, c, eps, nu, bb: # data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps] # [(b*VP+bb) % nW * m + nu], # name='d') def compute_B_T_dot_X(b, c, eps, nu, bb): temp_expr = {} for j in range(alpha): wd0 = input_tile[b][c][0][j][bb] - input_tile[b][c][6][j][bb] d4_sub_d2 = input_tile[b][c][4][j][bb] - input_tile[b][c][2][j][bb] wd7 = input_tile[b][c][7][j][bb] - input_tile[b][c][1][j][bb] d3_sub_d5 = input_tile[b][c][3][j][bb] - input_tile[b][c][5][j][bb] wd1 = input_tile[b][c][2][j][bb] + input_tile[b][c][6][j][bb] wd2 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb] wd4 = input_tile[b][c][5][j][bb] + input_tile[b][c][1][j][bb] * 0.25 wd5 = input_tile[b][c][6][j][bb] - input_tile[b][c][4][j][bb] * 5 wd3 = input_tile[b][c][6][j][bb] + input_tile[b][c][2][j][bb] * 0.25 wd6 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb] * 0.25 wd0 = wd0 + d4_sub_d2 * 5.25 wd7 = wd7 + d3_sub_d5 * 5.25 wd1 = wd1 - input_tile[b][c][4][j][bb] * 4.25 wd2 = wd2 - input_tile[b][c][3][j][bb] * 4.25 wd3 = wd3 - input_tile[b][c][4][j][bb] * 1.25 wd5 = wd5 + input_tile[b][c][2][j][bb] * 4 wd4 = wd4 - input_tile[b][c][3][j][bb] * 1.25 wd6 = wd6 - input_tile[b][c][3][j][bb] * 1.25 temp_expr[(0, j)] = wd0 temp_expr[(1, j)] = wd1 + wd2 temp_expr[(2, j)] = wd1 - wd2 temp_expr[(3, j)] = wd3 + wd4 * 2 temp_expr[(4, j)] = wd3 - wd4 * 2 temp_expr[(5, j)] = wd5 + wd6 * 2 temp_expr[(6, j)] = wd5 - wd6 * 2 temp_expr[(7, j)] = wd7 now = tvm.const(0.0, "float32") for ii in range(alpha): for jj in range(alpha): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now B_T_dot_X = tvm.compute((P // VP, C, alpha, alpha, VP), compute_B_T_dot_X, name="B_T_dot_X") def compute_X_dot_B(b, eps, nu, c, bb): temp_expr = {} for i in range(alpha): wd0 = B_T_dot_X[b][c][i][0][bb] - B_T_dot_X[b][c][i][6][bb] d4_sub_d2 = B_T_dot_X[b][c][i][4][bb] - B_T_dot_X[b][c][i][2][bb] wd7 = B_T_dot_X[b][c][i][7][bb] - B_T_dot_X[b][c][i][1][bb] d3_sub_d5 = B_T_dot_X[b][c][i][3][bb] - B_T_dot_X[b][c][i][5][bb] wd1 = B_T_dot_X[b][c][i][2][bb] + B_T_dot_X[b][c][i][6][bb] wd2 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] wd4 = B_T_dot_X[b][c][i][5][bb] + B_T_dot_X[b][c][i][1][bb] * 0.25 wd5 = B_T_dot_X[b][c][i][6][bb] - B_T_dot_X[b][c][i][4][bb] * 5 wd3 = B_T_dot_X[b][c][i][6][bb] + B_T_dot_X[b][c][i][2][bb] * 0.25 wd6 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] * 0.25 wd0 = wd0 + d4_sub_d2 * 5.25 wd7 = wd7 + d3_sub_d5 * 5.25 wd1 = wd1 - B_T_dot_X[b][c][i][4][bb] * 4.25 wd2 = wd2 - B_T_dot_X[b][c][i][3][bb] * 4.25 wd3 = wd3 - B_T_dot_X[b][c][i][4][bb] * 1.25 wd5 = wd5 + B_T_dot_X[b][c][i][2][bb] * 4 wd4 = wd4 - B_T_dot_X[b][c][i][3][bb] * 1.25 wd6 = wd6 - B_T_dot_X[b][c][i][3][bb] * 1.25 temp_expr[(i, 0)] = wd0 temp_expr[(i, 1)] = wd1 + wd2 temp_expr[(i, 2)] = wd1 - wd2 temp_expr[(i, 3)] = wd3 + wd4 * 2 temp_expr[(i, 4)] = wd3 - wd4 * 2 temp_expr[(i, 5)] = wd5 + wd6 * 2 temp_expr[(i, 6)] = wd5 - wd6 * 2 temp_expr[(i, 7)] = wd7 now = tvm.const(0.0, "float32") for ii in range(alpha): for jj in range(alpha): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now V = tvm.compute((P // VP, alpha, alpha, C, VP), compute_X_dot_B, name="V") # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((K // VK, P // VP, alpha, alpha, VK, VP), lambda k, b, eps, nu, kk, bb: tvm.sum( U[k][eps][nu][c][kk] * V[b][eps][nu][c][bb], axis=c), name='M') def compute_A_T_dot_M(k, b, eps, nu, kk, bb): temp_expr = {} for j in range(alpha): m1_add_m2 = M[k][b][1][j][kk][bb] + M[k][b][2][j][kk][bb] m1_sub_m2 = M[k][b][1][j][kk][bb] - M[k][b][2][j][kk][bb] m3_add_m4 = M[k][b][3][j][kk][bb] + M[k][b][4][j][kk][bb] m3_sub_m4 = M[k][b][3][j][kk][bb] - M[k][b][4][j][kk][bb] m5_add_m6 = M[k][b][5][j][kk][bb] + M[k][b][6][j][kk][bb] m5_sub_m6 = M[k][b][5][j][kk][bb] - M[k][b][6][j][kk][bb] s0 = M[k][b][0][j][kk][bb] + m1_add_m2 s5 = M[k][b][7][j][kk][bb] + m1_sub_m2 s1 = m1_sub_m2 + m5_sub_m6 * 16 s4 = m1_add_m2 + m3_add_m4 * 16 s2 = m1_add_m2 + 8 * m5_add_m6 s3 = m1_sub_m2 + 8 * m3_sub_m4 s0 = s0 + m5_add_m6 * 32 s5 = s5 + m3_sub_m4 * 32 s1 = s1 + m3_sub_m4 * 2 s4 = s4 + m5_add_m6 * 2 s0 = s0 + m3_add_m4 s5 = s5 + m5_sub_m6 s2 = s2 + m3_add_m4 * 4 s3 = s3 + m5_sub_m6 * 4 temp_expr[(0, j)] = s0 temp_expr[(1, j)] = s1 temp_expr[(2, j)] = s2 temp_expr[(3, j)] = s3 temp_expr[(4, j)] = s4 temp_expr[(5, j)] = s5 now = tvm.const(0.0, "float32") for ii in range(m): for jj in range(alpha): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now A_T_dot_M = tvm.compute((K // VK, P // VP, m, alpha, VK, VP), compute_A_T_dot_M, name="A_T_dot_M") def compute_X_dot_A(k, b, eps, nu, kk, bb): temp_expr = {} for i in range(m): m1_add_m2 = A_T_dot_M[k][b][i][1][kk][bb] + A_T_dot_M[k][b][i][2][ kk][bb] m1_sub_m2 = A_T_dot_M[k][b][i][1][kk][bb] - A_T_dot_M[k][b][i][2][ kk][bb] m3_add_m4 = A_T_dot_M[k][b][i][3][kk][bb] + A_T_dot_M[k][b][i][4][ kk][bb] m3_sub_m4 = A_T_dot_M[k][b][i][3][kk][bb] - A_T_dot_M[k][b][i][4][ kk][bb] m5_add_m6 = A_T_dot_M[k][b][i][5][kk][bb] + A_T_dot_M[k][b][i][6][ kk][bb] m5_sub_m6 = A_T_dot_M[k][b][i][5][kk][bb] - A_T_dot_M[k][b][i][6][ kk][bb] s0 = A_T_dot_M[k][b][i][0][kk][bb] + m1_add_m2 s5 = A_T_dot_M[k][b][i][7][kk][bb] + m1_sub_m2 s1 = m1_sub_m2 + m5_sub_m6 * 16 s4 = m1_add_m2 + m3_add_m4 * 16 s2 = m1_add_m2 + 8 * m5_add_m6 s3 = m1_sub_m2 + 8 * m3_sub_m4 s0 = s0 + m5_add_m6 * 32 s5 = s5 + m3_sub_m4 * 32 s1 = s1 + m3_sub_m4 * 2 s4 = s4 + m5_add_m6 * 2 s0 = s0 + m3_add_m4 s5 = s5 + m5_sub_m6 s2 = s2 + m3_add_m4 * 4 s3 = s3 + m5_sub_m6 * 4 temp_expr[(i, 0)] = s0 temp_expr[(i, 1)] = s1 temp_expr[(i, 2)] = s2 temp_expr[(i, 3)] = s3 temp_expr[(i, 4)] = s4 temp_expr[(i, 5)] = s5 now = tvm.const(0.0, "float32") for ii in range(m): for jj in range(m): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now Y = tvm.compute((K // VK, P // VP, m, m, VK, VP), compute_X_dot_A, name="Y") # unpack output def _output(n, k_, h, w): b_idx = n * nH * nW + (h // m) * nW + w // m b = b_idx // VP bb = b_idx % VP k = k_ // VK kk = k_ % VK return Y[k][b][h % m][w % m][kk][bb] output = tvm.compute((N, CO, H, W), _output, name='output', tag='winograd_conv_output') if cfg: cfg.add_flop(2 * N * K * H * W * KH * KW * C) return Y, input_tile, U, output
def _conv_spatial_pack_asm(args, data, kernel, strides, padding, dilation, out_dtype): """_conv_spatial_pack_asm""" is_var, vh_, vw_, vc_ = args # create workload according to raw arguments out_dtype = out_dtype or data.dtype n_, ci_, ih_, iw_ = data.shape if is_var else get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if len(kernel.shape) == 4: pre_packed = False co_, _, kh_, kw_ = kernel.shape if is_var else get_const_tuple( kernel.shape) else: # kernel tensor is pre packed pre_packed = True co_, _, kh_, kw_, vc_ = kernel.shape if is_var else get_const_tuple( kernel.shape) co_ = co_ * vc_ dilated_kernel_h = (kh_ - 1) * dilation_h + 1 dilated_kernel_w = (kw_ - 1) * dilation_w + 1 pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) hstr, wstr = strides if isinstance(strides, (tuple, list)) else (strides, strides) oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1 ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1 data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right]) oh_div = oh_ // vh_ ow_div = ow_ // vw_ kvshape = (co_ // vc_, ci_, kh_, kw_, vc_) ovshape = (n_, co_ // vc_, oh_div, ow_div, vh_, vw_, vc_) oshape = (n_, co_, oh_div * vh_, ow_div * vw_) if dilation_h != 1 or dilation_w != 1: # undilate input data dvshape = (n_, oh_ // vh_, ow_ // vw_, kh_, kw_, vh_, vw_, ci_) data_vec = tvm.compute( dvshape, lambda n, h, w, kh, kw, vh, vw, ci: data_pad[n][ci][ (h * vh_ + vh) * hstr + kh * dilation_h][ (w * vw_ + vw) * wstr + kw * dilation_w], name="data_vec_undilated", ) else: dvshape = ( n_, oh_ // vh_, ow_ // vw_, (vh_ - 1) * hstr + kh_, (vw_ - 1) * wstr + kw_, ci_, ) data_vec = tvm.compute( dvshape, lambda n, h, w, vh, vw, ci: data_pad[n][ci][h * vh_ * hstr + vh][ w * vw_ * wstr + vw], name="data_vec", ) if pre_packed: kernel_vec = kernel else: kernel_vec = tvm.compute( kvshape, lambda co, ci, kh, kw, vc: kernel[co * vc_ + vc][ci][kh][kw], name="kernel_vec", ) ci = tvm.reduce_axis((0, ci_), name="ci") kh = tvm.reduce_axis((0, kh_), name="kh") kw = tvm.reduce_axis((0, kw_), name="kw") # asm begin---- type_map = { "int8": "int32", "uint8": "uint32", "float32": "float32", "float16": "float16", } acum_dtype = type_map[data.dtype] attrs = { "SH": hstr, "SW": wstr, "PH": pad_top, "PW": pad_left, "DILA_H": dilation_h, "DILA_W": dilation_w, "VH": vh_, "VW": vw_, "VC": vc_, "ACUM_DTYPE": acum_dtype, } # asm end---- if dilation_h != 1 or dilation_w != 1: conv = tvm.compute( ovshape, lambda n, co, h, w, vh, vw, vc: tvm.sum( data_vec[n, h, w, kh, kw, vh, vw, ci].astype(out_dtype) * kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), axis=[ci, kh, kw], ), name="conv", attrs=attrs, ) else: conv = tvm.compute( ovshape, lambda n, co, h, w, vh, vw, vc: tvm.sum( data_vec[n, h, w, vh * hstr + kh, vw * wstr + kw, ci].astype( out_dtype) * kernel_vec[co, ci, kh, kw, vc].astype( out_dtype), axis=[ci, kh, kw], ), name="conv", attrs=attrs, ) output = tvm.compute( oshape, lambda n, co, h, w: conv[n][co // vc_][h // vh_][w // vw_][h % vh_][ w % vw_][co % vc_], name="output_unpack", tag="asm_conv2d_output", ) return output
def decl_winograd(data, U, stride, padding, out_dtype): """declare winograd fast convolution F(2x2, 3x3) for conv2d""" N, C, H, W = [util.get_const_int(x) for x in data.shape] _, _, C, K = [util.get_const_int(x) for x in U.shape] HPAD, WPAD = 1, 1 if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") B_data = np.array( [[1, 0, 0, 0], [0, 1, -1, 1], [-1, 1, 1, 0], [0, 0, 0, -1]], out_dtype) A_data = np.array([ [1, 0], [1, 1], [1, -1], [0, -1], ], out_dtype) m = 2 r = 3 alpha = m + r - 1 K = K nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW # pack input tile input_tile = tvm.compute( (C, P, alpha, alpha), lambda c, b, eps, nu: tvm.select( b < P, data_pad[b // (nH * nW)][c][b // nW % nH * m + eps][ b % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') # transform image B = const_array(B_data, 'B') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') V = tvm.compute((alpha, alpha, C, P), lambda eps, nu, c, b: tvm.sum(input_tile[c][b][r_eps][ r_nu] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute( (alpha, alpha, K, P), lambda eps, nu, k, b: tvm.sum(U[eps][nu][c][k] * V[eps][nu][c][b], axis=c), name='M') # inverse transform and unpack A = const_array(A_data, 'A') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') output = tvm.compute( (N, K, H, W), lambda n, k, h, w: tvm.sum(M[r_eps][r_nu][k][n * nH * nW + ( h // m) * nW + w // m] * A[r_eps][h % m] * A[r_nu][w % m], axis=[r_eps, r_nu]), name='output') return output
def compute_conv2d_gemm_without_weight_transform(cfg, data, B_interleaved_t, strides, padding, dilation, out_dtype, kernel_size, output_channels): """Compute conv2d by transforming the input, executing GEMM and transforming the output back""" batches, IH, IW, IC = get_const_tuple(data.shape) KH, KW = kernel_size OC = output_channels K_AREA = KH * KW if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = \ get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 if pad_top or pad_left: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: data_pad = data # --- Im2col M = OH * OW K = IC * K_AREA N = OC A_shape = (batches, M, K) if K_AREA == 1: A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y], name='data_flatten') else: A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW) + dilation_h * (y // IC) // KW, WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC], name='data_im2col') N_transformed = B_interleaved_t.shape[0] # --- Pad if necessary idxm = tvm.tir.indexmod pad_m = 0 pad_k = 0 if M % 4 != 0: pad_m = 4 - (M % 4) if K % 16 != 0: pad_k = 16 - (K % 16) M_padded = M + pad_m K_padded = K + pad_k pad_before = (0, 0, 0) pad_after = (0, pad_m, pad_k) if pad_m != 0 or pad_k != 0: A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded") # --- GEMM: A*B' k = te.reduce_axis((0, K_padded), "k") A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16), lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y], name='A_interleaved') C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4), lambda b, x, y, w, z: te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)* B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype), axis=k), name='C_interleaved') # --- Unpack C C = te.compute((batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], name="C", tag='injective') # --- Produce the conv output out_shape = (batches, OH, OW, OC) out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name='conv2d_gemm_output') return out
def _spatial_conv_all(wkl, sch, data, kernel, out_dtype): H, W = wkl.height, wkl.width CI, CO = wkl.in_filter, wkl.out_filter KH, KW = wkl.hkernel, wkl.wkernel HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride HCAT, WCAT = KH - 1, KW - 1 VH = sch.vh VW = sch.vw VC = sch.vc UNROLL = sch.unroll TH = H + 2 * HPAD TW = W + 2 * WPAD OH = (H + 2 * HPAD - KH) // HSTR + 1 OW = (W + 2 * WPAD - KW) // WSTR + 1 dshape = (1, CI, H, W) dpshape = (1, CI, TH, TW) dvshape = (1, TH // (VH * HSTR), TW // (VW * WSTR), CI, VH * HSTR + HCAT, VW * WSTR + WCAT) DOPAD = (HPAD != 0 and WPAD != 0) if DOPAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: \ data_pad[n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw], name='data_vec') kshape = (CO, CI, KH, KW) kvshape = (CO // VC, CI, KH, KW, VC) kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \ kernel[co * VC + vc][ci][dh][dw], name='kernel_vec') ci = tvm.reduce_axis((0, CI), name='ci') dh = tvm.reduce_axis((0, KH), name='dh') dw = tvm.reduce_axis((0, KW), name='dw') ovshape = (1, CO // VC, OH // VH, OW // VW, VH, VW, VC) oshape = (1, CO, OH, OW) conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ tvm.sum(data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw].astype(out_dtype) * kernel_vec[co, ci, dh, dw, vc].astype(out_dtype), axis=[ci, dh, dw]), name='conv') output = tvm.compute(oshape, lambda n, co, h, w: conv[n][co // VC][h // VH][w // VW][h % VH][w % VW][co % VC], name='output_unpack', tag='spatial_conv_output') s = tvm.create_schedule(conv.op) traverse(s, conv.op) # schedule for data_vec A0, A1 = data_pad, data_vec if DOPAD: s[A0].compute_inline() _, h, _, _, _, _ = s[A1].op.axis if sch.ba == 1: oaxis = h paxis = h else: oh, ih = s[A1].split(h, sch.ba) oaxis = oh paxis = ih s[A1].parallel(paxis) s[A1].pragma(oaxis, "parallel_launch_point") s[A1].pragma(paxis, "parallel_stride_pattern") s[A1].pragma(oaxis, "parallel_barrier_when_finish") # schedule for kernel_vec B, B0 = kernel, kernel_vec co, _, _, _, _ = s[B0].op.axis if sch.bc == 1: oaxis = co paxis = co else: oco, ico = s[B0].split(co, sch.bc) oaxis = oco paxis = ico s[B0].parallel(paxis) s[B0].pragma(oaxis, "parallel_launch_point") s[B0].pragma(paxis, "parallel_stride_pattern") s[B0].pragma(oaxis, "parallel_barrier_when_finish") # schedule for conv & unpack C0, C = conv, output s = tvm.create_schedule(C.op) traverse(s, C.op) CC = s.cache_write(C0, "global") _, co, oh, ow, vh, vw, vc = s[C0].op.axis if UNROLL: s[C0].unroll(vw) s[C0].vectorize(vc) s[CC].compute_at(s[C0], ow) _, co, oh, ow, vh, vw, vc = s[CC].op.axis ci, dh, dw = s[CC].op.reduce_axis s[CC].reorder(ci, dh, vh, dw, vw, vc) if UNROLL: s[CC].unroll(vw) s[CC].vectorize(vc) n, co, h, w = s[C].op.axis co, vc = s[C].split(co, VC) oh, ow, vh, vw = s[C].tile(h, w, VH, VW) s[C].reorder(n, co, oh, ow, vh, vw, vc) # if C != C1: # s[C1].compute_inline() s[C0].compute_at(s[C], ow) if sch.bc == 1: oaxis = co paxis = co else: oco, ico = s[C].split(co, sch.bc) oaxis = oco paxis = ico s[C].parallel(paxis) s[C].pragma(oaxis, "parallel_launch_point") s[C].pragma(paxis, "parallel_stride_pattern") s[C].pragma(oaxis, "parallel_barrier_when_finish") return C, s
def _im2col_pack(wkl, sch, data, kernel, stride, padding, out_dtype): """ Compute convolution with im2col pack layout. """ assert data.shape[ 0].value == 1, "im2col pack convolution only support batch size=1" N = 1 H, W = wkl.height, wkl.width CI = wkl.in_filter CO = wkl.out_filter KH, KW = wkl.hkernel, wkl.wkernel HPAD, WPAD = wkl.hpad, wkl.hpad HSTR, WSTR = wkl.hstride, wkl.wstride OH = (H + 2 * HPAD - KH) // HSTR + 1 OW = (W + 2 * WPAD - KW) // WSTR + 1 P = sch.vp Q = sch.vq UNROLL = sch.unroll dshape = (N, CI, H, W) dpshape = (N, CI, H + 2 * HPAD, W + 2 * WPAD) dcshape = (N, OH, OW, CI, KH, KW) dvshape = (N, OH * OW // P, CI, KH, KW, P) kshape = (CO, CI, KH, KW) kvshape = (CO // Q, CI, KH, KW, Q) ovshape = (N, CO // Q, OH * OW // P, P, Q) oshape = (N, CO, OH, OW) ############### declaration DO_PAD = (wkl.hpad != 0 and wkl.wpad != 0) if DO_PAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data data_col = tvm.compute(dcshape, lambda n, oh, ow, ci, hk, wk: \ data_pad[n][ci][oh*HSTR+hk][ow*WSTR+wk], name='data_col') data_vec = tvm.compute(dvshape, lambda n, im, ci, hk, wk, vim: \ data_col[n][(im*P+vim)//OW][(im*P+vim)%OW][ci][hk][wk], name='data_vec') kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \ kernel[co*Q+vc][ci][dh][dw], name='kernel_vec') ci = tvm.reduce_axis((0, CI), name='ci') hk = tvm.reduce_axis((0, KH), name='hk') wk = tvm.reduce_axis((0, KW), name='wk') conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \ tvm.sum(data_vec[n][im][ci][hk][wk][vim].astype(out_dtype) * kernel_vec[co][ci][hk][wk][vco].astype(out_dtype), axis=[ci, hk, wk]), name='conv') output = tvm.compute(oshape, lambda n, co, h, w: \ conv[n][co//Q][(h*OW+w)//P][(h*OW+w)%P][co%Q], name='output_vec', tag='im2col_conv_output') return output