def fused_batch_norm_manual_setdim(shape): """manual setdim for fused batch norm with dynamic shape""" from akg import dim info = dim.Dim() for i, d in enumerate(DYNAMIC_SETDIM_MAP.get(shape, [])): info.setdim(index=0, axis=i, tilel1=d, tilel0=1) return str(info)
def dropout_set_dim_func(data_tensor, data_mask, keep_prob): shape = [x.value for x in data_tensor.shape if x.value != 1] dtype = data_tensor.dtype storage = 49152 if dtype.lower() == 'float16': dnum = 1 else: dnum = 2 info = dim.Dim() list_info = [] def cal_max_divisor(a, threshold): for i in range(threshold, 0, -1): if a % i == 0: return i return 1 for i in range(len(shape) - 1, -1, -1): if dnum >= storage: list_info.append((i, 1)) elif dnum * shape[i] > storage: list_info.append((i, cal_max_divisor(shape[i], storage // dnum))) dnum *= shape[i] for i in reversed(list_info): info.setdim(index=0, axis=i[0], tilel1=i[1], tilel0=1) return str(info)
def set_dims_group(cut_h, cut_co, cut_m, cut_k, cut_n, out_shape_5d, _c_i, _c_o, group, _k_h, _k_w, _s_h, block_size): info = dim.Dim() out_n, out_c1, out_h, out_w, out_c0 = out_shape_5d tile_out_h = (cut_h - _k_h) // _s_h + 1 if (out_n > 1): info.setdim(index=0, axis=0, tilel1=1, tilel0=0) if (out_c1 > 1): info.setdim(index=0, axis=0, tilel1=cut_co // block_size, tilel0=0) if (out_h > 1): info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0) if (out_w > 1): info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0) if (out_c0 > 1): info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) assert _c_i // block_size // group == 1 if (_c_i // block_size // group > 1): info.setdim(index=0, axis=5, tilel1=_c_i // block_size // group, tilel0=0) if (_k_h > 1): info.setdim(index=0, axis=5, tilel1=_k_h, tilel0=0) if (_k_w > 1): info.setdim(index=0, axis=5, tilel1=_k_w, tilel0=0) return str(info)
def set_dims(fmap_shape, filter_shape, pad_, stride_, dilation_, tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, block_size): """set dim info in attrs.""" in_n, in_c, in_h, in_w = fmap_shape in_c = (in_c + block_size - 1) // block_size * block_size in_c1 = in_c // block_size # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape k_c = (k_c + block_size - 1) // block_size * block_size k_n = (k_n + block_size - 1) // block_size * block_size padding = (pad_[0], pad_[0], pad_[1], pad_[1]) p_top, p_bottom, p_left, p_right = padding s_h, s_w = (stride_[0], stride_[1]) d_h, d_w = (dilation_[0], dilation_[1]) if (tile_hh == in_h): tile_hh += p_top + p_bottom tile_coco = (tile_coco + block_size - 1) // block_size * block_size tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = (tile_kk + block_size - 1) // block_size * block_size tile_nn = (tile_nn + block_size - 1) // block_size * block_size k_h_d = (k_h - 1) * d_h + 1 k_w_d = (k_w - 1) * d_w + 1 out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1 tile_out_h = (tile_hh - k_h_d) // s_h + 1 out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 if (tile_coco > 0): c1_cut = tile_coco // block_size else: c1_cut = out_c1 # set dim info = dim.Dim() if (out_n > 1): info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if (out_c1 > 1): info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1 if (out_h > 1): info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h if (out_w > 1): info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0) # w if (out_c0 > 1): info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 if (in_c1 > 1): info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1 if (k_h > 1): info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh if (k_w > 1): info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw return str(info)
def smooth_l1_loss_grad_get_dim(shape): """ get dim attr for smooth L1 loss grad Args: shape: the shape of prediction tensor (e.g. [8, 4718, 4]) Returns: dim string for akg.op.build(attrs=...) """ # example shape: [8, 4718, 4] # cut dim: ((1,1), (1024,1024)) tensor_size = 1 for i in shape[:-1]: tensor_size *= i # if tensor_size >= threshold, cut ub_size = 256 * 1024 # estimated maximum number of data copies in UB num_data_copies = 32 data_size = 4 # do not cut the last dim max_tensor_size = int(ub_size / data_size / num_data_copies / shape[-1]) if tensor_size > max_tensor_size: # find the largest divisor of tensor_size to be the tile size # currently the dim size must be divisible by tile size tile_size = 1 for i in range(max_tensor_size, 1, -1): if tensor_size % i == 0: tile_size = i break # generate setdim string info = dim.Dim() # do not cut last dim for i in range(0, len(shape) - 2): info.setdim(index=0, axis=i, tilel1=1, tilel0=1) # cut -2 dim info.setdim(index=0, axis=len(shape) - 2, tilel1=tile_size, tilel0=tile_size) return str(info) return ''
def test_quant(fmap_shape): # input shape(NCHW -> NC1HWC0) in_n, in_c, in_h, in_w = fmap_shape assert in_c % 32 == 0 input_shape_nc1hwc0 = (in_n, in_c // 16, in_h, in_w, 16) in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0 # placeholder (NC1HWC0) FMap = akg.tvm.placeholder(input_shape_nc1hwc0, dtype='float16', name='FMap') ScaleQ = akg.tvm.placeholder((16, ), dtype='float16', name='ScaleQ') OffsetQ = akg.tvm.placeholder((16, ), dtype='float16', name='OffsetQ') out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32) print(out_shape_nc1hwc0) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 # quantize Quant = akg.tvm.compute(out_shape_nc1hwc0, lambda n, c1, h, w, c0: (FMap[n, c1 + c0 // 16, h, w, c0 % 16] * ScaleQ[0] + OffsetQ[0]).astype('int8'), name='output') info = dim.Dim() info.setdim(index=0, axis=0, tilel1=2, tilel0=0) info.setdim(index=0, axis=0, tilel1=32, tilel0=0) info.setdim(index=0, axis=0, tilel1=32, tilel0=0) info.setdim(index=0, axis=0, tilel1=16, tilel0=0) # schedule s = akg.tvm.create_schedule(Quant.op) with akg.build_config(add_lower_pass=utils.debug_mode(0), dump_pass_ir=True): mod = akg.build(s, [FMap, ScaleQ, OffsetQ, Quant], 'cce', name='cce_quant', attrs={'dim': str(info)}, polyhedral=True) source_code = mod.imported_modules[0].get_source() print(source_code)
def set_dims(tiling): """Set dim for tiling.""" info = dim.Dim() for d, tile_d in enumerate(tiling): if len(tile_d) == 2: # only c1 and c0 tile index = 0 axis = d c1 = tile_d[0] c0 = tile_d[1] elif len(tile_d) == 4: # index, axis, c1, c0 index = tile_d[0] axis = tile_d[1] c1 = tile_d[2] c0 = tile_d[3] else: raise RuntimeError( "Each element in tiling should be length-2 (c1_tile, c0_tile) " "or length-4 (band_index, axis_index, c1_tile, c0_tile)") info.setdim(index=index, axis=axis, tilel1=c1, tilel0=c0) return str(info)
def gen_static_dim(): info = dim.Dim() if out_n > 1: info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if out_c1 > 1: info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1 if out_h > 1: info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h if out_w > 1: info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0) # w if out_c0 > 1: info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 if in_c1 > 1: info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1 if k_h > 1: info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh if k_w > 1: info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0) # kc0 return info
def gen_dynamic_dim(): info = dim.Dim() if dynamic: info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n elif out_n > 1: info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if dynamic_tiling: info.setdim(index=0, axis=0, tilel1=c1_cut_fake, tilel0=0) # c1 elif dynamic or out_c1 > 1: info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1 if dynamic_tiling: info.setdim(index=0, axis="H", tilel1=tile_out_h_fake, tilel0=0) # h elif dynamic or out_h > 1: info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h if dynamic_tiling: info.setdim(index=0, axis="W", tilel1=tile_out_w_fake, tilel0=0) # w elif dynamic or out_w > 1: info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0) # w if dynamic or out_c0 > 1: info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 if dynamic and not use_autotiling: info.setdim(index=0, axis=5, tilel1=dynamic_ci_c1, tilel0=0) # kc1 elif dynamic or in_c1 > 1: info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1 if dynamic or k_h > 1: info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh if dynamic or k_w > 1: info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0) # kc0 return info
def test_CCE_Conv(fmap_shape, filter_shape, pad_, stride_, tile_hh=0, tile_coco=0, tile_mm=0, tile_kk=0, tile_nn=0, bypass_l1=False, use_bias=False, kernel_name="quant_conv", cce_path='.'): # input shape (NCHW -> NC1HWC0) in_n, in_c, in_h, in_w = fmap_shape input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size) # out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32) in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0 # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape kernel_shape_nc1hwc0 = (k_n, k_c // 32, k_h, k_w, 32) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 kernel_shape_fractal = (k_c // 32 * k_h * k_w, k_n // 16, 16, 32) f_ko, f_no, f_ni, f_ki = kernel_shape_fractal # bias shape bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size) # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[0], pad_[1], pad_[1]) p_top, p_bottom, p_left, p_right = padding # stride (stride_h, stride_w) s_h, s_w = stride_ # A placeholder (NC1HWCO) A = akg.tvm.placeholder(input_shape_nc1hwc0, dtype=conv_dtype, name='FMap') # B_placeholder (fractal) B = akg.tvm.placeholder(kernel_shape_fractal, dtype='int8', name='Filter') ScaleQ = akg.tvm.placeholder((16, ), dtype='float16', name='ScaleQ') OffsetQ = akg.tvm.placeholder((16, ), dtype='float16', name='OffsetQ') out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32) q_n, q_c1, q_h, q_w, q_c0 = out_shape_nc1hwc0 # print out_shape_nc1hwc0 Quant = akg.tvm.compute(out_shape_nc1hwc0, lambda qn, qc1, qh, qw, qc0: (A[qn, qc1 + qc0 // 16, qh, qw, qc0 % 16] * ScaleQ[ 0] + OffsetQ[0]).astype('int8'), name='QuantOUT', attrs={'no_inline': 1}) if use_bias: bias_name = 'bias' bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0, dtype=conv_dtype, name=bias_name) else: bias_name = 'None' # Create reduction variables kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1') kh = akg.tvm.reduce_axis((0, k_h), name='kh') kw = akg.tvm.reduce_axis((0, k_w), name='kw') kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0') out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1 tile_out_h = (tile_hh - k_h) // s_h + 1 out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 if (tile_coco > 0): c1_cut = tile_coco // block_size else: c1_cut = out_c1 # set dim index = 0 info = dim.Dim() if (q_c1 > 1): info.setdim(index=index, axis="KO", tilel1=q_c1, tilel0=q_c1) # ko if (q_h > 1): info.setdim(index=index, axis="C1", tilel1=tile_out_h, tilel0=tile_out_h) # c1 if (q_w > 1): info.setdim(index=index, axis="C0", tilel1=q_w, tilel0=q_w) # c0 if (q_c0 > 1): info.setdim(index=index, axis="KI", tilel1=q_c0, tilel0=q_c0) # ki index += 1 if (out_c1 > 1): info.setdim(index=index, axis="C1", tilel1=c1_cut, tilel0=0) # c1 if (out_h > 1): info.setdim(index=index, axis="H", tilel1=tile_out_h, tilel0=0) # h if (out_w > 1): info.setdim(index=index, axis="W", tilel1=out_w, tilel0=0) # w if (out_c0 > 1): info.setdim(index=index, axis="C0", tilel1=out_c0, tilel0=0) # c0 if (in_c1 > 1): info.setdim(index=index, axis="KC1", tilel1=in_c1 / 2, tilel0=0) # kc1 if (k_h > 1): info.setdim(index=index, axis="KH", tilel1=k_h, tilel0=0) # kh if (k_w > 1): info.setdim(index=index, axis="KW", tilel1=k_w, tilel0=0) # kw info = str(info) # Compute the convolution output_name = "output0" output_bias_name = "output1" # print out_shape_nc1hwc0 C = akg.tvm.compute( out_shape_nc1hwc0, lambda n, c1, h, w, c0: akg.tvm.sum(akg.tvm.if_then_else( akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) > (in_h + p_top - 1), (w * s_w + kw) < p_left, (w * s_w + kw) > (in_w + p_left - 1)), akg.tvm.const(0.0, 'int8'), Quant[n, kc1, (h * s_h + kh - p_top), (w * s_w + kw - p_left), kc0]) * B[ (kc1 * k_h + kh) * k_w + kw, c1, c0, kc0], axis=[kc1, kh, kw, kc0]), name=output_name, attrs={ "pragma_conv_kernel_n": k_n, "pragma_conv_kernel_h": k_h, "pragma_conv_kernel_w": k_w, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_dilation_h": 1, "pragma_conv_dilation_w": 1, "pragma_conv_bypass_l1": 1 if bypass_l1 else 0, "pragma_conv_stride_h": s_h, "pragma_conv_stride_w": s_w, "pragma_conv_fm_n": in_n, "pragma_conv_fm_c": in_c, "pragma_conv_fm_h": in_h, "pragma_conv_fm_w": in_w, "pragma_conv_h_cut": (h_window_cut - 1) * s_h + k_h, "pragma_conv_w_cut": (in_w + p_left + p_right), "pragma_conv_co_cut": c1_cut * k_c0, "pragma_conv_m_cut": tile_mm, "pragma_conv_k_cut": tile_kk, "pragma_conv_n_cut": tile_nn, "feature": Quant.op.name, "filter": B.op.name, "bias": bias_name, "res": output_name, "res_bias": output_bias_name }) if use_bias: cube = akg.tvm.compute(out_shape_nc1hwc0, lambda n, c1, h, w, c0: C[n, c1, h, w, c0] + bias_value[0, c1, 0, 0, c0], name=output_bias_name) else: cube = C if fusion: # leakly relu negative_slope = 0.0 slope_tmp = akg.tvm.const(negative_slope, dtype=conv_dtype) # negative_slope*x out = akg.lang.cce.vmuls(cube, slope_tmp) # max(x,negative_slope*x) out = akg.lang.cce.vmax(out, cube) else: out = cube # schedule s = akg.tvm.create_schedule(out.op) attrs = {} attrs["pragma_reschedule"] = 1 with akg.build_config(add_lower_pass=cce.debug_mode(0), dump_pass_ir=True): if fusion: if use_bias: mod = akg.build(s, [A, B, ScaleQ, OffsetQ, bias_value, out], "cce", name=kernel_name, attrs=attrs, attrs={"dim": info}, polyhedral=True) else: mod = akg.build(s, [A, B, ScaleQ, OffsetQ, out], "cce", name=kernel_name, attrs=attrs, attrs={"dim": info}, polyhedral=True) else: if use_bias: mod = akg.build(s, [A, B, ScaleQ, OffsetQ, bias_value, out], "cce", name=kernel_name, attrs=attrs, attrs={"dim": info}, polyhedral=True) else: mod = akg.build(s, [A, B, ScaleQ, OffsetQ, out], "cce", name=kernel_name, attrs=attrs, attrs={"dim": info}, polyhedral=True) source_code = mod.imported_modules[0].get_source() # print(source_code) # utils.create_code(kernel_name, cce_path, source_code) if run_cce: run_conv(mod, fmap_shape, filter_shape, pad_[0], stride_[0], use_bias)
def group_conv(N, H, W, CI, CO, group, KH, KW, PAD_H, PAD_W, SH, SW, cutH, cutCo, cutM, cutK, cutN, block_size, use_bias=False, kernel_name='conv'): """ split channels of FeatureMap to some groups,every group has its filter-kernel Args: args1:a list,the size is 3 if use_bias else the size is 2; data[0] akg.tvm.Tensor of type float16 ,shape 5D(N, CI//C0, C0, H, W) data[1] akg.tvm.Tensor of type float16 ,shape 6D(CI//(CI//C0)//C0, KH, KW, k_ch*CI//C0, C0, C0) data[2] akg.tvm.Tensor of type float16 ,shape 5D(N, CI*k_ch//C0, OH, OW, C0) N:batchsize H:height of featureMap W:width of featureMap CI:channel of featureMap C0:num of Filters group:num of spliting channels of FeatureMap KH:height of Filter KW:width of Filter PAD_H:padding pixels in vertical direction PAD_W:padding pixels in horizontal direction SH:stride in vertical direction SW:stride in horizontal direction block_size:a int var use_bias:a bool value Returns: akg.tvm.Tensor of same type as data, shape is 5D(N, C0//block_size, block_size, OH, OW) """ conv_dtype = "float16" if cutH == H: cutH += PAD_H + PAD_H assert CO % group == 0 and CI % group == 0 assert CO % block_size == 0 and (CI // group) % block_size == 0 # (N, CI, H, W) -> (N, C0, H, W, C1) A = akg.tvm.placeholder((N, CI // block_size, H, W, block_size), dtype=conv_dtype, name="A") # (CO, CI // group, KH, KW) -> (CI // group // block * KH * KW, CO // block, block, block) B = akg.tvm.placeholder((CI // group // block_size * KH * KW, CO // block_size, block_size, block_size), dtype=conv_dtype, name="B") bias = akg.tvm.placeholder((1, CO // block_size, 1, 1, block_size), dtype=conv_dtype, name="bias") OH = (H + 2 * PAD_H - KH) // SH + 1 OW = (W + 2 * PAD_W - KW) // SW + 1 kc1 = akg.tvm.reduce_axis((0, CI // block_size // group), name="kc1") kh = akg.tvm.reduce_axis((0, KH), name="kh") kw = akg.tvm.reduce_axis((0, KW), name="kw") kc0 = akg.tvm.reduce_axis((0, block_size), name="kc0") p_top, p_bottom, p_left, p_right = PAD_H, PAD_H, PAD_W, PAD_W output_name = "output" output_bias_name = "output_bias" C = akg.tvm.compute( (N, CO // block_size, OH, OW, block_size), lambda n, c1, h, w, c0: akg.lang.ascend.mmad(akg.tvm.if_then_else( akg.tvm.any((h * SH + kh) < p_top, (h * SH + kh) > (H + p_top - 1), (w * SW + kw) < p_left, (w * SW + kw) > (W + p_left - 1)), akg.tvm.const(0.0, conv_dtype), A[n, c1 // ((CO // block_size) // group) * ( (CI // block_size) // group) + kc1, (h * SH + kh - p_top), (w * SW + kw - p_left), kc0]) * B[ (kc1 * KH + kh) * KW + kw, c1, c0, kc0], axis=[kc1, kh, kw, kc0]), attrs={ "pragma_conv_kernel_n": CO, "pragma_conv_kernel_h": KH, "pragma_conv_kernel_w": KW, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_bypass_l1": 1, "pragma_conv_stride_h": SH, "pragma_conv_stride_w": SW, "pragma_conv_fm_n": N, "pragma_conv_fm_c": CI, "pragma_conv_fm_h": H, "pragma_conv_fm_w": W, "pragma_conv_dilation_h": 1, "pragma_conv_dilation_w": 1, "pragma_conv_h_cut": cutH, "pragma_conv_w_cut": W + 2 * PAD_W, "pragma_conv_co_cut": cutCo, "pragma_conv_m_cut": cutM, "pragma_conv_k_cut": cutK, "pragma_conv_n_cut": cutN, "feature": A.op.name, "filter": B.op.name, "bias": bias.op.name, "res": output_name, "res_bias": output_bias_name }, name=output_name) if use_bias: out = akg.tvm.compute( C.shape, lambda n, c1, h, w, c0: C[n, c1, h, w, c0] + bias[0, c1, 0, 0, c0], name=output_bias_name) bufs = [A, B, bias, out] else: out = C bufs = [A, B, out] # create schedule for cce s = akg.tvm.create_schedule([out.op]) # set cut / tiling out_n, out_c1, out_h, out_w, out_c0 = akg.topi.util.get_const_tuple( out.shape) # set dim tile_out_h = (cutH - KH) // SH + 1 info = dim.Dim() if (out_n > 1): info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if (out_c1 > 1): info.setdim(index=0, axis=0, tilel1=cutCo // block_size, tilel0=0) # c1 if (out_h > 1): info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0) # h if (out_w > 1): info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0) # w if (out_c0 > 1): info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 assert CI // block_size // group == 1 if (CI // block_size // group > 1): info.setdim(index=0, axis=5, tilel1=CI // block_size // group, tilel0=0) # kc1 if (KH > 1): info.setdim(index=0, axis=5, tilel1=KH, tilel0=0) # kh if (KW > 1): info.setdim(index=0, axis=5, tilel1=KW, tilel0=0) # kw # build with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod = akg.build(s, bufs, "cce", name=kernel_name, attrs={"dim": str(info)}, polyhedral=True) return OH, OW, A, B, C, mod
def add_a_conv_compute(fmap_shape, filter_shape, pad_, stride_, dilation_, tile_hh=0, tile_coco=0, tile_mm=0, tile_kk=0, tile_nn=0, bypass_l1=False, use_bias=False, block_size=16, conv_dtype='float16'): # input shape (NCHW -> NC1HWC0) in_n, in_c, in_h, in_w = fmap_shape in_c = (in_c + block_size - 1) // block_size * block_size # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape k_c = (k_c + block_size - 1) // block_size * block_size k_n = (k_n + block_size - 1) // block_size * block_size # padding((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[0], pad_[1], pad_[1]) p_top, p_bottom, p_left, p_right = padding # stride (stride_h, stride_w) s_h, s_w = stride_ # dilation (dilation_h, dilation_w) d_h, d_w = dilation_ if tile_hh == in_h: tile_hh += p_top + p_bottom tile_coco = (tile_coco + block_size - 1) // block_size * block_size tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = (tile_kk + block_size - 1) // block_size * block_size tile_nn = (tile_nn + block_size - 1) // block_size * block_size c0 = block_size c1_cut = tile_coco // c0 h_window_cut = (tile_hh - k_h) // s_h + 1 out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1 kernel_name = "add_a_conv_layer_" + str(in_n) + "_" + str(in_c) + "_" + str(in_h) + "_" + str(in_w) \ + "_" + str(k_n) + "_" + str(in_c) + "_" + str(k_h) + "_" + str(k_w) \ + "_" + str(p_top) + "_" + str(s_h) input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size) in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0 kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size, block_size, block_size) # bias shape bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size) # a_value placeholder (NC1HWCO) a_tmp = akg.tvm.placeholder(input_shape_nc1hwc0, dtype=conv_dtype, name='a_tmp') a_value = akg.tvm.compute(a_tmp.shape, lambda n, kc1, h, w, kc0: a_tmp[n, kc1, h, w, kc0] + 1, \ name='a_value', attrs={'no_inline': 1}) # b_value placeholder (fractal) b_value = akg.tvm.placeholder(kernel_shape_fractal, dtype=conv_dtype, name='b_value') if use_bias: bias_name = 'bias' bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0, dtype=conv_dtype, name=bias_name) else: bias_name = 'None' bias_value = None # Create reduction variables kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1') kh = akg.tvm.reduce_axis((0, k_h), name='kh') kw = akg.tvm.reduce_axis((0, k_w), name='kw') kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0') k_h_d = (k_h - 1) * d_h + 1 k_w_d = (k_w - 1) * d_w + 1 out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1 tile_out_h = (tile_hh - k_h_d) // s_h + 1 out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) _, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 if tile_coco > 0: c1_cut = tile_coco // block_size else: c1_cut = out_c1 # set dim if s_h > k_h: a_cut_h = tile_out_h * s_h else: a_cut_h = (tile_out_h - 1) * s_h + k_h_d a_cut_w = (out_w - 1) * s_w + k_w_d index = 0 info = dim.Dim() if in_c1 > 1: info.setdim(index=index, axis="C1", tilel1=in_c1, tilel0=in_c1) # c1 if in_h > 1: info.setdim(index=index, axis="H", tilel1=a_cut_h, tilel0=a_cut_h) # h if in_w > 1: info.setdim(index=index, axis="W", tilel1=a_cut_w, tilel0=a_cut_w) # w if in_c0 > 1: info.setdim(index=index, axis="C0", tilel1=in_c0, tilel0=in_c0) # c0 index += 1 if out_c1 > 1: info.setdim(index=index, axis="C1", tilel1=c1_cut, tilel0=0) # c1 if out_h > 1: info.setdim(index=index, axis="H", tilel1=tile_out_h, tilel0=0) # h if out_w > 1: info.setdim(index=index, axis="W", tilel1=out_w, tilel0=0) # w if out_c0 > 1: info.setdim(index=index, axis="C0", tilel1=out_c0, tilel0=0) # c0 if in_c1 > 1: info.setdim(index=index, axis=5, tilel1=in_c1, tilel0=0) # kc1 if k_h > 1: info.setdim(index=index, axis=5, tilel1=k_h, tilel0=0) # kh if k_w > 1: info.setdim(index=index, axis=5, tilel1=k_w, tilel0=0) # kw # Compute the convolution output_name = "c_value" output_bias_name = "OUT" c_value = akg.tvm.compute(out_shape_nc1hwc0, lambda n, c1, h, w, c0: akg.lang.cce.mmad( akg.tvm.if_then_else(akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) > (in_h + p_top - 1), (w * s_w + kw) < p_left, (w * s_w + kw) > (in_w + p_left - 1)), akg.tvm.const(0.0, 'float16'), a_value[n, kc1, (h * s_h + (kh * d_h) - p_top), \ (w * s_w + (kw * d_w) - p_left), kc0]) * b_value[(kc1 * k_h + kh) * k_w + kw, c1, c0, kc0], axis=[kc1, kh, kw, kc0]), name=output_name, attrs={ "pragma_conv_kernel_n": k_n, "pragma_conv_kernel_h": k_h, "pragma_conv_kernel_w": k_w, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_bypass_l1": 1 if bypass_l1 else 0, "pragma_conv_stride_h": s_h, "pragma_conv_stride_w": s_w, "pragma_conv_dilation_h": d_h, "pragma_conv_dilation_w": d_w, "pragma_conv_fm_n": in_n, "pragma_conv_fm_c": in_c, "pragma_conv_fm_h": in_h, "pragma_conv_fm_w": in_w, "pragma_conv_h_cut": (h_window_cut - 1) * s_h + k_h_d, "pragma_conv_w_cut": (in_w + p_left + p_right), "pragma_conv_co_cut": c1_cut * k_c0, "pragma_conv_m_cut": tile_mm, "pragma_conv_k_cut": tile_kk, "pragma_conv_n_cut": tile_nn, "feature": a_value.op.name, "filter": b_value.op.name, "bias": bias_name, "res": output_name, "res_bias": output_bias_name}) if use_bias: cube = akg.tvm.compute(out_shape_nc1hwc0, lambda n, c1, h, w, c0: c_value[n, c1, h, w, c0] + bias_value[0, c1, 0, 0, c0], name=output_bias_name) else: cube = c_value return cube, a_tmp, b_value, bias_value, kernel_name, str(info)
def conv_backprop_input_compute(data, output_shape, filter_shape, input_shape, pad_, stride_, block_size=16, attrs=None, key=None): """core computation of conv_backprop_input.""" _, in_c, w_h, w_w = filter_shape # stride (stride_h, stride_w) stride_h, stride_w = stride_ if stride_h != stride_w: raise ValueError("stride_h must be equal to stride_w.") # output shape (NCHW -> NC1HWC0) in_nn, in_cc, in_hh, in_ww = output_shape if in_c % block_size != 0: raise ValueError("in_c must be divided by block_size.") input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww, block_size) in_nn, _, in_hh, in_ww, _ = input_shape_nc1hwc0 input_trans_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh * stride_h, in_ww * stride_w, block_size) in_n, in_c1, in_h, in_w, _ = input_trans_shape_nc1hwc0 # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape if k_c % block_size != 0: raise ValueError("k_c must be divided by block_size.") kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 kernel_shape_trans = (k_n // block_size * k_h * k_w, k_c // block_size, block_size, block_size) k_c1 = k_n // block_size k_n = k_c _, _, input_h, input_w = input_shape # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[1], pad_[2], pad_[3]) pad_t, pad_b, pad_l, pad_r = padding # padHT -> padHT' p_top = k_h - pad_t - 1 # padHB -> padHB' p_bottom = input_h + pad_t - stride_h * ( (input_h + pad_t + pad_b - k_h) // stride_h + 1) # padWL -> padWL' p_left = k_w - pad_l - 1 # padWR -> padWR' p_right = input_w + pad_l - stride_w * ( (input_w + pad_l + pad_r - k_w) // stride_w + 1) s_h = 1 s_w = 1 # NC1HWCO a_value = data[0] if data[1].dtype == 'float32': b_value = cast.cast(data[1], 'float16') tiling_args = cast_tiling_args else: b_value = data[1] tiling_args = conv_backprop_input_tiling_args # Create reduction variables kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1') kh = akg.tvm.reduce_axis((0, k_h), name='kh') kw = akg.tvm.reduce_axis((0, k_w), name='kw') kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0') use_auto_tiling = False if attrs is not None and 'conv_tile' in attrs and len( attrs['conv_tile']) >= 5: tile_value = attrs['conv_tile'] elif key in tiling_args: tile_value = tiling_args[key] else: use_auto_tiling = True out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1 out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 # set dim info = dim.Dim() index_ = 0 if not use_auto_tiling: tile_hh = tile_value[0] if tile_hh == input_h: tile_hh += pad_t + pad_b tile_coco = tile_value[1] tile_coco = (tile_coco + block_size - 1) // block_size * block_size tile_mm = tile_value[2] tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = tile_value[3] if not tile_kk % (block_size * w_h * w_w) == 0: logging.warning( "Warning: tile_k must be a multiple of (block_size * w_h * w_w)" ) tile_kk = (tile_kk + block_size * w_h * w_w - 1) // (block_size * w_h * w_w) * (block_size * w_h * w_w) tile_nn = tile_value[4] tile_nn = (tile_nn + block_size - 1) // block_size * block_size tile_ww = input_w if len(tile_value) >= 6 and tile_value[5] > 0: tile_ww = tile_value[5] if tile_ww == input_w: tile_ww += pad_l + pad_r if tile_hh == in_h: tile_hh += p_top + p_bottom tile_out_h = (tile_hh - k_h) // s_h + 1 if tile_ww == in_w: tile_ww += p_left + p_right tile_out_w = (tile_ww - k_w) // s_w + 1 if tile_coco > 0: c1_cut = tile_coco // block_size else: c1_cut = out_c1 if out_n > 1: info.setdim(index=index_, axis=0, tilel1=1, tilel0=0) # n if out_c1 > 1: info.setdim(index=index_, axis=1, tilel1=c1_cut, tilel0=0) # c1 if out_h > 1: info.setdim(index=index_, axis="H", tilel1=tile_out_h, tilel0=0) # h if out_w > 1: info.setdim(index=index_, axis="W", tilel1=tile_out_w, tilel0=0) # w if out_c0 > 1: info.setdim(index=index_, axis=4, tilel1=out_c0, tilel0=0) # c0 if in_c1 > 1: info.setdim(index=index_, axis=5, tilel1=in_c1, tilel0=0) # kc1 if k_h > 1: info.setdim(index=index_, axis=5, tilel1=k_h, tilel0=0) # kh if k_w > 1: info.setdim(index=index_, axis=5, tilel1=k_w, tilel0=0) # kw info = str(info) else: info = "" # Compute the convolution below output_name = "output0" # weight_trans [ ko, no, ni, ki ] # weight_trans [ co_1, kh, kw, ci_1, ci_0, co_0 ] # kw = ko % k_w # kh = ko // k_w % k_h # co_1 = ko // k_w // k_h # ci_1 = no # --> # weight [ ci_1, kh', kw', co_1, co_0, ci_0 ] # weight [ no, k_h - ko // k_w % k_h - 1, k_w - ko % k_w - 1, ko // k_w // k_h, co_0, ci_0 ] b_trans = akg.tvm.compute(kernel_shape_trans, lambda ko, no, ni, ki: b_value[ ((no * k_h + k_h - 1 - ko // k_w % k_h) * k_w + k_w - 1 - ko % k_w), ko // (k_h * k_w), ki, ni], name='B_trans') if ((stride_h > 1) or (stride_w > 1)): @akg.tvm.hybrid.script def data_trans_hybrid(output, inputs, const_zero): """Implements data_trans ( B[n, c1, h * strideH, w * strideW, c0] = A[n, c1, h, w, c0] ).""" stride_h = output.shape[2] // inputs.shape[2] stride_w = output.shape[3] // inputs.shape[3] b = allocate(output.shape, output.dtype, 'local') for n in range(output.shape[0]): for c1 in range(output.shape[1]): for h in range(output.shape[2]): for w in range(output.shape[3]): for c0 in range(output.shape[4]): b[n, c1, h, w, c0] = const_zero if h % stride_h == 0 and w % stride_w == 0: b[n, c1, h, w, c0] = inputs[n, c1, h // stride_h, w // stride_w, c0] return b a_trans_init = akg.tvm.placeholder(input_trans_shape_nc1hwc0, dtype="float16", name='a_trans') const_zero = akg.tvm.const(0, 'float16') a_trans = data_trans_hybrid(a_trans_init, a_value, const_zero) else: a_trans = a_value conv_attrs = { "pragma_conv_kernel_n": k_n, "pragma_conv_kernel_h": k_h, "pragma_conv_kernel_w": k_w, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_bypass_l1": 0, "pragma_conv_backprop_input": 1, "pragma_conv_stride_h": s_h, "pragma_conv_stride_w": s_w, "pragma_conv_dilation_h": 1, "pragma_conv_dilation_w": 1, "pragma_conv_fm_n": in_n, "pragma_conv_fm_c": in_c, "pragma_conv_fm_h": in_h, "pragma_conv_fm_w": in_w, "feature": a_trans.op.name, "filter": b_value.op.name, "bias": 'None', "res": output_name } if not use_auto_tiling: conv_attrs["pragma_conv_h_cut"] = (tile_out_h - 1) * s_h + k_h conv_attrs["pragma_conv_w_cut"] = (tile_out_w - 1) * s_w + k_w conv_attrs["pragma_conv_co_cut"] = c1_cut * k_c0 conv_attrs["pragma_conv_m_cut"] = tile_mm conv_attrs["pragma_conv_k_cut"] = tile_kk conv_attrs["pragma_conv_n_cut"] = tile_nn res_c = akg.tvm.compute( out_shape_nc1hwc0, lambda n, c1, h, w, c0: akg.lang.cce.mmad((akg.tvm.if_then_else( akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) > (in_h + p_top - 1), (w * s_w + kw) < p_left, (w * s_w + kw) > (in_w + p_left - 1)), akg.tvm.const(0.0, 'float16'), a_trans[n, kc1, (h * s_h + kh - p_top), (w * s_w + kw - p_left), kc0]) * b_trans[ (kc1 * k_h + kh) * k_w + kw, c1, c0, kc0]).astype( "float32"), axis=[kc1, kh, kw, kc0]), name=output_name, attrs=conv_attrs) res_c = cast.cast(res_c, "float16") return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_rmselfdep": 0}
def avg_pool_5d_hybrid(a_value, kernel, stride, strategy): """avgpool with 5d case via hybrid""" kernel_h, kernel_w = kernel stride_h, stride_w = stride shape = get_shape(a_value) batch_size, c1_, in_size_h, in_size_w, c0_ = shape dtype = a_value.dtype if len(shape) != 5: raise ValueError("Only support 5-dim pooling!") if len(kernel) != 2: raise ValueError("Only support 2-dim kernel!") [pad_height_head, _, pad_width_head, _], [out_size_h, out_size_w] = \ cal_pad_shapes_by_strategy(shape, kernel, stride, strategy) avg_pre = akg.tvm.const(1.0000 / (kernel_w * kernel_h), dtype=dtype) zero = akg.tvm.const(0.0, dtype=dtype) @script(capture=locals()) def avg_pool_hybrid(inputs, zero, avg_pre): output = output_tensor((batch_size, c1_, out_size_h, out_size_w, c0_), inputs.dtype) for n in range(batch_size): for c1 in range(c1_): # Head for ow in range(out_size_w): for c0 in range(c0_): output[n, c1, 0, ow, c0] = zero for ow in range(out_size_w): for kh in range(kernel_h): for kw in range(kernel_w): for c0 in range(c0_): if (kh >= pad_height_head) \ and (ow * stride_w + kw - pad_width_head >= 0) \ and (ow * stride_w + kw <= in_size_w + pad_width_head - 1): output[n, c1, 0, ow, c0] = output[n, c1, 0, ow, c0] +\ inputs[n, c1, kh - pad_height_head, ow * stride_w + kw - pad_width_head, c0] else: output[n, c1, 0, ow, c0] += zero for ow in range(out_size_w): for c0 in range(c0_): output[n, c1, 0, ow, c0] *= avg_pre # Tail for oh in range(out_size_h - 1): for ow in range(out_size_w): for c0 in range(c0_): output[n, c1, oh + 1, ow, c0] = zero for oh in range(out_size_h - 1): for ow in range(out_size_w): for kh in range(kernel_h): for kw in range(kernel_w): for c0 in range(c0_): if ((oh + 1) * stride_h + kh <= in_size_h + pad_height_head - 1)\ and (ow * stride_w + kw >= pad_width_head)\ and (ow * stride_w + kw <= in_size_w + pad_width_head - 1): output[n, c1, oh + 1, ow, c0] = output[n, c1, oh + 1, ow, c0] +\ inputs[n, c1, (oh + 1) * stride_h + kh - pad_height_head, ow * stride_w + kw - pad_width_head, c0] else: output[n, c1, oh + 1, ow, c0] += zero for oh in range(out_size_h - 1): for ow in range(out_size_w): for c0 in range(c0_): output[n, c1, oh + 1, ow, c0] *= avg_pre return output res_value = avg_pool_hybrid(a_value, zero, avg_pre) # set dim info = dim.Dim() # first part info.setdim(index=0, axis=0, tilel1=out_size_w, tilel0=0) # ow info.setdim(index=0, axis=1, tilel1=c0_, tilel0=0) # c0 info.setdim(index=0, axis=2, tilel1=kernel_h, tilel0=0) # kh # second part info.setdim(index=1, axis=0, tilel1=out_size_h - 1, tilel0=0) # oh-1 info.setdim(index=1, axis=1, tilel1=out_size_w, tilel0=0) # ow info.setdim(index=1, axis=2, tilel1=c0_, tilel0=0) # c0 info.setdim(index=1, axis=3, tilel1=kernel_h, tilel0=0) # kh info = str(info) attrs = {DIM: info} return res_value, attrs
def conv_backprop_filter_compute(data, input_shape, filter_shape, output_shape, pad_, stride_, dilation_, block_size=16, attrs=None, key=None): """core computation of conv_backprop_filter_compute.""" # stride (stride_h, stride_w) stride_h, stride_w = stride_ if stride_h != stride_w: raise ValueError("stride_h must be equal to stride_w.") # conv_backprop_filter input shape (NCHW -> NC1HWC0 -> fractal): load2d L0A input_n, input_c, input_h, input_w = output_shape if input_c % block_size != 0: raise ValueError("output channel must be divided by block_size.") if input_n > 32: raise ValueError("Batch must be less than or equal to 32.") input_shape_nc1hwc0 = (input_n, input_c // block_size, input_h, input_w, block_size) input_n, input_c1, input_h, input_w, input_c0 = input_shape_nc1hwc0 mo = (input_h * input_w + block_size - 1) // block_size mi = block_size input_trans_shape_fractal = (input_n, input_c1, mo, input_c0, mi) # conv_backprop_filter kernel shape (NCHW -> NC1HWC0): img2col L0B k_n, k_c, k_h, k_w = input_shape if k_c % block_size != 0: raise ValueError("input channel must be divided by block_size.") kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 # conv_backprop_filter output shape (NCHW -> NC1HWC0) out_n, out_c, out_h, out_w = filter_shape if out_n != input_c: raise ValueError("out_n must be equal to input_c.") output_shape_nc1hwc0 = (out_n, out_c // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, _ = output_shape_nc1hwc0 output_shape_fractal = (out_c1, out_h, out_w, out_n // block_size, block_size, block_size) out_c1, out_h, out_w, out_mo, out_mi, out_ni = output_shape_fractal # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[1], pad_[2], pad_[3]) p_top, p_bottom, p_left, p_right = padding s_h, s_w = stride_ data_a = data[0] o_n, o_c1, o_h, o_w, o_c0 = data_a.shape mo = (o_h * o_w + block_size - 1) // block_size mi = block_size a_shape_fractal = (o_n, o_c1, mo, mi, o_c0) a_fractal = akg.tvm.placeholder(a_shape_fractal, dtype=data_a.dtype, name="backprop") a_buf = akg.tvm.decl_buffer(a_shape_fractal, a_fractal.dtype, name="backprop") data_b = data[1] tiling_args = batch_conv_backprop_filter_tiling_args use_autotiling = False if k_n == 1: tiling_args = conv_backprop_filter_tiling_args if attrs is not None and 'conv_tile' in attrs and len(attrs['conv_tile']) >= 8: tile = attrs['conv_tile'] elif key in tiling_args: tile = tiling_args[key] else: use_autotiling = True in_h = k_h in_w = k_w if not use_autotiling: # set dim info = dim.Dim() index_ = 0 # tile = [Ci, KH, KW, Co, Batch, H, W, M, K, N] tile_ci = tile[0] if tile_ci > k_c1 * k_c0: tile_ci = k_c1 * k_c0 tile_ci = (tile_ci + block_size - 1) // block_size tile_kh = tile[1] if tile_kh > out_h: tile_kh = out_h tile_kw = tile[2] if tile_kw > out_w: tile_kw = out_w tile_coco = tile[3] if tile_coco > input_c1 * input_c0: tile_coco = input_c1 * input_c0 tile_coco = (tile_coco + block_size - 1) // block_size tile_batch = tile[4] if tile_batch > input_n: tile_batch = input_n if tile_batch != 1: raise ValueError("tile_batch must be 1.") d_h, d_w = dilation_ tile_hh = tile[5] if tile_hh == in_h: tile_hh = in_h + p_top + p_bottom elif tile_hh > in_h + p_top + p_bottom: tile_hh = in_h + p_top + p_bottom h_win_cut = (tile_hh - ((out_h - 1) * d_h + 1)) // s_h + 1 tile_ww = tile[6] if tile_ww == in_w: tile_ww = in_w + p_left + p_right elif tile_ww > in_w + p_left + p_right: tile_ww = in_w + p_left + p_right w_win_cut = (tile_ww - ((out_w - 1) * d_w + 1)) // s_w + 1 tile_mm = tile[7] tile_kk = tile[8] tile_nn = tile[9] tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = (tile_kk + block_size - 1) // block_size * block_size tile_nn = (tile_nn + block_size - 1) // block_size * block_size if out_c1 > 1: info.setdim(index=index_, axis=0, tilel1=tile_ci, tilel0=tile_ci) if out_h > 1: info.setdim(index=index_, axis=0, tilel1=tile_kh, tilel0=tile_kh) if out_w > 1: info.setdim(index=index_, axis=0, tilel1=tile_kw, tilel0=tile_kw) if out_mo > 1: info.setdim(index=index_, axis=0, tilel1=tile_coco, tilel0=tile_coco) if out_mi > 1: info.setdim(index=index_, axis=0, tilel1=out_mi, tilel0=out_mi) # mi don't tile if out_ni > 1: info.setdim(index=index_, axis=0, tilel1=out_ni, tilel0=out_ni) # ni don't tile if input_n > 1: info.setdim(index=index_, axis=0, tilel1=tile_batch, tilel0=tile_batch) # Batch tile if k_h > 1: info.setdim(index=index_, axis="H", tilel1=h_win_cut, tilel0=h_win_cut) # out_h if k_w > 1: info.setdim(index=index_, axis="W", tilel1=w_win_cut, tilel0=w_win_cut) # out_w info = str(info) else: info = "" # Compute the convolution output_name = "filter" a_trans = akg.tvm.compute(input_trans_shape_fractal, lambda n, co1, mo, co0, mi: a_fractal[n, co1, mo, mi, co0], name='dy_trans') # Create reduction variables no = akg.tvm.reduce_axis((0, input_n), name='no') ho = akg.tvm.reduce_axis((0, input_h), name='ho') wo = akg.tvm.reduce_axis((0, input_w), name='wo') conv_filter_attr = { "pragma_conv_kernel_n": out_n, "pragma_conv_kernel_h": out_h, "pragma_conv_kernel_w": out_w, "pragma_conv_padding_top": p_top, "pragma_conv_padding_bottom": p_bottom, "pragma_conv_padding_left": p_left, "pragma_conv_padding_right": p_right, "pragma_conv_bypass_l1": 0, "pragma_conv_backprop_filter": 1, "pragma_conv_stride_h": s_h, "pragma_conv_stride_w": s_w, "pragma_conv_dilation_h": 1, "pragma_conv_dilation_w": 1, "pragma_conv_fm_n": k_n, "pragma_conv_fm_c": k_c, "pragma_conv_fm_h": k_h, "pragma_conv_fm_w": k_w, "feature": data_b.op.name, "filter": a_fractal.op.name, "bias": 'None', "res": output_name} if not use_autotiling: conv_filter_attr["pragma_conv_batch_cut"] = tile_batch conv_filter_attr["pragma_conv_h_cut"] = (h_win_cut - 1) * s_h + ((out_h - 1) * d_h + 1) conv_filter_attr["pragma_conv_w_cut"] = (w_win_cut - 1) * s_w + ((out_w - 1) * d_w + 1) conv_filter_attr["pragma_conv_co_cut"] = tile_coco * block_size conv_filter_attr["pragma_conv_cin_cut"] = tile_ci * block_size conv_filter_attr["pragma_conv_m_cut"] = tile_mm conv_filter_attr["pragma_conv_k_cut"] = tile_kk conv_filter_attr["pragma_conv_n_cut"] = tile_nn conv_filter_attr["pragma_conv_kh_cut"] = tile_kh conv_filter_attr["pragma_conv_kw_cut"] = tile_kw res_c = akg.tvm.compute(output_shape_fractal, lambda c1, h, w, mo, mi, ni: akg.lang.cce.mmad( (akg.tvm.if_then_else(akg.tvm.any((h + s_h * ho) < p_top, (h + s_h * ho) > (in_h + p_top - 1), (w + s_w * wo) < p_left, (w + s_w * wo) > (in_w + p_left - 1)), akg.tvm.const(0.0, 'float16'), a_trans[no, mo, (input_w * ho + wo) // 16, mi, (input_w * ho + wo) % 16]) * data_b[no, c1, (ho * s_h + h - p_top), (wo * s_w + w - p_left), ni]).astype("float32"), axis=[no, ho, wo]), name=output_name, attrs=conv_filter_attr) return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_conv_special_dma": 1, utils.BINDS: {data_a: a_buf, a_fractal: a_buf}}
def depthwise_set_dim_func(data, N, H, W, CI, k_ch, KH, KW, PAD_H, PAD_W, SH, SW, block_size, use_bias=False): key = [N, H, W, CI, k_ch, KH, KW, PAD_H, PAD_W, SH, SW] hash_key = str((tuple(key))) clear = True if hash_key in depthwise_set_dim_map: cutH, cutCo, _, _, _ = depthwise_set_dim_map[hash_key] clear = False else: # raise RuntimeError("other can not find cutH, cutCo, cutM, cutK, cutN") cutH = (KH - 1) * KH + 1 cutCo = 16 group = CI // block_size CO = CI * k_ch OH = (H + 2 * PAD_H - KH) // SH + 1 OW = (W + 2 * PAD_W - KW) // SW + 1 out_n, out_c1, out_h, out_w, out_c0 = [ N, CO // block_size, OH, OW, block_size ] # set dim tile_out_h = (cutH - KH) // SH + 1 info = dim.Dim() if (out_n > 1): info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if (out_c1 > 1): info.setdim(index=0, axis=0, tilel1=cutCo // block_size, tilel0=0) # c1 if (out_h > 1): info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0) # h if (out_w > 1): info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0) # w if (out_c0 > 1): info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 assert CI // block_size // group == 1 if (CI // block_size // group > 1): info.setdim(index=0, axis=5, tilel1=CI // block_size // group, tilel0=0) # kc1 if (KH > 1): info.setdim(index=0, axis=5, tilel1=KH, tilel0=0) # kh if (KW > 1): info.setdim(index=0, axis=5, tilel1=KW, tilel0=0) # kw if clear: info = "" return str(info)
def conv_02(fmap_shape, filter_shape, pad_, stride_, dilation_, tile_hh=0, tile_coco=0, tile_mm=0, tile_kk=0, tile_nn=0, bypass_l1=False, use_bias=False, block_size=16, conv_dtype='float16'): # input shape (NCHW -> NC1HWC0) in_n, in_c, in_h, in_w = fmap_shape in_c = (in_c + block_size - 1) // block_size * block_size # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape k_c = (k_c + block_size - 1) // block_size * block_size k_n = (k_n + block_size - 1) // block_size * block_size input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size) in_n, _, in_h, in_w, _ = input_shape_nc1hwc0 kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, _, k_h, k_w, _ = kernel_shape_nc1hwc0 kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size, block_size, block_size) # A placeholder (NC1HWCO) A = akg.tvm.placeholder(input_shape_nc1hwc0, dtype=conv_dtype, name="input0") # B_placeholder (fractal) B = akg.tvm.placeholder(kernel_shape_fractal, dtype=conv_dtype, name="input1") if use_bias: bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size) bias_name = "input2" bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0, dtype=conv_dtype, name=bias_name) else: bias_name = 'None' bias_value = None conv_forward = conv_compute_forward(fmap_shape, filter_shape, pad_, stride_, dilation_, A, B, bias_value, tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, bypass_l1, use_bias, block_size, conv_dtype) k_hw = k_h * k_w const_shift = k_hw - 1 # B in Fractal format; result in Fractal format def flip_weight(B, k_c, k_hw, const_shift): out_shape = (B.shape[1].value * k_hw, k_c // block_size, block_size, block_size) B_flip = akg.tvm.compute( out_shape, lambda i0, i1, i2, i3: B[i1 * k_hw + const_shift - truncmod( i0, k_hw), floordiv(i0, k_hw), i3, i2], name=B.name + "_flipped") return B_flip # H in 5D format; result in 5D format def strided_head(H, s_h, s_w): n, c1, h, w, c0 = H.shape out_shape = (n, c1, (h - 1) * s_h + 1, (w - 1) * s_w + 1, c0) H_strided = akg.tvm.compute( out_shape, lambda i0, i1, i2, i3, i4: akg.tvm.expr.Select( akg.tvm.any(truncmod(i2, s_h) != 0, truncmod(i3, s_w) != 0), akg.tvm.const(0.0, dtype="float16"), H[i0, i1, floordiv(i2, s_h), floordiv(i3, s_w), i4]), name=H.name + "_strided") return H_strided # A in 5D format; result in 5D format def transpose_data(A): out_shape = (A.shape[1].value * block_size, A.shape[0].value // block_size, A.shape[2].value, A.shape[3].value, block_size) A_transpose = akg.tvm.compute( out_shape, lambda j0, j1, j2, j3, j4: A[j1 * block_size + j4, floordiv(j0, block_size), j2, j3, truncmod(j0, block_size)], name=A.name + "_transposed") return A_transpose # Head is in 5D format; result in Fractal format def transpose_convert_head(Head): out_shape = ((Head.shape[0].value // block_size) * Head.shape[2].value * Head.shape[3].value, Head.shape[1].value, block_size, block_size) tmp_6D_shape = (Head.shape[0].value // block_size, block_size, Head.shape[1].value, Head.shape[2].value, Head.shape[3].value, block_size) Head_6D = akg.topi.reshape(Head, tmp_6D_shape) Head_6D_transpose = akg.topi.transpose(Head_6D, (0, 3, 4, 2, 5, 1)) Head_transpose_convert = akg.topi.reshape(Head_6D_transpose, out_shape) return Head_transpose_convert HEAD = akg.tvm.placeholder(conv_forward.shape, name="Head", dtype='float16') Head_transposed_NCHW = (HEAD.shape[1].value * HEAD.shape[4].value, HEAD.shape[0].value, HEAD.shape[2].value, HEAD.shape[3].value) s_h, s_w = stride_ Head_strided_NCHW = (HEAD.shape[0].value, HEAD.shape[1].value * HEAD.shape[4].value, (HEAD.shape[2].value - 1) * s_h + 1, (HEAD.shape[3].value - 1) * s_w + 1) A_transposed_NCHW = (in_c, in_n, in_h, in_w) K_flip_rot_NCHW = (k_c, k_n, k_h, k_w) Head_transposed_converted = transpose_convert_head(HEAD) pld_Head_transposed_converted = akg.tvm.placeholder( Head_transposed_converted.shape, name="Head_trans_fractal", dtype=conv_dtype) A_transposed = transpose_data(A) pld_A_transposed = akg.tvm.placeholder(A_transposed.shape, name="A_trans", dtype=conv_dtype) info = dim.Dim() info.setdim(index=0, axis=0, tilel1=1, tilel0=1) info.setdim(index=0, axis=1, tilel1=1, tilel0=1) info.setdim(index=0, axis=2, tilel1=1, tilel0=1) info.setdim(index=0, axis=3, tilel1=1, tilel0=1) B_flip = flip_weight(B, k_c, k_hw, const_shift) pld_B_flipped = akg.tvm.placeholder(B_flip.shape, name="B_flip", dtype=conv_dtype) s_flipped = akg.tvm.create_schedule(B_flip.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_weight_flipped = akg.build(s_flipped, [B, B_flip], "cce", name=B.name + "_flipped", attrs={"dim": str(info)}, polyhedral=True) s_transposed_converted = akg.tvm.create_schedule( Head_transposed_converted.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_head_transposed_converted = akg.build( s_transposed_converted, [HEAD, Head_transposed_converted], "cce", name="H_trans_converted", attrs={"dim": str(info)}, polyhedral=True) Head_strided = strided_head(HEAD, s_h, s_w) pld_Head_strided = akg.tvm.placeholder(Head_strided.shape, name="Head_trans_5D", dtype=conv_dtype) s_strided = akg.tvm.create_schedule(Head_strided.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_head_strided = akg.build(s_strided, [HEAD, Head_strided], "cce", name="H_strided", attrs={"dim": str(info)}, polyhedral=True) s_transposed = akg.tvm.create_schedule(A_transposed.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_transposed = akg.build(s_transposed, [A, A_transposed], "cce", name="A_transposed", attrs={"dim": str(info)}, polyhedral=True) ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 1} jacs = list( akg.differentiate(conv_forward, [A], HEAD, ad_attrs, [pld_Head_strided, pld_B_flipped, None])) info = set_dims(Head_strided_NCHW, (k_c, k_n, k_h, k_w), (k_h - 1, k_w - 1), (1, 1), (1, 1), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, block_size) sjac = akg.tvm.create_schedule([jacs[0].op]) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_AD_data = akg.build(sjac, [pld_Head_strided, pld_B_flipped, jacs[0]], "cce", name="conv_AD_data", attrs={"dim": str(info)}, polyhedral=True) conv_data = conv_compute_forward(Head_strided_NCHW, K_flip_rot_NCHW, (k_h - 1, k_h - 1, k_w - 1, k_w - 1), (1, 1), (1, 1), pld_Head_strided, pld_B_flipped, None, tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, bypass_l1, use_bias, block_size, conv_dtype) info = set_dims(Head_strided_NCHW, (k_c, k_n, k_h, k_w), (k_h - 1, k_w - 1), (1, 1), (1, 1), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, block_size) s_data = akg.tvm.create_schedule(conv_data.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): _ = akg.build(s_data, [pld_Head_strided, pld_B_flipped, conv_data], "cce", name="conv_data", attrs={"dim": str(info)}, polyhedral=True) ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 1} jacs = list( akg.differentiate( conv_forward, [B], HEAD, ad_attrs, [pld_A_transposed, pld_Head_transposed_converted, None])) info = set_dims(A_transposed_NCHW, Head_transposed_NCHW, (0, 0), (1, 1), (s_h, s_w), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, block_size) sjac = akg.tvm.create_schedule([jacs[0].op]) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_AD_weight = akg.build( sjac, [pld_A_transposed, pld_Head_transposed_converted, jacs[0]], "cce", name="conv_AD_weight", attrs={"dim": str(info)}, polyhedral=True) conv_weight = conv_compute_forward( A_transposed_NCHW, Head_transposed_NCHW, (0, 0, 0, 0), (1, 1), (s_h, s_w), pld_A_transposed, pld_Head_transposed_converted, None, tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, bypass_l1, use_bias, block_size, conv_dtype) info = set_dims(A_transposed_NCHW, Head_transposed_NCHW, (0, 0), (1, 1), (s_h, s_w), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, block_size) s_weight = akg.tvm.create_schedule(conv_weight.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): akg.build( s_weight, [pld_A_transposed, pld_Head_transposed_converted, conv_weight], "cce", name="conv_weight", attrs={"dim": str(info)}, polyhedral=True) return mod_AD_data, mod_AD_weight, mod_transposed, mod_head_transposed_converted, mod_head_strided, mod_weight_flipped
def get_dim_info(self, arg, is_conv=False): info = dim.Dim() tile_dims = [] dims = None enable_multicore = None dynamic = False partial_dynamic = False bypass_l1 = False if "dynamic" in arg: dynamic = True if isinstance(arg, tuple): arg = list(arg) arg.remove("dynamic") arg = tuple(arg) else: arg.remove("dynamic") if "partial_dynamic" in arg: partial_dynamic = True arg.remove("partial_dynamic") if "bypassL1" in arg: bypass_l1 = True arg.remove("bypassL1") if is_conv: dy = dynamic or partial_dynamic if len(arg) == 4: conv_tile = arg[3] if len(conv_tile) > 0: if not dy: return { "dim": str(info), "conv_tile": conv_tile, "enable_multicore": True, "bypass": 1 if bypass_l1 else 0, } else: return { "dim": str(info), "conv_tile": conv_tile, "dynamic": dynamic, "partial_dynamic": partial_dynamic, "bypass": 1 if bypass_l1 else 0, } elif dy and len(arg) == 3: return { "dynamic": dynamic, "partial_dynamic": partial_dynamic, "bypass": 1 if bypass_l1 else 0, } if len(arg) == 5 and not arg[-1]: dims = arg[3] for d in range(len(dims)): tile_dims.append(dims[d][0]) elif (len(arg) == 5 and arg[-1]) or len(arg) == 4: if isinstance(arg[3], (bool, int)): # only multicore info enable_multicore = arg[3] elif isinstance(arg[3][-1], (bool, int)): # dim info and multicore info enable_multicore = arg[3][-1] dims = arg[3][0] else: # only dim info dims = arg[3] if dims is not None: for i in range(len(dims)): if (isinstance(dims[i][0], int)): # only one index, ((l1,l0),(l1,l0),...) i_dims = dims else: # multiple indices, (((l1,l0),(l1,l0),...), ((l1,l0),(l1,l0),...)) i_dims = dims[i] for d in range(len(i_dims)): info.setdim(index=i, axis=d, tilel1=i_dims[d][0], tilel0=i_dims[d][1]) if len(arg) == 5 and not arg[-1]: return {"tile": tile_dims} else: res = {"dim": str(info), "dynamic": dynamic} if enable_multicore: res["enable_multicore"] = enable_multicore return res
def group_conv_ad(_n, _h, _w, _c_i, _c_o, group, _k_h, _k_w, pad_h, pad_w, _s_h, _s_w, cut_h, cut_co, cut_m, cut_k, cut_n, block_size, use_bias=False, kernel_name='group_conv'): conv_dtype = 'float16' _a = akg.tvm.placeholder((_n, _c_i // block_size, _h, _w, block_size), name="input0", dtype=conv_dtype) _b = akg.tvm.placeholder(((_c_i // group) // block_size * _k_h * _k_w, _c_o // block_size, block_size, block_size), name="input1", dtype=conv_dtype) mod_forward = group_conv_forward(_n, _h, _w, _c_i, _c_o, group, _k_h, _k_w, _a, _b, None, pad_h, pad_w, _s_h, _s_w, cut_h, cut_co, cut_m, cut_k, cut_n, block_size) _o_h = mod_forward.shape[2].value _o_w = mod_forward.shape[3].value head = akg.tvm.placeholder(mod_forward.shape, name="head", dtype=conv_dtype) # (_n,_c_o,_o_h,_o_w)--(stride)-->(_n,_c_o,(_o_h-1)*_s_h+1, # (_o_w-1)*_s_w+1)--(5d)-->(_n,_c_o/16,(_o_h-1)*_s_h+1,(_o_w-1)*_s_w+1,16) pld_head_strided = akg.tvm.placeholder((_n, _c_o // block_size, (_o_h - 1) * _s_h + 1, (_o_w - 1) * _s_w + 1, block_size), name="head_strided_5d", dtype=conv_dtype) # (_c_o,_c_i//group,_k_h,_k_w)--(flip)--> # (_c_i,_c_o//group,_k_h,_k_w)--(Fractal)-->((_c_o//group)/16*_k_h*_k_w, _c_i/16,16,16) pld_b_flipped = akg.tvm.placeholder(((_c_o // group) // block_size * _k_h * _k_w, _c_i // block_size, block_size, block_size), name="b_flip", dtype=conv_dtype) # b in Fractal format; result in Fractal format b_group_flipped = group_flip_weight(_b, _k_h, _k_w, group, _c_o // group // block_size, _c_i // group // block_size, block_size) s_gr_fl = akg.tvm.create_schedule([b_group_flipped.op]) info = dim.Dim() info.setdim(index=0, axis=0, tilel1=1, tilel0=1) info.setdim(index=0, axis=1, tilel1=1, tilel0=1) info.setdim(index=0, axis=2, tilel1=1, tilel0=1) info.setdim(index=0, axis=3, tilel1=1, tilel0=1) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False): mod_b_group_flip = akg.build(s_gr_fl, [_b, b_group_flipped], "cce", name="b_group_flip", attrs={"dim": str(info)}, polyhedral=True) head_strided = strided_head(head, _s_h, _s_w) s_striding = akg.tvm.create_schedule(head_strided.op) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False): mod_head_strided = akg.build(s_striding, [head, head_strided], "cce", name="h_strided", attrs={"dim": str(info)}, polyhedral=True) a_transposed = transpose_regroup(_a, block_size, group) s_transposed_nc = akg.tvm.create_schedule(a_transposed.op) info = dim.Dim() info.setdim(index=0, axis=0, tilel1=16, tilel0=16) info.setdim(index=0, axis=1, tilel1=1, tilel0=1) info.setdim(index=0, axis=2, tilel1=1, tilel0=1) info.setdim(index=0, axis=3, tilel1=1, tilel0=1) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_transposed_nc = akg.build(s_transposed_nc, [_a, a_transposed], "cce", name="a_transposed", attrs={"dim": str(info)}, polyhedral=True) head_transposed_convert = transpose_convert_head(head, block_size) s_transposed_convert = akg.tvm.create_schedule(head_transposed_convert.op) info = dim.Dim() info.setdim(index=0, axis=0, tilel1=1, tilel0=1) info.setdim(index=0, axis=1, tilel1=1, tilel0=1) info.setdim(index=0, axis=2, tilel1=1, tilel0=1) info.setdim(index=0, axis=3, tilel1=1, tilel0=1) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_transposed_convert = akg.build(s_transposed_convert, [head, head_transposed_convert], "cce", name="a_transposed", attrs={"dim": str(info)}, polyhedral=True) # Begin with the ad kernels ad_attrs = {"ad_conv_enable": 1} _jacs_data = list(akg.differentiate(mod_forward, [_a], head, ad_attrs, [pld_head_strided, pld_b_flipped, None])) cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e = ((_o_h - 1) * _s_h + 1 + 2 * (_k_h - 1 - pad_h), 16, _h * _w, 48, 16) cut_m_e = ((cut_m_e + block_size - 1) // block_size) * block_size info = set_dims_group(cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e, expr_to_int(_a.shape), _c_o, _c_i, group, _k_h, _k_w, _s_h, block_size) s_data = akg.tvm.create_schedule([_jacs_data[0].op]) # low_data = akg.lower(s_data, [pld_head_strided, pld_b_flipped, _jacs_data[0]], simple_mode=True) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False): mod_ad_data = akg.build(s_data, [pld_head_strided, pld_b_flipped, _jacs_data[0]], "cce", name="conv_ad_data", attrs={"dim": info}, polyhedral=True) # (_n,_c_i,_h,_w)--(trans)-->(_c_i,_n,_h,_w)--(regroup)--> # (_c_i//group,_n*group,_h,_w)--(5d)-->(_c_i//group,(_n*group)/16,_h,_w,16) pld_x_trans = akg.tvm.placeholder((_c_i // group, (_n * group) // block_size, _h, _w, block_size), name="x_trans_5d", dtype=conv_dtype) # (_n,_c_o,_o_h,_o_w)--(trans)--> # (_c_o,_n,_o_h,_o_w)--(Fractal)-->(_n/16*_o_h*_o_w, _c_o/16,16,16) pld_head_trans_converted = akg.tvm.placeholder((_n // block_size * _o_h * _o_w, _c_o // block_size, block_size, block_size), name="head_trans_convert", dtype=conv_dtype) # ad_attrs = {"ad_conv_enable": 1} _jacs_weights = list(akg.differentiate(mod_forward, [_b], head, ad_attrs, [pld_x_trans, pld_head_trans_converted, None])) cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e = (_h + 2 * pad_h, 16, _k_h * _k_w, 48, 16) cut_m_e = ((cut_m_e + block_size - 1) // block_size) * block_size info = set_dims_group(cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e, (_c_i // group, _c_o // block_size, _k_h, _k_w, block_size), _n * group, _c_o, group, _o_h, _o_w, 1, block_size) s_weights = akg.tvm.create_schedule([_jacs_weights[0].op]) with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True): mod_ad_weights = akg.build(s_weights, [pld_x_trans, pld_head_trans_converted, _jacs_weights[0]], "cce", name="conv_ad_weights", attrs={"dim": info}, polyhedral=True) print("Forward input data shape: ", _a.shape) print("Forward input weight shape: ", _b.shape) print("Forward output shape: ", mod_forward.shape) print("Backward wrt. DATA input data shape: ", pld_head_strided.shape) print("Backward wrt. DATA input weight shape: ", pld_b_flipped.shape) print("Backward wrt. DATA output shape: ", _jacs_data[0].shape) print("Backward wrt. WEIGHT input data shape: ", pld_x_trans.shape) print("Backward wrt. WEIGHT input weight shape: ", pld_head_trans_converted.shape) print("Backward wrt. WEIGHT output shape: ", _jacs_weights[0].shape) return mod_ad_data, mod_ad_weights, mod_b_group_flip, mod_head_strided, mod_transposed_nc, mod_transposed_convert
def cast_conv_set_dim_func(data, fmap_shape, filter_shape, pad_, stride_, dilation_, use_bias=False, block_size=16, attrs=None): if isinstance(stride_, int): stride_ = [stride_] * 2 elif isinstance(stride_, (list, tuple)) and 1 == len(stride_): stride_ = list(stride_) * 2 elif isinstance(stride_, (list, tuple)) and 2 == len(stride_): pass else: raise RuntimeError('stride para illegal !!!') if isinstance(pad_, int): pad_ = [pad_] * 4 elif isinstance(pad_, (list, tuple)) and 1 == len(pad_): pad_ = list(pad_) * 4 elif isinstance(pad_, (list, tuple)) and 4 == len(pad_): pass else: raise RuntimeError('pad para illegal !!!') if isinstance(dilation_, int): dilation_ = [dilation_] * 2 elif isinstance(dilation_, (list, tuple)) and 1 == len(dilation_): dilation_ = list(dilation_) * 2 elif isinstance(dilation_, (list, tuple)) and 2 == len(dilation_): pass else: raise RuntimeError('dilation para illegal !!!') key = [] key.append(tuple(fmap_shape)) key.append(tuple(filter_shape)) key.append(tuple(pad_)) key.append(tuple(stride_)) key.append(tuple(dilation_)) hash_key = str(tuple(key)) # input shape (NCHW -> NC1HWC0) in_n, in_c, in_h, in_w = fmap_shape in_c = (in_c + block_size - 1) // block_size * block_size # kernel shape (NCHW -> NC1HWC0 -> Fractal) k_n, k_c, k_h, k_w = filter_shape k_c = (k_c + block_size - 1) // block_size * block_size k_n = (k_n + block_size - 1) // block_size * block_size # padding((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right)) padding = (pad_[0], pad_[0], pad_[1], pad_[1]) p_top, p_bottom, p_left, p_right = padding # stride (stride_h, stride_w) s_h, s_w = stride_ # dilation (dilation_h, dilation_w) d_h, d_w = dilation_ k_w_d = (k_w - 1) * d_w + 1 out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1 bypass_list = [0, 1] bypass = 0 if attrs is not None and 'conv_tile' in attrs and len( attrs['conv_tile']) >= 5: tile_hh = attrs['conv_tile'][0] tile_coco = attrs['conv_tile'][1] tile_mm = attrs['conv_tile'][2] tile_kk = attrs['conv_tile'][3] tile_nn = attrs['conv_tile'][4] if len(attrs['conv_tile']) > 5: tile_ww = attrs['conv_tile'][5] else: tile_ww = (out_w - 1) * s_w + k_w_d if 'bypass' in attrs: bypass = attrs['bypass'] elif hash_key in cast_conv_set_dim_map: configs = cast_conv_set_dim_map[hash_key] if isinstance(configs, tuple): tiles = configs[0] if "bypass" in configs[1]: bypass = configs[1]["bypass"] else: tiles = configs if len(tiles) > 5: tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww = tiles else: tile_hh, tile_coco, tile_mm, tile_kk, tile_nn = tiles tile_ww = (out_w - 1) * s_w + k_w_d else: tile_hh = (k_h - 1) * d_h + 1 + p_top * s_h tile_ww = (out_w - 1) * s_w + k_w_d tile_coco = 16 tile_mm = 16 tile_kk = 16 tile_nn = 16 if not (bypass in bypass_list): raise RuntimeError("conv_cce ony supports %s while bypass is %d" % (",".join(str(bypass_list)), bypass)) if (tile_hh == in_h): tile_hh += p_top + p_bottom tile_coco = (tile_coco + block_size - 1) // block_size * block_size tile_mm = (tile_mm + block_size - 1) // block_size * block_size tile_kk = (tile_kk + block_size - 1) // block_size * block_size tile_nn = (tile_nn + block_size - 1) // block_size * block_size c0 = block_size c1_cut = tile_coco // c0 h_window_cut = (tile_hh - k_h) // s_h + 1 out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1 input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size) in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0 kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size) k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0 k_h_d = (k_h - 1) * d_h + 1 k_w_d = (k_w - 1) * d_w + 1 out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1 tile_out_h = (tile_hh - k_h_d) // s_h + 1 out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1 tile_out_w = (tile_ww - k_w_d) // s_w + 1 out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size) out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0 if (tile_coco > 0): c1_cut = tile_coco // block_size else: c1_cut = out_c1 # set dim info = dim.Dim() if (out_n > 1): info.setdim(index=0, axis=0, tilel1=1, tilel0=0) # n if (out_c1 > 1): info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0) # c1 if (out_h > 1): info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0) # h if (out_w > 1): info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0) # w if (out_c0 > 1): info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0) # c0 if (in_c1 > 1): info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0) # kc1 if (k_h > 1): info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0) # kh if (k_w > 1): info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0) # kw return str(info) # ct_util.set_dims_by_key(hash_key, conv_set_dim_map)