def _get_param_more_row(tvm_ib, src_shape, dtype): """ calculate parameters for more row ir builder make function """ device_core_num = AICORE_NUM float_size = cce.cce_intrin.get_bit_len(dtype) // 8 cp_align_len = cce_params.BLOCK_REDUCE_INT8 // float_size ub_ele = ((UB_SIZE_B - 32) // 2) // float_size _, n_no, n_ni, c_0 = src_shape row_ele = n_no*n_ni*c_0 num_row_one_core = ub_ele // row_ele num_row_one_group = num_row_one_core*device_core_num num_row_in_data = src_shape[0] num_group_index = num_row_in_data // num_row_one_group num_group_mod = num_row_in_data % num_row_one_group block_index = tvm.thread_axis("blockIdx.x") tvm_ib.scope_attr(block_index, "thread_extent", device_core_num) param_map = {"num_group_index": num_group_index, "num_group_mod": num_group_mod, "row_ele": row_ele, "float_size": float_size, "cp_align_len": cp_align_len, "num_row_one_core": num_row_one_core, "num_row_one_group": num_row_one_group, "block_index": block_index} return param_map
def __init__(self, input_x, output_y, num, axis, kernel_name): self.input_x = input_x self.num = num self.kernel_name = kernel_name self.dtype = input_x.get("dtype").lower() self.dim_info_vars = [] self.ub_tensor_list = [] self.res_tensor_list = [] self.virtual_node = None self.sch_list = [] self.arg_list = [] self.rules = [] self.compile_vars = {} self.dim_vars = [] self.dim_bounds = [] self.output_shape = [] self.x_reshape = None self.left_range = None self.right_range = None self._normalize_shape() self._trans_input_shape(axis) self.new_axis = 1 self._input_placeholder = None self.block_idx = tvm.thread_axis('blockIdx.x') self.ub_size = cce_conf.get_soc_spec(cce_conf.UB_SIZE) self.core_num = cce_conf.get_soc_spec(cce_conf.CORE_NUM)
def upsample(x, y, scale=1.0, stride_h=2, stride_w=2, kernel_name="upsample"): """ calculating data Parameters --------- x : dict include shape dtype and format stride_h : int the shape change axis h stride_w : int the shape change axis w scale : float the value of tensor change axis, default value is 1.0 y :output kernel_name : str kernel name, default value is "upsample" Returns ------- None """ upsample_check(x, stride_h, stride_w, kernel_name) dtype = x.get("dtype") op_list, ins_list, tensor_dic, feature, y \ = gen_upsample(x, dtype, scale, stride_h, stride_w) schedule = tvm.create_schedule(y.op) # skip the res buffer buffer_mapping(schedule, op_list[:-1]) tilling_spilt_axis_dic \ = tilling_spilt_axis(schedule, tensor_dic, stride_h, stride_w) cal_axis_dic, axis \ = cal_axis_spilt(x, stride_h, stride_w, tilling_spilt_axis_dic, tensor_dic, schedule) axis_list = upsample_compute(schedule, cal_axis_dic, tensor_dic) res_op = tensor_dic.get("res") ins_emit(schedule, op_list, axis_list, ins_list) if axis == 0: schedule[y].bind(cal_axis_dic.get("axis_xo"), tvm.thread_axis("blockIdx.x")) else: res_out, _ = bind_multcore(axis, x, schedule, res_op) schedule[y].bind(res_out, tvm.thread_axis("blockIdx.x")) with build_config: tvm.build(schedule, [feature, y], "cce", name=kernel_name)
def _tile_axis(data_list, shape, dtype): """calculate the tile parameters. """ sch = data_list[0] data_ub = data_list[1] data_out = data_list[2] ub_size = tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.UB_SIZE) dtype_size = tbe_platform.cce_intrin.get_bit_len(dtype) // 8 total_cnt = ub_size // dtype_size ele_cnt = shape[0] axis = 0 factor = ele_cnt if ele_cnt > total_cnt: factor = total_cnt core_num = ele_cnt // factor if core_num <= MAX_BLOCK: axis_outer, axis_inner = sch[data_out].split(data_out.op.axis[axis], factor=factor) if core_num != 1: sch[data_out].bind(axis_outer, tvm.thread_axis('blockIdx.x')) sch[data_ub].compute_at(sch[data_out], axis_outer) sch[data_ub].emit_insn(data_ub.op.axis[axis], 'dma_copy') sch[data_out].emit_insn(axis_inner, 'dma_copy') else: factor_new = 1 core_num_new = 1 for i in reversed(list(range(1, MAX_BLOCK))): if core_num % i == 0: factor_new = core_num // i core_num_new = i break axis_outer, axis_inner = sch[data_out].split(data_out.op.axis[axis], factor=factor) last_outer, last_inner = sch[data_out].split(axis_outer, factor=factor_new) if core_num_new != 1: sch[data_out].bind(last_outer, tvm.thread_axis('blockIdx.x')) sch[data_ub].compute_at(sch[data_out], last_inner) sch[data_ub].emit_insn(data_ub.op.axis[axis], 'dma_copy') sch[data_out].emit_insn(axis_inner, 'dma_copy') return sch
def _schedule_last_axis(sch, shape, in_data, output, dtype): """ schedule for the last axis situation """ # the four return args is the first axis factor, the first_inner facotr, the second_axis factor axis_inner_ft, axis_inner2_ft, axis2_inner_ft, cores, compute_axis = _get_core_num_last_axis( shape, dtype) axis_outer, axis_inner = sch[output].split(output.op.axis[0], factor=axis_inner_ft) axis_inner_outer, axis_inner_inner = sch[output].split( axis_inner, factor=axis_inner2_ft) axis_two_outter, axis_two_inner = sch[output].split(output.op.axis[1], factor=axis2_inner_ft) input_axis_outer, input_axis_inner = sch[in_data].split( in_data.op.axis[0], factor=axis_inner_ft) input_axis_inner_outer, input_axis_inner_inner = sch[in_data].split( input_axis_inner, factor=axis_inner2_ft) input_axis_two_outter, input_axis_two_inner = sch[in_data].split( in_data.op.axis[1], factor=axis2_inner_ft) if compute_axis == "axis_inner_inner": sch[in_data].compute_at(sch[output], axis_inner_inner) sch[in_data].emit_insn(input_axis_two_inner, insn_cmd.DMA_COPY) # gm-ub sch[output].emit_insn(axis_two_inner, insn_cmd.DMA_COPY) # ub-gm elif compute_axis == "axis_inner_outer": sch[in_data].compute_at(sch[output], axis_inner_outer) sch[in_data].emit_insn(input_axis_inner_inner, insn_cmd.DMA_COPY) # gm-ub sch[output].emit_insn(axis_inner_inner, insn_cmd.DMA_COPY) # ub-gm elif compute_axis == "axis_two": sch[in_data].compute_at(sch[output], axis_two_outter) sch[in_data].emit_insn(input_axis_two_inner, insn_cmd.DMA_COPY) # gm-ub sch[output].emit_insn(axis_two_inner, insn_cmd.DMA_COPY) # ub-gm else: sch[in_data].compute_at(sch[output], axis_outer) sch[in_data].emit_insn(input_axis_inner_inner, insn_cmd.DMA_COPY) # gm-ub sch[output].emit_insn(axis_inner_inner, insn_cmd.DMA_COPY) # ub-gm if cores: thread_block = tvm.thread_axis("blockIdx.x") sch[output].bind(axis_outer, thread_block)
def max_pool3d_schedule(res, sch): """ max_pool3d schedule """ tensor_d = res.op.input_tensors[0] tensor_h = tensor_d.op.input_tensors[0] tensor_w = tensor_h.op.input_tensors[0] tensor_in_ub = tensor_w.op.input_tensors[0] # set scope sch[tensor_in_ub].set_scope(cce.scope_ubuf) sch[tensor_w].set_scope(cce.scope_ubuf) sch[tensor_h].set_scope(cce.scope_ubuf) sch[tensor_d].set_scope(cce.scope_ubuf) # double buffer sch[tensor_in_ub].double_buffer() sch[tensor_in_ub].preload() sch[tensor_w].double_buffer() sch[tensor_h].double_buffer() sch[tensor_d].double_buffer() # bind core res_1o, _ = sch[res].split(res.op.axis[1], factor=1) thread_block = tvm.thread_axis("blockIdx.x") sch[res].bind(res_1o, thread_block) # tiling cut_h = max_pool3d_tiling(tensor_in_ub.shape) res_2o, res_2i = sch[res].split(res.op.axis[2], factor=cut_h) # compute at sch[tensor_w].compute_at(sch[res], res_2o) sch[tensor_h].compute_at(sch[res], res_2o) sch[tensor_d].compute_at(sch[res], res_2o) sch[tensor_in_ub].compute_at(sch[res], res_2o) # emit insn sch[tensor_in_ub].emit_insn(tensor_in_ub.op.axis[0], 'dma_copy') sch[tensor_w].emit_insn(tensor_w.op.axis[0], 'vector_max') sch[tensor_h].emit_insn(tensor_h.op.axis[0], 'vector_max') sch[tensor_d].emit_insn(tensor_d.op.axis[0], 'vector_max') sch[res].emit_insn(res_2i, 'dma_copy')
def __init__(self, ib_, dtype): self.ib_ = ib_ self.dtype = dtype self.type_size = tbe_platform.cce_intrin.get_bit_len(dtype) // 8 self.cp_align_len = cce_params.BLOCK_REDUCE_INT8 // self.type_size self.unified_buffer_len = tbe_platform.get_soc_spec( tbe_platform.cce_conf.UB_SIZE) // self.type_size self.vec_align_len = cce_params.VECTOR_INST_BLOCK_WIDTH // self.type_size self.uint8_max_value = 255 self.last_block = ib_.allocate("int32", (1, ), name="last_block", scope=cce_params.scope_reg) self.device_core_num = tbe_platform.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) self.block = tvm.thread_axis("blockIdx.x") self.ib_.scope_attr(self.block, "thread_extent", self.device_core_num) self.input_ub = 0 self.output_ub = 0
def _unpack_schedule(input_place, output_shape, y, num, axis, dtype): """Create unpack schedule. Parameters ---------- input_place: TVM tensor the tensor of input. output_shape: tuple or list the shape of output tensor. y: tuple or list the list of output tensor. num : int. the length of the dim axis. axis: int. the axis to unpack along. dtype: str. the dtype of input. Returns ------- sch: schedule the created schedule. build_list: list the list of input and output tensors, tensor type is TVM tensor. """ _, ele_each_block, device_core_num = _get_public_param(dtype) befordim, afterdim = output_shape[0], output_shape[-1] block_idx = tvm.thread_axis('blockIdx.x') # can open multi-core scene if befordim >= ele_each_block and afterdim < ele_each_block: befordim_in = ele_each_block // afterdim + 1 befordim_out = (befordim + befordim_in - 1) // befordim_in while (befordim + befordim_out - 1) // befordim_out * afterdim < ele_each_block: befordim_out -= 1 if befordim_out >= device_core_num: befordim_out = device_core_num afterdim_in = afterdim gm2ub_tensor, ub2ub_tensor_list, ub2gm_tensor_list, virtual_node = _unpack_compute_scalar( input_place, y, num, axis) res_op = [] build_list = [input_place] for ub2gm_tensor in ub2gm_tensor_list: res_op.append(ub2gm_tensor.op) build_list.append(ub2gm_tensor) sch = tvm.create_schedule(virtual_node.op) sch[gm2ub_tensor].set_scope(tbe_platform.scope_ubuf) for tensor in ub2ub_tensor_list: sch[tensor].set_scope(tbe_platform.scope_ubuf) befordim_outer, befordim_inner = sch[virtual_node].split( virtual_node.op.axis[0], nparts=befordim_out) afterdim_outer, afterdim_inner = sch[virtual_node].split( virtual_node.op.axis[2], factor=afterdim_in) sch[virtual_node].reorder(befordim_outer, afterdim_outer, befordim_inner, afterdim_inner) fused_axis = sch[virtual_node].fuse(befordim_outer, afterdim_outer) sch[virtual_node].bind(fused_axis, block_idx) new_shape = ((befordim + befordim_out - 1) // befordim_out, num, afterdim_in) split_axis, split_factor = _tiling_axis(new_shape, dtype) if split_axis == 0: axis_outer, axis_inner = sch[virtual_node].split( befordim_inner, factor=split_factor) else: axis_outer, axis_inner = sch[virtual_node].split( afterdim_inner, factor=split_factor) sch[gm2ub_tensor].compute_at(sch[virtual_node], axis_outer) sch[gm2ub_tensor].emit_insn(gm2ub_tensor.op.axis[split_axis], insn_cmd.DMA_COPY) for i in range(num): sch[ub2gm_tensor_list[i]].compute_at(sch[virtual_node], axis_outer) sch[ub2ub_tensor_list[i]].compute_at(sch[virtual_node], axis_outer) sch[ub2ub_tensor_list[i]].emit_insn( ub2ub_tensor_list[i].op.axis[split_axis], insn_cmd.DATA_MOV) sch[ub2gm_tensor_list[i]].emit_insn( ub2gm_tensor_list[i].op.axis[split_axis], insn_cmd.DMA_COPY) sch[virtual_node].emit_insn(axis_inner, insn_cmd.PHONY_INSN) else: gm2ub_tensor_list, ub2gm_tensor_list, virtual_node = _unpack_compute_copy( input_place, y, num, axis) res_op = [] build_list = [input_place] for ub2gm_tensor in ub2gm_tensor_list: res_op.append(ub2gm_tensor.op) build_list.append(ub2gm_tensor) sch = tvm.create_schedule(virtual_node.op) for tensor in gm2ub_tensor_list: sch[tensor].set_scope(tbe_platform.scope_ubuf) # can open multi-core scene if afterdim >= ele_each_block: if befordim >= device_core_num: befordim_out = device_core_num afterdim_in = afterdim elif befordim == 1: befordim_out = befordim afterdim_in = (afterdim + device_core_num - 1) // device_core_num else: afterdim_outer = device_core_num // befordim afterdim_in = (afterdim + afterdim_outer - 1) // afterdim_outer while afterdim % afterdim_in < ele_each_block: afterdim_in += 1 befordim_out = befordim befordim_outer, befordim_inner = sch[virtual_node].split( virtual_node.op.axis[0], nparts=befordim_out) afterdim_outer, afterdim_inner = sch[virtual_node].split( virtual_node.op.axis[2], factor=afterdim_in) sch[virtual_node].reorder(befordim_outer, afterdim_outer, befordim_inner, afterdim_inner) fused_axis = sch[virtual_node].fuse(befordim_outer, afterdim_outer) sch[virtual_node].bind(fused_axis, block_idx) new_shape = ((befordim + befordim_out - 1) // befordim_out, 1, afterdim_in) split_axis, split_factor = _tiling_axis(new_shape, dtype) if split_axis == 0: axis_outer, axis_inner = sch[virtual_node].split( befordim_inner, factor=split_factor) else: axis_outer, axis_inner = sch[virtual_node].split( afterdim_inner, factor=split_factor) else: split_axis, split_factor = _tiling_axis(output_shape, dtype) axis_outer, axis_inner = sch[virtual_node].split( virtual_node.op.axis[split_axis], factor=split_factor) for i in range(num): storage_axis = split_axis - 1 if split_axis != 0 else 0 sch[gm2ub_tensor_list[i]].storage_align( gm2ub_tensor_list[i].op.axis[storage_axis], ele_each_block, 0) sch[gm2ub_tensor_list[i]].double_buffer() sch[gm2ub_tensor_list[i]].compute_at(sch[virtual_node], axis_outer) sch[ub2gm_tensor_list[i]].compute_at(sch[virtual_node], axis_outer) sch[gm2ub_tensor_list[i]].emit_insn( gm2ub_tensor_list[i].op.axis[split_axis], insn_cmd.DMA_COPY) sch[ub2gm_tensor_list[i]].emit_insn( ub2gm_tensor_list[i].op.axis[split_axis], insn_cmd.DMA_COPY) sch[virtual_node].emit_insn(axis_inner, insn_cmd.PHONY_INSN) return sch, build_list
def _kernel_ir(dst, src, const_1): """ dropout_do_mask kernel """ ir_builder = tvm.ir_builder.create() place_holders = [src[0], src[1], dst[0], src[2]] # input & output params # 0:max_elemets # 1:cnt_per_vsel(VECTOR_INST_BLOCK_WIDTH=256 bytes is maximum process unit # in vector process) # 2:mask_cnt_per_vsel plantform_paras = [ _get_ub_max_elements(place_holders[0].dtype), VECTOR_INST_BLOCK_WIDTH // (cce_util.get_type_bits(place_holders[0].dtype) // 8), (VECTOR_INST_BLOCK_WIDTH // (cce_util.get_type_bits(place_holders[0].dtype) // 8)) // cce_util.get_type_bits(place_holders[1].dtype) ] target_core_num, mask_num_each_core, core_num_one_more, num_remain_by_128, is_not_align = \ _get_target_core_num(src[0], src[1]) if num_remain_by_128 != 0 and is_not_align: # 0:loop_for_ub 1:loop_for_128 # 2:remain_data_ub(after tilling by ub max process elements) 3:remain_ele loops_remains = [ int(place_holders[0].shape[0]) // plantform_paras[0], plantform_paras[0] // ELEMS_BATCH_PROCESS_FP16, int(place_holders[0].shape[0]) % plantform_paras[0], num_remain_by_128 ] _do_operation(ir_builder, place_holders, plantform_paras, loops_remains, const_1, 0, 0, num_remain_by_128, is_not_align) else: block_index = tvm.thread_axis("blockIdx.x") ir_builder.scope_attr(block_index, "thread_extent", target_core_num) with ir_builder.if_scope(block_index < core_num_one_more): shape_each_core = (mask_num_each_core + 1) * ELEMS_BATCH_PROCESS_FP16 * 8 block_offset = shape_each_core * block_index # 0:loop_for_ub 1:loop_for_128 # 2:remain_data_ub(after tilling by ub max process elements) 3:remain_ele loops_remains = [ int(shape_each_core) // plantform_paras[0], plantform_paras[0] // ELEMS_BATCH_PROCESS_FP16, int(shape_each_core) % plantform_paras[0], num_remain_by_128 ] _do_operation(ir_builder, place_holders, plantform_paras, loops_remains, const_1, block_offset, shape_each_core, num_remain_by_128, is_not_align) with ir_builder.else_scope(): shape_each_core = mask_num_each_core * ELEMS_BATCH_PROCESS_FP16 * 8 block_offset = ELEMS_BATCH_PROCESS_FP16 * 8 * core_num_one_more + \ shape_each_core * block_index if num_remain_by_128: with ir_builder.if_scope(block_index == target_core_num - 1): shape_each_core += num_remain_by_128 * 8 # 0:loop_for_ub 1:loop_for_128 # 2:remain_data_ub(after tilling by ub max process elements) 3:remain_ele loops_remains = [ int(shape_each_core) // plantform_paras[0], plantform_paras[0] // ELEMS_BATCH_PROCESS_FP16, int(shape_each_core) % plantform_paras[0], 0 ] _do_operation(ir_builder, place_holders, plantform_paras, loops_remains, const_1, block_offset, shape_each_core, num_remain_by_128, is_not_align) return ir_builder.get()
def strided_slice_d(input_x, output_x, begin, end, strides=None, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0, kernel_name="strided_slice_d"): """ Extracts a strided slice of a tensor (generalized python array indexing). Roughly speaking, this op extracts a slice of size (end-begin)/stride from the given input_ tensor. Starting at the location specified by begin the slice continues by adding stride to the index until all dimensions are not less than end. Note that a stride can be negative, which causes a reverse slice. Parameters ---------- input_x : dict shape and dtype of input output_x : dict shape and dtype of out begin: list. represents the index of the first value to select. end: list. represents the index of the last value to select. strides: list or tuple. step length to select. begin_mask: int a bitmask where a bit i being 1 means to ignore the begin value and instead use the largest interval possible. end_mask: int analogous to `begin_mask`. ellipsis_mask: int a bitmask where bit `i` being 1 means the `i`th position is actually an ellipsis. new_axis_mask: int a bitmask where bit `i` being 1 means the `i`th specification creates a new shape 1 dimension. shrink_axis_mask: int a bitmask where bit `i` implies that the `i`th specification should shrink the dimensionality. kernel_name : str cce kernel name, default value is "strided_slice_d" Returns ------- None """ input_shape = input_x.get("shape") input_dtype = input_x.get("dtype").lower() check_list = ("float16", "float32", "int32", "uint8", "bool", "int8") check_dtype(input_dtype, check_list, param_name="input_x") check_shape(input_shape, param_name="input_x") begin = list(begin) end = list(end) if not _check_parameter(input_shape, begin, end, strides, ellipsis_mask, new_axis_mask, shrink_axis_mask): raise RuntimeError("Parameter Invalid!") if strides is None: strides = _fill_list_with_ones(len(input_shape)) else: strides = list(strides) input_tensor = tvm.placeholder(input_shape, dtype=input_dtype, name='input_tensor') [output, out_shape] = strided_slice_d_compute(input_tensor, output_x, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, kernel_name=kernel_name) # pylint: disable=locally-disabled,unnecessary-lambda out_tensor = tvm.compute(out_shape, lambda *i: output(*i), name='out_tensor', tag='strided_slice_d|3') input_size = functools_reduce(lambda x, y: x * y, input_shape[0:]) out_size = functools_reduce(lambda x, y: x * y, out_shape[0:]) output_dtype = output_x.get("dtype").lower() output_shape = output_x.get("shape") if input_size == out_size: if output_dtype == "bool": input_x["dtype"] = "int8" output_x["dtype"] = "int8" if len(output_shape) == 0: output_x["shape"] = (1, ) copy_only(input_x, output_x, kernel_name) return output_shape_one = list(output_shape) if ellipsis_mask == 0 and shrink_axis_mask != 0: for i, _ in enumerate(list(input_shape)): if (shrink_axis_mask & 2**i) == 2**i: output_shape_one.insert(i, 1) output_shape = tuple(output_shape_one) # for RL tune getting res fusion_manager.set_op_res(out_tensor) ret, sch = rl_bank.query_rl_bank([out_tensor]) if ret and sch: with build_config: tvm.build(sch, [input_tensor, out_tensor], "cce", name=kernel_name) return sch = tvm.create_schedule(out_tensor.op) sch[output].set_scope(tbe_platform.scope_ubuf) sch_input_shape = [] for dim in output.shape: sch_input_shape.append(dim.value) check_result = _check_last_axis_situation(sch_input_shape, begin, end, strides) if check_result: _schedule_last_axis(sch, sch_input_shape, output, out_tensor, input_dtype) with build_config: tvm.build(sch, [input_tensor, out_tensor], "cce", name=kernel_name) return if _check_tik_branch(input_shape, output_shape, begin, end, strides): begin_shape = copy.deepcopy(begin) end_shape = copy.deepcopy(end) stride_shape = list(strides) stride_shape = copy.deepcopy(stride_shape) input_list = list(input_shape) # update begin_shape, end_shape begin_shape, end_shape, stride_shape = _init_parameter( input_list, begin_shape, end_shape, stride_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) head_size = 1 for i in range(0, (len(input_shape) - 1)): head_size = head_size * input_shape[i] if input_dtype == "float32" and input_shape[-1] == 2 and \ begin_shape[len(begin_shape) - 1] == 0 and end_shape[len(begin_shape) - 1] == 1 \ and head_size > 128: strided_slice_two_turn_one(input_x, output_x, kernel_name) return if input_list[-1] > 80 and output_shape[-1] == 80: res1 = strided_slice_last_dim_only(input_shape, input_dtype, output_shape, begin_shape, kernel_name) if res1: return if input_list[-1] >= 32 and input_list[-1] < 7500 and len(output_shape) > 1 and \ output_shape[-1] >= 32: res = strided_slice_last_dim_mte(input_shape, input_dtype, output_shape, begin_shape, kernel_name) if res: return res = strided_slice_last_dim(input_shape, input_dtype, output_shape, begin_shape, end_shape, stride_shape, kernel_name) if res: return else: res1 = strided_slice_last_dim_one(input_shape, input_dtype, output_shape, begin_shape, kernel_name) if res1: return split_axis, split_factor = _tilling_axis(out_shape, dtype=input_dtype) core_state = _get_multicore(out_shape, input_dtype, split_axis, split_factor) axis_outer, axis_inner = sch[out_tensor].split( out_tensor.op.axis[split_axis], factor=split_factor) if split_axis == 0: core_num = _get_target_core_num(out_shape[split_axis] // split_factor) axis_outer_outer, axis_outer_inter = sch[out_tensor].split( axis_outer, nparts=core_num) else: core_num = _get_target_core_num(out_shape[0]) axis_outer_outer, axis_outer_inter = sch[out_tensor].split( out_tensor.op.axis[0], nparts=core_num) for i in range(1, split_axis): axis_outer_inter = sch[out_tensor].fuse(axis_outer_inter, out_tensor.op.axis[i]) axis_outer_inter = sch[out_tensor].fuse(axis_outer_inter, axis_outer) sch[output].compute_at(sch[out_tensor], axis_outer_inter) sch[output].emit_insn(output.op.axis[0], insn_cmd.DMA_COPY) # gm-ub if len(out_shape) >= 2: # Convert bytes to Bytes dtype_bytes_size = tbe_platform.cce_intrin.get_bit_len( input_dtype) // 8 # 32 means one block size(32 Bytes), divide by 32 to # get the numbers of data that # can be stored in one block. element = 32 // dtype_bytes_size align_axis = _get_align_axis(out_shape) sch[output].storage_align(output.op.axis[align_axis], element, 0) if core_state: thread_block = tvm.thread_axis("blockIdx.x") sch[out_tensor].bind(axis_outer_outer, thread_block) sch[out_tensor].emit_insn(axis_inner, insn_cmd.DMA_COPY) # ub-gm with build_config: tvm.build(sch, [input_tensor, out_tensor], "cce", name=kernel_name)
def avg_pool3d_schedule(res, sch, ksize, strides): """ avg_pool3d schedule Parameters ---------- res: last tensor of compute sch: schedule object Returns ------- None """ res_cast = res.op.input_tensors[0] tensor_a = res_cast.op.input_tensors[0] tensor_d_hw = tensor_a.op.input_tensors[0] tensor_in_ub_cast = tensor_d_hw.op.input_tensors[0] tensor_in_ub = tensor_in_ub_cast.op.input_tensors[0] input_shape = [int(i) for i in tensor_in_ub.shape] # set scope sch[tensor_in_ub].set_scope(tbe_platform.scope_ubuf) sch[tensor_in_ub_cast].set_scope(tbe_platform.scope_ubuf) sch[tensor_d_hw].set_scope(tbe_platform.scope_ubuf) sch[tensor_a].set_scope(tbe_platform.scope_ubuf) sch[res_cast].set_scope(tbe_platform.scope_ubuf) core_num = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) ax_res_n = res.op.axis[0] ax_res_do = res.op.axis[1] ax_res_c1 = res.op.axis[2] ax_res_hw = res.op.axis[3] ax_res_c0 = res.op.axis[4] ax_dhw_n = tensor_d_hw.op.axis[0] ax_dhw_do = tensor_d_hw.op.axis[1] ax_dhw_c1 = tensor_d_hw.op.axis[2] ax_dhw_hw = tensor_d_hw.op.axis[3] ax_dhw_c0 = tensor_d_hw.op.axis[4] ax_dhw_rd = tensor_d_hw.op.reduce_axis[0] ax_dhw_rhw = tensor_d_hw.op.reduce_axis[1] factor_c1, factor_dout, factor_reduce_d, factor_reduce_hw = _tiling_param( input_shape, ksize, strides, core_num) reduce_hw_o, reduce_hw_i = sch[tensor_d_hw].split(ax_dhw_rhw, factor=factor_reduce_hw) reduce_d_o, reduce_d_i = sch[tensor_d_hw].split(ax_dhw_rd, factor=factor_reduce_d) dhw_do_o, dhw_do_i = sch[tensor_d_hw].split(ax_dhw_do, factor=factor_dout) sch[tensor_d_hw].reorder(ax_dhw_n, ax_dhw_c1, ax_dhw_hw, dhw_do_o, reduce_d_o, reduce_hw_o, dhw_do_i, reduce_d_i, reduce_hw_i, ax_dhw_c0) ax_res_c1_o, ax_res_c1_i = sch[res].split(ax_res_c1, factor=factor_c1) ax_res_do_o, ax_res_do_i = sch[res].split(ax_res_do, factor=factor_dout) sch[res].reorder(ax_res_n, ax_res_c1_o, ax_res_do_o, ax_res_c1_i, ax_res_do_i, ax_res_hw, ax_res_c0) ax_fused = sch[res].fuse(ax_res_n, ax_res_c1_o) block = tvm.thread_axis("blockIdx.x") sch[res].bind(ax_fused, block) sch[tensor_in_ub].compute_at(sch[tensor_d_hw], reduce_hw_o) sch[tensor_in_ub_cast].compute_at(sch[tensor_d_hw], reduce_hw_o) sch[res_cast].compute_at(sch[res], ax_res_do_o) sch[tensor_a].compute_at(sch[res], ax_res_do_o) sch[tensor_d_hw].compute_at(sch[res], ax_res_do_o) sch[tensor_in_ub].emit_insn(sch[tensor_in_ub].op.axis[0], insn_cmd.DMA_COPY) sch[tensor_in_ub_cast].emit_insn(sch[tensor_in_ub_cast].op.axis[0], insn_cmd.CAST) sch[tensor_d_hw].emit_insn(dhw_do_i, insn_cmd.REDUCE_SUM) sch[tensor_a].emit_insn(sch[tensor_a].op.axis[0], insn_cmd.MUL) sch[res_cast].emit_insn(sch[res_cast].op.axis[0], insn_cmd.CAST) sch[res].emit_insn(ax_res_c1_i, insn_cmd.DMA_COPY)
def _max_pool_grad_grad_with_argmax_schedule(compute_list, sch_list): """ Computes second-order gradients of the maxpooling function. Parameters ---------- compute_list: list All of the result of the maxpooling computation Include grad_in_l1, grad_im2col, grad_fractal, grad_fractal_transp, argmax_ub, tensor_zero_ub, grad_grad_col, grad_grad, res. sch_list: list sch of the maxpooling, include sch. Returns ------- None """ sch = sch_list[0] res = compute_list[-1] grad_in_l1 = compute_list[0] grad_im2col = compute_list[1] grad_fractal = compute_list[2] grad_fractal_transp = compute_list[3] argmax_ub = compute_list[4] tensor_zero_ub = compute_list[5] grad_grad_col = compute_list[6] grad_grad = compute_list[7] setfmatrix_map = res.op.attrs['setfmatrix_dict'] setfmatrix_dict = {} for key, value in setfmatrix_map.items(): if hasattr(value, "value"): setfmatrix_dict[key] = value.value else: setfmatrix_dict[key] = value extract_map = res.op.attrs['extract_params'] extract_params = {} for key, value in extract_map.items(): if hasattr(value, "value"): extract_params[key] = value.value else: extract_params[key] = value padding = extract_params['padding_mode'] fmap_shape = extract_params['fmap_shape'] shape_max_pool_h = extract_params['shape_max_pool_h'] shape_max_pool_w = extract_params['shape_max_pool_w'] stride_h = setfmatrix_dict["conv_stride_h"] stride_w = setfmatrix_dict["conv_stride_w"] kernel_h = setfmatrix_dict["conv_kernel_h"] kernel_w = setfmatrix_dict["conv_kernel_w"] # These calculations are on CB sch[grad_in_l1].set_scope(tbe_platform.scope_cbuf) sch[grad_im2col].set_scope(tbe_platform.scope_cbuf) # These calculations are on UB sch[grad_fractal].set_scope(tbe_platform.scope_ubuf) sch[argmax_ub].set_scope(tbe_platform.scope_ubuf) sch[tensor_zero_ub].set_scope(tbe_platform.scope_ubuf) sch[grad_grad_col].set_scope(tbe_platform.scope_ubuf) sch[grad_grad].set_scope(tbe_platform.scope_ubuf) # compute inline sch[grad_fractal_transp].compute_inline() # Last axis of grad_im2col instr has to be an integer multiple of 16 sch[grad_grad].buffer_align((1, 1), (1, 1), (1, 1), (1, BLOCK_SIZE), (1, BLOCK_SIZE), (1, 1)) sch[grad_im2col].buffer_align((1, 1), (1, shape_max_pool_w), (1, 1), (1, 1), (1, 1), (1, BLOCK_SIZE)) # get tiling shape value max_l1_valid_size = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.L1_SIZE) max_ub_size = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.UB_SIZE) max_next_valid_size = max_ub_size * 16 * kernel_h * kernel_w // \ (49 * kernel_h * kernel_w + 16) is_tiling_valid, shape_in_l1, is_l1_double_buffer, \ shape_after_load3d, is_l0_ub_double_buffer = \ get_load3d_tiling(fmap_shape, (kernel_h, kernel_w), (stride_h, stride_w), padding, max_l1_valid_size, max_next_valid_size, "float16") if (is_tiling_valid, shape_in_l1, is_l1_double_buffer, shape_after_load3d, is_l0_ub_double_buffer) == \ (False, None, None, None, None): raise RuntimeError( "Not supported fmap shape = (%u, %u, %u, %u, %u)," " kernel = (1, %u, %u, 1)," " stride = (1, %u, %u, 1)" % (fmap_shape[0], fmap_shape[1], fmap_shape[2], fmap_shape[3], fmap_shape[4], kernel_h, kernel_w, stride_h, stride_w)) _, _, l1_hi, l1_wi, _ = shape_in_l1 def _get_output_length(l1_hi, l1_wi, stride, kernel_size): if fmap_shape[2].value == l1_hi: tile_l1_ho = shape_max_pool_h else: tile_l1_ho = (l1_hi + stride[0] - kernel_size[0]) // stride[0] if fmap_shape[3].value == l1_wi: tile_l1_wo = shape_max_pool_w else: tile_l1_wo = (l1_wi + stride[1] - kernel_size[1]) // stride[1] return tile_l1_ho, tile_l1_wo tile_l1_ho, tile_l1_wo, = _get_output_length(l1_hi, l1_wi, (stride_h, stride_w), (kernel_h, kernel_w)) (_, ub_howo, _, ub_khkw, _) = shape_after_load3d # tiling split_factor_howo = ub_howo # cut grad_grad grad_grad_n_outer, grad_grad_n_inner = sch[grad_grad].split( grad_grad.op.axis[0], factor=1) grad_grad_c1_outer, grad_grad_c1_inner = sch[grad_grad].split( grad_grad.op.axis[1], factor=1) grad_grad_howo_outer, grad_grad_howo_inner = sch[grad_grad].split( grad_grad.op.axis[2], factor=(split_factor_howo + 15) // 16) grad_grad_k_outer, grad_grad_k_inner = sch[grad_grad].split( grad_grad.op.reduce_axis[0], factor=ub_khkw) sch[grad_grad].reorder(grad_grad_n_outer, grad_grad_c1_outer, grad_grad_howo_outer, grad_grad_k_outer, grad_grad_n_inner, grad_grad_c1_inner, grad_grad_howo_inner, grad_grad.op.axis[3], grad_grad_k_inner, grad_grad.op.axis[4]) # cut res res_n_outer, res_n_inner = sch[res].split(res.op.axis[0], factor=1) res_c1_outer, res_c1_inner = sch[res].split(res.op.axis[1], factor=1) # gm->l1 res_howo_outer, res_howo_inner = \ sch[res].split(res.op.axis[2], factor=(tile_l1_ho * tile_l1_wo)) # l1->ub res_mwo_outer, res_mwo_inner = sch[res].split(res_howo_inner, factor=split_factor_howo) sch[res].reorder(res_n_outer, res_c1_outer, res_howo_outer, res_mwo_outer, res_n_inner, res_c1_inner, res_mwo_inner, res.op.axis[3]) res_fused_n_c1_howo_outer = sch[res].fuse(res_n_outer, res_c1_outer, res_howo_outer) core_number = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) sch[grad_in_l1].compute_at(sch[res], res_fused_n_c1_howo_outer) sch[grad_im2col].compute_at(sch[res], res_fused_n_c1_howo_outer) sch[tensor_zero_ub].compute_at(sch[res], res_fused_n_c1_howo_outer) sch[grad_fractal].compute_at(sch[grad_grad], grad_grad_k_outer) sch[argmax_ub].compute_at(sch[grad_grad], grad_grad_k_outer) sch[grad_grad_col].compute_at(sch[grad_grad], grad_grad_k_outer) sch[grad_grad].compute_at(sch[res], res_mwo_outer) sch[grad_in_l1].emit_insn(grad_in_l1.op.axis[0], insn_cmd.DMA_COPY) sch[grad_im2col].emit_insn(grad_im2col.op.axis[0], 'set_fmatrix', setfmatrix_dict) sch[grad_fractal].emit_insn(grad_fractal.op.axis[0], insn_cmd.IM2COL) sch[argmax_ub].emit_insn(argmax_ub.op.axis[0], insn_cmd.DMA_COPY) sch[tensor_zero_ub].emit_insn(tensor_zero_ub.op.axis[0], insn_cmd.DUP) sch[grad_grad_col].emit_insn(grad_grad_col.op.axis[0], insn_cmd.SELECT) sch[grad_grad].emit_insn(grad_grad_n_inner, insn_cmd.REDUCE_SUM) sch[res].emit_insn(res_n_inner, insn_cmd.DMA_COPY) # for double buffer if is_l0_ub_double_buffer: sch[grad_fractal].double_buffer() sch[argmax_ub].double_buffer() sch[grad_grad_col].double_buffer() sch[grad_grad].double_buffer() sch[grad_im2col].double_buffer() if is_l1_double_buffer: sch[grad_in_l1].double_buffer() # for multi cores block = tvm.thread_axis("blockIdx.x") sch[res].bind(res_fused_n_c1_howo_outer, block)
def _strided_slice_assign_schedule(schedule_list, out, input_value_shape, input_shape, data_dtype): """ strided_slice_assign schedule function Parameters ---------- schedule_list : tuple include tvm.tensor of ref and tvm.tensor of input_value in ub out : tvm.tensor tvm.tensor of out input_value_shape : list input_value shape input_shape : list ref shape data_dtype : str input data dtype Returns ------- sch : tvm.schedule the compute schedule """ dtype_bytes_size = tbe_platform.cce_intrin.get_bit_len(data_dtype) // 8 # 32 means one block size(32 Bytes), divide by 32 to get the numbers of data # that can be stored in one block. one_block_bytes_size = cce_params.VECTOR_INST_BLOCK_WIDTH // \ cce_params.VECTOR_INST_BLOCK_NUM element = one_block_bytes_size // dtype_bytes_size input_tensor = schedule_list[0] input_value_ub = schedule_list[1] sch = tvm.create_schedule(out.op) sch[input_value_ub].set_scope(tbe_platform.scope_ubuf) split_axis, split_factor, not_storage_align = _tilling_axis( input_value_shape, dtype=data_dtype) if not_storage_align: split_factor = input_shape[0] axis_outer, axis_inner = sch[out].split(out.op.axis[split_axis], factor=split_factor) # multi core device_core_num = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) if split_axis == 0: init_core_num = (input_value_shape[0] + split_factor - 1) // split_factor if init_core_num > device_core_num: forward_axis_outer, forward_axis_inner = sch[out].split( axis_outer, nparts=device_core_num) sch[out].bind(forward_axis_outer, tvm.thread_axis('blockIdx.x')) compute_at_axis = forward_axis_inner else: sch[out].bind(axis_outer, tvm.thread_axis('blockIdx.x')) compute_at_axis = axis_outer else: all_move_count = input_value_shape[0] if all_move_count > device_core_num: fused_axis_outer, _ = sch[out].split(sch[out].op.axis[0], nparts=device_core_num) sch[out].bind(fused_axis_outer, tvm.thread_axis('blockIdx.x')) else: sch[out].bind(sch[out].op.axis[0], tvm.thread_axis('blockIdx.x')) compute_at_axis = axis_outer # compute_at sch[input_value_ub].compute_at(sch[out], compute_at_axis) # storage_align if not_storage_align: pass elif len(input_value_shape) == 1: sch[input_value_ub].storage_align(compute_at_axis, element, 0) else: sch[input_value_ub].storage_align(input_value_ub.op.axis[-2], element, 0) # emit insn sch[input_value_ub].emit_insn(input_value_ub.op.axis[split_axis], insn_cmd.DMA_COPY) sch[out].emit_insn(axis_inner, insn_cmd.DMA_COPY, {"no_overlap": 1}) sch[input_value_ub].double_buffer() return sch
def _assign_sub_schedule(schedule_list, res, shape, dtype, data_a): """ assign_sub schedule function Parameters ---------- schedule_list : list list of tensors for schedule. res : tvm.tensor tensor of result shape : list or tuple shape of ref and value. dtype : str the type of ref and value. Returns ------- sch: tvm.schedule the compute schedule """ # list of tensors for 'elewise_single_cast' cast_list = (schedule_list[INDEX_TWO], schedule_list[INDEX_THREE], schedule_list[INDEX_FIVE]) sch = tvm.create_schedule(res.op) for cal_res in schedule_list: sch[cal_res].set_scope(cce.scope_ubuf) for cal_res in schedule_list: sch[cal_res].double_buffer() # choose a appropriate method of tiling the tensor split_axis, split_factor = _tilling_axis(shape, dtype=dtype) axis_outer, axis_inner = sch[res].split(res.op.axis[split_axis], factor=split_factor) out_extent = (int(res.shape[0]) + split_factor - 1) // split_factor # if out extent > 1, bind to multi core thread axis if out_extent > 1: block_index = tvm.thread_axis('blockIdx.x') if out_extent > cce.CceProductParams().getParams("Device_core_num"): thread_axis, axis_outer = sch[res].split( axis_outer, nparts=cce.CceProductParams().getParams("Device_core_num")) sch[res].bind(thread_axis, block_index) else: sch[res].bind(axis_outer, block_index) # compute_at for cal_res in schedule_list: sch[cal_res].compute_at(sch[res], axis_outer) # rewrite the variable sch[data_a].reused_by(res) sch[schedule_list[INDEX_ZERO]].reused_by(schedule_list[INDEX_FIVE]) sch[schedule_list[INDEX_ZERO]].emit_insn( schedule_list[INDEX_ZERO].op.axis[split_axis], 'dma_copy') sch[schedule_list[INDEX_ONE]].emit_insn( schedule_list[INDEX_ONE].op.axis[split_axis], 'dma_copy') if dtype in ("int8", "uint8"): for cal_res in cast_list: sch[cal_res].emit_insn(cal_res.op.axis[split_axis], 'elewise_single_cast') sch[schedule_list[INDEX_TWO]].reused_by(schedule_list[INDEX_FOUR]) sch[schedule_list[INDEX_FOUR]].emit_insn( schedule_list[INDEX_FOUR].op.axis[split_axis], 'elewise_binary_sub') sch[res].emit_insn(axis_inner, 'dma_copy') return sch
def extract_image_patches_schedule(res, sch_list): """ :param res: the multi-results in the operator :param sch: schedule list """ sch = sch_list[0] setfmatrix_map = res.op.attrs['setfmatrix_dict'] setfmatrix_dict = {} for key, value in setfmatrix_map.items(): if hasattr(value, "value"): setfmatrix_dict[key] = value.value else: setfmatrix_dict[key] = value extract_map = res.op.attrs['extract_params'] extract_params = {} for key, value in extract_map.items(): if hasattr(value, "value"): extract_params[key] = value.value else: extract_params[key] = value out_h = extract_params['out_h'] out_w = extract_params['out_w'] fmap_shape = extract_params['fmap_shape'] c_in_real = extract_params["c_in_real"] fmap_n = fmap_shape[0].value fmap_c1 = fmap_shape[1].value fmap_h = fmap_shape[2].value fmap_w = fmap_shape[3].value fmap_c0 = fmap_shape[4].value kernel_h = setfmatrix_dict['conv_kernel_h'] kernel_w = setfmatrix_dict['conv_kernel_w'] dilate_h = setfmatrix_dict['conv_dilation_h'] dilate_w = setfmatrix_dict['conv_dilation_w'] stride_h = setfmatrix_dict['conv_stride_h'] stride_w = setfmatrix_dict['conv_stride_w'] ub_res = res.op.input_tensors[0] workspace_res = ub_res.op.input_tensors[0] ub_merge_co = workspace_res.op.input_tensors[0] ub_merge_hw = ub_merge_co.op.input_tensors[0] ub_transpose = ub_merge_hw.op.input_tensors[0] ub_split_c1 = ub_transpose.op.input_tensors[0] fmap_fractal = ub_split_c1.op.input_tensors[0] fmap_im2col = fmap_fractal.op.input_tensors[0] fmap_in_l1 = fmap_im2col.op.input_tensors[0] sch[fmap_in_l1].set_scope(tbe_platform.scope_cbuf) sch[fmap_im2col].set_scope(tbe_platform.scope_cbuf) sch[fmap_fractal].set_scope(tbe_platform.scope_ubuf) sch[ub_split_c1].set_scope(tbe_platform.scope_ubuf) sch[ub_transpose].set_scope(tbe_platform.scope_ubuf) sch[ub_merge_hw].set_scope(tbe_platform.scope_ubuf) sch[ub_merge_co].set_scope(tbe_platform.scope_ubuf) sch[workspace_res].set_scope(tbe_platform.scope_gm) sch[ub_res].set_scope(tbe_platform.scope_ubuf) dtype_input = ub_res.dtype if dtype_input == "int8" or dtype_input == "uint8": BLOCK_SIZE_ALIGN = BLOCK_SIZE_INT8 type_size = INT8_SIZE else: BLOCK_SIZE_ALIGN = BLOCK_SIZE type_size = FP16_SIZE out_hw_up16 = ((out_h * out_w - 1) // BLOCK_SIZE + 1) * BLOCK_SIZE dilated_kernel_h = (kernel_h - 1) * dilate_h + 1 dilated_kernel_w = (kernel_w - 1) * dilate_w + 1 lcm_out_w = BLOCK_SIZE // math.gcd(out_w, BLOCK_SIZE) * out_w cut_h_col = (BLOCK_SIZE // math.gcd(out_w, BLOCK_SIZE) - 1) * stride_h \ + 1 + dilated_kernel_h // 2 if cut_h_col > fmap_h: cut_h_col = fmap_h # cut_h_col while cut_hw = BLOCK_SIZE cut_w_row_s = (BLOCK_SIZE - 1) * stride_w + 1 cut_h_row_s = (((cut_w_row_s - 1) // fmap_w + 1) - 1) * stride_h + 1 cut_w_row = cut_w_row_s + dilated_kernel_w - 1 cut_h_row = cut_h_row_s + dilated_kernel_h - 1 if lcm_out_w > out_hw_up16: lcm_out_w = out_hw_up16 extract_params['lcm_out_w'] = lcm_out_w extract_params['cut_h_col'] = cut_h_col extract_params['cut_w_row'] = cut_w_row extract_params['cut_h_row'] = cut_h_row extract_params['dilated_kernel_h'] = dilated_kernel_h extract_params['dilated_kernel_w'] = dilated_kernel_w sch[ub_res].buffer_align((1, 1), (1, 1), (1, 1), (1, BLOCK_SIZE_ALIGN)) sch[fmap_im2col].buffer_align((1, 1), (out_w, out_w), (1, 1), (1, 1), (1, 1), (1, BLOCK_SIZE_ALIGN)) sch[fmap_fractal].buffer_align((1, 1), (1, 1), (1, 1), (1, BLOCK_SIZE), (1, BLOCK_SIZE_ALIGN)) used_ub_size = UB_SIZE // type_size // DOUBLE_BUFFER avg_split_ub_size = used_ub_size // NEED_UB_SPACE_NUM howo = out_h * out_w khkw = kernel_h * kernel_w c_out = khkw * fmap_c1 * fmap_c0 if c_in_real % BLOCK_SIZE_ALIGN == 0: n_factor = 1 howo_factor = howo khkw_factor = khkw c_factor = c_in_real max_v = fmap_c1 max_v_cut_col, max_v_cut_row, max_v_cut_col_p, max_v_cut_min, \ move_rate_cut_col, move_rate_cut_row, move_rate_cut_col_p = \ get_tiling_param(setfmatrix_dict, extract_params, used_ub_size, type_size, avg_split_ub_size) move_rate = move_rate_cut_col if move_rate < move_rate_cut_row: move_rate = move_rate_cut_row if move_rate < move_rate_cut_col_p: move_rate = move_rate_cut_col_p if lcm_out_w * c_out <= avg_split_ub_size \ and khkw * fmap_c1 <= LOAD3D_REPEAT_TIME_LIMIT: max_v = avg_split_ub_size // lcm_out_w // c_out if lcm_out_w * max_v < howo: # if True cut n howo else only cut n howo_factor = lcm_out_w * max_v elif move_rate == move_rate_cut_col and max_v_cut_col > 0: # cut howo col howo_factor = lcm_out_w max_v = max_v_cut_col khkw_factor = 1 c_factor = BLOCK_SIZE_ALIGN * max_v elif move_rate == move_rate_cut_row and max_v_cut_row > 0: # cut howo row howo_factor = BLOCK_SIZE khkw_factor = khkw max_v = max_v_cut_row c_factor = BLOCK_SIZE_ALIGN * max_v elif move_rate == move_rate_cut_col_p and max_v_cut_col_p > 0: # cut howo col partially howo_factor = BLOCK_SIZE * max_v_cut_col_p c_factor = c_in_real khkw_factor = khkw max_v = fmap_c1 else: # cut howo khkw c howo_factor = BLOCK_SIZE max_v = max_v_cut_min khkw_factor = 1 c_factor = BLOCK_SIZE_ALIGN * max_v device_core_num = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) res_n_inner_outer, res_n_inner = sch[res].split(res.op.axis[0], factor=n_factor) res_n_outer_outer, res_n_outer = sch[res].split(res_n_inner_outer, nparts=device_core_num) res_howo_outer, res_howo_inner = sch[res].split(res.op.axis[1], factor=howo_factor) res_khkw_outer, res_khkw_inner = sch[res].split(res.op.axis[2], factor=khkw_factor) res_c_inner_outer, res_c_inner = sch[res].split( res.op.axis[3], factor=BLOCK_SIZE_ALIGN) res_c_outer, res_c_outer_inner = sch[res].split(res_c_inner_outer, factor=c_factor // BLOCK_SIZE_ALIGN) sch[res].reorder(res_n_outer_outer, res_n_outer, res_howo_outer, res_khkw_outer, res_c_outer, res_n_inner, res_c_outer_inner, res_howo_inner, res_khkw_inner, res_c_inner) if L1_SIZE >= fmap_h * fmap_w * fmap_c0 * fmap_c1 * \ type_size * DOUBLE_BUFFER: sch[fmap_im2col].compute_at(sch[res], res_n_outer) sch[fmap_in_l1].compute_at(sch[res], res_n_outer) elif L1_SIZE >= cut_h_row * fmap_w * fmap_c0 * fmap_c1 * type_size * \ DOUBLE_BUFFER and move_rate != move_rate_cut_col: sch[fmap_im2col].compute_at(sch[res], res_howo_outer) sch[fmap_in_l1].compute_at(sch[res], res_howo_outer) elif L1_SIZE >= cut_h_col * fmap_w * fmap_c0 * fmap_c1 * \ type_size * DOUBLE_BUFFER and move_rate == move_rate_cut_col: sch[fmap_im2col].compute_at(sch[res], res_howo_outer) sch[fmap_in_l1].compute_at(sch[res], res_howo_outer) else: sch[fmap_im2col].compute_at(sch[res], res_c_outer) sch[fmap_in_l1].compute_at(sch[res], res_c_outer) sch[ub_transpose].compute_at(sch[res], res_c_outer) sch[fmap_fractal].compute_at(sch[res], res_c_outer) sch[workspace_res].compute_inline() sch[ub_res].compute_inline() sch[ub_merge_co].compute_inline() sch[ub_merge_hw].compute_inline() sch[ub_split_c1].compute_inline() block = tvm.thread_axis("blockIdx.x") sch[res].bind(res_n_outer_outer, block) sch[fmap_in_l1].emit_insn(fmap_in_l1.op.axis[0], insn_cmd.DMA_COPY) sch[fmap_im2col].emit_insn(fmap_im2col.op.axis[0], insn_cmd.SET_FMATRIX, setfmatrix_dict) sch[fmap_fractal].emit_insn(fmap_fractal.op.axis[0], insn_cmd.IM2COL) sch[ub_transpose].emit_insn(ub_transpose.op.axis[0], insn_cmd.DMA_COPY) sch[res].emit_insn(res_n_inner, insn_cmd.DMA_COPY) else: c1_factor = BLOCK_SIZE_ALIGN res_n_outer, res_n_inner = sch[res].split(res.op.axis[0], factor=1) res_c1_outer, res_c1_inner = sch[res].split(res.op.axis[3], factor=c_in_real) sch[ub_res].compute_at(sch[res], res_c1_outer) workspace_res_n_outer, workspace_res_n_inner = sch[ workspace_res].split(workspace_res.op.axis[0], factor=1) workspace_res_howo_outer, workspace_res_howo_inner = sch[ workspace_res].split(workspace_res.op.axis[1], factor=lcm_out_w) workspace_res_khkw_outer, workspace_res_khkw_inner = sch[ workspace_res].split(workspace_res.op.axis[2], factor=1) workspace_res_c1_outer, workspace_res_c1_inner = sch[ workspace_res].split(workspace_res.op.axis[3], factor=c1_factor) sch[workspace_res].reorder( workspace_res_n_outer, workspace_res_howo_outer, workspace_res_khkw_outer, workspace_res_c1_outer, workspace_res_n_inner, workspace_res_howo_inner, workspace_res_khkw_inner, workspace_res_c1_inner) sch[ub_merge_co].compute_at(sch[workspace_res], workspace_res_c1_outer) sch[ub_merge_hw].compute_at(sch[workspace_res], workspace_res_c1_outer) sch[ub_transpose].compute_at(sch[workspace_res], workspace_res_c1_outer) sch[ub_split_c1].compute_at(sch[workspace_res], workspace_res_c1_outer) sch[fmap_fractal].compute_at(sch[workspace_res], workspace_res_c1_outer) sch[fmap_im2col].compute_at(sch[workspace_res], workspace_res_howo_outer) sch[fmap_in_l1].compute_at(sch[workspace_res], workspace_res_howo_outer) if c_in_real > BLOCK_SIZE_ALIGN: sch[workspace_res].compute_at(sch[res], res_n_outer) block = tvm.thread_axis("blockIdx.x") sch[res].bind(res_n_outer, block) sch[ub_split_c1].compute_inline() sch[ub_transpose].compute_inline() sch[ub_merge_co].compute_inline() sch[fmap_in_l1].emit_insn(fmap_in_l1.op.axis[0], insn_cmd.DMA_COPY) sch[fmap_im2col].emit_insn(fmap_im2col.op.axis[0], insn_cmd.SET_FMATRIX, setfmatrix_dict) sch[fmap_fractal].emit_insn(fmap_fractal.op.axis[0], insn_cmd.IM2COL) sch[ub_split_c1].emit_insn(ub_split_c1.op.axis[0], insn_cmd.DMA_COPY) sch[ub_transpose].emit_insn(ub_transpose.op.axis[0], insn_cmd.DMA_COPY) sch[ub_merge_hw].emit_insn(ub_merge_hw.op.axis[0], insn_cmd.DMA_COPY) sch[ub_merge_co].emit_insn(ub_merge_co.op.axis[0], insn_cmd.DMA_COPY) sch[workspace_res].emit_insn(workspace_res_c1_inner, insn_cmd.DMA_COPY) sch[ub_res].emit_insn(ub_res.op.axis[3], insn_cmd.DMA_COPY) sch[res].emit_insn(res_c1_inner, insn_cmd.DMA_COPY, {"no_overlap": 1}) sch[fmap_in_l1].double_buffer() sch[fmap_im2col].double_buffer() sch[fmap_fractal].double_buffer() sch[ub_transpose].double_buffer() sch[ub_res].double_buffer()
def basic_lstm_cell_schedule(tensor_list, scope_list, operation_list, build_list, product_info, tilling_info, kernel_name): """ do the schedule for the LSTM compute. """ ht = tensor_list["ht"] schedule_list = [ht.op] s = tvm.create_schedule(schedule_list) for key in tensor_list.keys(): if key in scope_list.keys(): s[tensor_list[key]].set_scope(scope_list[key]) if key in operation_list.keys(): s[tensor_list[key]].emit_insn(s[tensor_list[key]].op.axis[0], operation_list[key]) s[tensor_list["tensor_xh_l1_ot"]].reused_by(tensor_list["tensor_xh_l1_it"], tensor_list["tensor_xh_l1_ft"], tensor_list["tensor_xh_l1_jt"]) s[tensor_list["tensor_xh_l0a_ot"]].reused_by( tensor_list["tensor_xh_l0a_it"], tensor_list["tensor_xh_l0a_ft"], tensor_list["tensor_xh_l0a_jt"]) # handle matmul info mad_pattern = cce.GEMM_MODE # split matmul symbol = ["ot", "it", "jt", "ft"] l1_factor = tilling_info["l1_factor"] for t in symbol: s[tensor_list["tensor_b_loc_" + t]].reused_by( tensor_list["tensor_matmul_l0c_" + t], tensor_list["tensor_matmul_result_l0c_" + t]) tmp = tensor_list["tensor_matmul_l0c_" + t] block_n_o, block_n_i = s[tmp].split( tmp.op.axis[1], factor=tilling_info["block_n_factor"]) block_out_o, block_out_i = s[tmp].split( tmp.op.axis[0], factor=tilling_info["block_out_factor"]) l1_n_outer, l1_n_inner = s[tmp].split( block_n_i, factor=tilling_info["n_factor"]) # safe l1_out_outer, l1_out_inner = s[tmp].split( block_out_i, factor=tilling_info["out_factor"]) l1_k_outer, l1_k_inner = s[tmp].split(tmp.op.reduce_axis[0], factor=tilling_info["k_factor"]) l0_n_outer, l0_n_inner = s[tmp].split(l1_n_inner, factor=tilling_info["n_factor"]) l0_out_outer, l0_out_inner = s[tmp].split( l1_out_inner, factor=tilling_info["out_factor"]) l0_k_outer, l0_k_inner = s[tmp].split(l1_k_inner, factor=tilling_info["k_factor"]) s[tmp].reorder(block_n_o, block_out_o, l1_n_outer, l1_out_outer, l1_k_outer, l0_n_outer, l0_out_outer, l0_k_outer, l0_n_inner, l0_out_inner, tmp.op.axis[2], tmp.op.axis[3], l0_k_inner, tmp.op.reduce_axis[1]) s[tensor_list["tensor_xh_l0a_" + t]].compute_at(s[tmp], l0_k_outer) s[tensor_list["tensor_w_l0b_" + t]].compute_at(s[tmp], l0_k_outer) if l1_factor != 1: s[tensor_list["tensor_xh_l1_" + t]].split( s[tensor_list["tensor_xh_l1_" + t]].op.axis[1], factor=l1_factor) s[tensor_list["tensor_xh_l1_" + t]].compute_at(s[tmp], l1_k_outer) s[tensor_list["tensor_w_l1_" + t]].compute_at(s[tmp], l1_k_outer) mad_dict = { "mad_pattern": mad_pattern, "k_outer": [l1_k_outer, l0_k_outer], "init_bias": 1 } s[tmp].emit_insn(l0_n_inner, 'mad', mad_dict) # split ht origin linmu ht_0 = ht.shape[0].value ht_1 = ht.shape[1].value axis_1_o, axis_1_i = s[ht].split(ht.op.axis[1], factor=tilling_info["block_n_factor"]) axis_1_i_0, axis_1_i_i = s[ht].split(axis_1_i, factor=tilling_info["n_factor"]) axis_0_o, axis_0_i = s[ht].split(ht.op.axis[0], factor=tilling_info["block_out_factor"]) axis_0_o_o, axis_0_o_i = s[ht].split(axis_0_o, factor=1) axis_0_i_o, axis_0_i_i = s[ht].split(axis_0_i, factor=tilling_info["out_factor"]) s[ht].reorder(axis_1_o, axis_0_o_o, axis_0_o_i, axis_1_i_0, axis_0_i_o, axis_1_i_i, axis_0_i_i) compute_at_axis = axis_0_o_i for t in symbol: s[tensor_list["tensor_xh_l1_" + t]].double_buffer() s[tensor_list["tensor_w_l1_" + t]].double_buffer() s[tensor_list["tensor_b_ub_" + t]].double_buffer() s[tensor_list["tensor_c_ub"]].double_buffer() s[tensor_list["it_ub"]].double_buffer() s[tensor_list["ft_ub"]].double_buffer() s[tensor_list["ot_ub"]].double_buffer() s[tensor_list["jt_ub"]].double_buffer() block_num = cceconf.CceProductParams().getParams("Device_core_num") if (ht_1 // tilling_info["block_n_factor"]) > 1: core_outer = s[ht].split(axis_1_o, nparts=block_num) s[ht].bind(core_outer[0], tvm.thread_axis("blockIdx.x")) else: core_outer = s[ht].split(axis_0_o_o, nparts=block_num) s[ht].bind(core_outer[0], tvm.thread_axis("blockIdx.x")) special_symbol = { "tensor_xh_l0a_it", "tensor_xh_l0a_ft", "tensor_xh_l0a_ot", "tensor_xh_l0a_jt", "tensor_w_l0b_it", "tensor_w_l0b_ft", "tensor_w_l0b_ot", "tensor_w_l0b_jt", "tensor_xh_l1_it", "tensor_xh_l1_ft", "tensor_xh_l1_ot", "tensor_xh_l1_jt", "tensor_w_l1_it", "tensor_w_l1_ft", "tensor_w_l1_ot", "tensor_w_l1_jt", "ht" } for key in tensor_list.keys(): if key not in special_symbol: s[tensor_list[key]].compute_at(s[ht], compute_at_axis) # Move result back (Fake) tensor_list["it_ub_fake_ub"] = s.cache_read(tensor_list["it_ub_fake_true"], cce.scope_ubuf, [tensor_list["tensor_ij_ub"]]) tensor_list["jt_ub_fake_ub"] = s.cache_read(tensor_list["jt_ub_fake_true"], cce.scope_ubuf, [tensor_list["tensor_ij_ub"]]) tensor_list["ft_ub_fake_ub"] = s.cache_read(tensor_list["ft_ub_fake_true"], cce.scope_ubuf, [tensor_list["tensor_cf_ub"]]) tensor_list["ot_ub_fake_ub"] = s.cache_read(tensor_list["ot_ub_fake_true"], cce.scope_ubuf, [tensor_list["tensor_ht_ub"]]) for t in ["ot", "it", "ft", "jt"]: s[tensor_list[t + "_ub_fake_ub"]].compute_at(s[ht], compute_at_axis) s[tensor_list[t + "_ub2gm"]].reused_by( tensor_list[t + "_ub_fake_ub"], tensor_list[t + "_ub_fake_true"]) s[tensor_list[t + "_ub2gm"]].unreused_by(tensor_list["tensor_" + t + "_ub_true"]) s[tensor_list[t + "_ub_fake_ub"]].reused_by(reuse_data=True) s[tensor_list[t + "_ub_fake_ub"]].emit_insn( s[tensor_list[t + "_ub_fake_ub"]].op.axis[0], 'phony_insn') s[tensor_list[t + "_ub_fake"]].compute_inline() #ct tensor_ct_ub_fake_ub = s.cache_read( tensor_list["tensor_ct_ub_fake"], cce.scope_ubuf, [tensor_list["tensor_ct_ub_fake_true"]]) s[tensor_ct_ub_fake_ub].compute_at(s[ht], compute_at_axis) s[tensor_list["tensor_ct_ub"]].reused_by( tensor_ct_ub_fake_ub, tensor_list["tensor_ct_ub_fake_true"]) s[tensor_ct_ub_fake_ub].emit_insn(s[tensor_ct_ub_fake_ub].op.axis[0], 'phony_insn') s[tensor_list["tensor_ct_ub_fake"]].compute_inline() #tanhct tensor_tanhct_ub_fake_ub = s.cache_read( tensor_list["tensor_tanhct_ub_fake"], cce.scope_ubuf, [tensor_list["tensor_tanhct_ub_fake_true"]]) s[tensor_tanhct_ub_fake_ub].compute_at(s[ht], compute_at_axis) s[tensor_list["tensor_ub_tanh_ct"]].reused_by( tensor_list["tensor_tanhct_ub_fake_true"]) s[tensor_tanhct_ub_fake_ub].emit_insn( s[tensor_tanhct_ub_fake_ub].op.axis[0], 'phony_insn') s[tensor_list["tensor_tanhct_ub_fake"]].compute_inline() #ht s[ht].emit_insn(s[ht].op.axis[2], 'dma_copy') build_symbol = [ "x", "h", "c", "w", "b", "ct", "ht", "it", "jt", "ft", "ot", "tanhct" ] new_build_list = [] for t in build_symbol: if t in build_list.keys(): new_build_list += [build_list[t]] with build_config: tvm.build(s, new_build_list, "cce", name=kernel_name)
def avg_pool_grad_schedule(res): """ the tiling avg pool grad schedule """ s = tvm.create_schedule(res.op) mad_cast = res.op.input_tensors[0] mad_res = mad_cast.op.input_tensors[0] dout_col_pad = mad_res.op.input_tensors[0] weight_rotated = mad_res.op.input_tensors[1] weight = weight_rotated.op.input_tensors[0] dout_col = dout_col_pad.op.input_tensors[0] dout_dilated = dout_col.op.input_tensors[0] dout_mul = dout_dilated.op.input_tensors[0] dout = dout_mul.op.input_tensors[0] dvealuemean = dout_mul.op.input_tensors[1] dout_ubuf = s.cache_read(dout, tbe_platform.scope_ubuf, [dout_mul]) dvealuemean_ubuf = s.cache_read(dvealuemean, tbe_platform.scope_ubuf, [dout_mul]) dout_mul_ubuf = s.cache_write(dout_mul, tbe_platform.scope_ubuf) dout_cbuf_nc1hwc0 = s.cache_write(dout_dilated, tbe_platform.scope_cbuf) dout_dilated_ubuf = s.cache_write(dout_cbuf_nc1hwc0, tbe_platform.scope_ubuf) dout_cbuf_row_major = s.cache_write(dout_col, tbe_platform.scope_cbuf) dout_ca = s.cache_write(dout_col_pad, tbe_platform.scope_ca) s[dout_mul].compute_inline() s[dout_dilated].compute_inline() s[dout_col].compute_inline() s[dout_col_pad].compute_inline() weight_cbuf = s.cache_read(weight, tbe_platform.scope_cbuf, [weight_rotated]) weight_cb = s.cache_write(weight_rotated, tbe_platform.scope_cb) s[weight_rotated].compute_inline() mad_cc = s.cache_write(mad_res, tbe_platform.scope_cc) mad_ubuf = s.cache_write(mad_cast, tbe_platform.scope_ubuf) s[mad_res].compute_inline() s[mad_cast].compute_inline() # get shape value dilated_pad_top = res.op.attrs['dilated_pad'][0].value dilated_pad_bottom = res.op.attrs['dilated_pad'][1].value dilated_pad_left = res.op.attrs['dilated_pad'][2].value dilated_pad_right = res.op.attrs['dilated_pad'][3].value k_height = res.op.attrs['weight_height'].value k_width = res.op.attrs['weight_width'].value block_size = dout.op.shape[len(dout.op.shape) - 1].value _, _, _, dout_dilated_h, dout_dilated_w, _ = dout_dilated.shape input_w = dout_dilated_w.value + dilated_pad_left \ + dilated_pad_right - k_width + 1 input_h = dout_dilated_h.value + dilated_pad_top \ + dilated_pad_bottom - k_height + 1 stride = dout_dilated.op.attrs["strides"][0].value weight_shape = [int(i.value) for i in weight.shape] dout_shape = [int(i.value) for i in dout.shape] dout_dilated_shape = [int(i.value) for i in dout_dilated.shape] mad_cc_axis_n, mad_cc_axis_cg, mad_cc_axis_co1, mad_cc_axis_howomad, \ mad_cc_axis_co0 = mad_cc.op.axis mad_ubuf_axis_n, mad_ubuf_axis_cg, mad_ubuf_axis_co1, \ mad_ubuf_axis_howomad, mad_ubuf_axis_co0 = mad_ubuf.op.axis mad_res_shape = [int(i.value) for i in mad_res.shape] res_block_n, res_block_cgroup, _, _, _ = mad_res_shape #tiling res_l1, tile_input_h, tile_dile_h_ub, tile_m, \ tile_k, tile_n = avg_pool_grad_tiling( input_w, input_h, weight_shape, dout_shape, res, stride) mad_cc_Ncut_o, mad_cc_Ncut_i = s[mad_cc].split(mad_cc_axis_n, factor=1) mad_cc_mcut_o, mad_cc_mcut_i = s[mad_cc].split(mad_cc_axis_howomad, factor=tile_m) mad_cc_kcut_o, mad_cc_kcut_i = s[mad_cc].split(mad_cc.op.reduce_axis[0], factor=tile_k) mad_cc_ncut_o, mad_cc_ncut_i = s[mad_cc].split(mad_cc_axis_co1, factor=tile_n) s[mad_cc].reorder(mad_cc_Ncut_o, mad_cc_axis_cg, mad_cc_ncut_o, mad_cc_mcut_o, mad_cc_kcut_o, mad_cc_Ncut_i, mad_cc_ncut_i, mad_cc_mcut_i, mad_cc_axis_co0, mad_cc_kcut_i, mad_cc.op.reduce_axis[1]) s[dout_ca].compute_at(s[mad_cc], mad_cc_kcut_o) s[weight_cb].compute_at(s[mad_cc], mad_cc_kcut_o) mad_ubuf_Ncut_o, mad_ubuf_Ncut_i = s[mad_ubuf].split(mad_ubuf_axis_n, factor=1) mad_ubuf_mcut_o, mad_ubuf_mcut_i = s[mad_ubuf].split(mad_ubuf_axis_howomad, factor=tile_m) mad_ubuf_ncut_o, mad_ubuf_ncut_i = s[mad_ubuf].split(mad_ubuf_axis_co1, factor=tile_n) s[mad_ubuf].reorder(mad_ubuf_Ncut_o, mad_ubuf_axis_cg, mad_ubuf_ncut_o, mad_ubuf_mcut_o, mad_ubuf_Ncut_i, mad_ubuf_ncut_i, mad_ubuf_mcut_i, mad_ubuf_axis_co0) s[mad_cc].compute_at(s[mad_ubuf], mad_ubuf_mcut_o) conv_Ncut_o, conv_Ncut_i = s[res].split(res.op.axis[0], factor=1) conv_hcut_o, conv_hcut_i = s[res].split(res.op.axis[3], factor=(res_l1)) conv_mcut_o, conv_mcut_i = s[res].split(conv_hcut_i, factor=tile_m) s[res].reorder(conv_Ncut_o, res.op.axis[1], conv_hcut_o, conv_mcut_o, conv_Ncut_i, res.op.axis[2], conv_mcut_i, res.op.axis[4]) s[mad_ubuf].buffer_align((1, 1), (1, 1), (1, 1), (1, block_size), (1, block_size)) s[mad_ubuf].compute_at(s[res], conv_mcut_o) s[dout_cbuf_row_major].buffer_align((1, 1), (1, 1), (input_w, input_w), (1, 1), (1, 1), (1, 1), (1, block_size)) s[dout_cbuf_row_major].compute_at(s[res], conv_hcut_o) s[dout_cbuf_nc1hwc0].compute_at(s[res], conv_hcut_o) s[weight_cbuf].compute_at(s[res], conv_hcut_o) dout_dilated_w = dout_dilated_shape[4] ub_l1hcut_o, ub_l1hcut_i = s[dout_cbuf_nc1hwc0].split( dout_cbuf_nc1hwc0.op.axis[3], factor=tile_dile_h_ub) if stride > 1: dila_o_h, dila_i_h = s[dout_dilated_ubuf].split( dout_dilated_ubuf.op.axis[3], factor=stride) dila_o_w, dila_i_w = s[dout_dilated_ubuf].split( dout_dilated_ubuf.op.axis[4], factor=stride) s[dout_dilated_ubuf].reorder(dila_i_h, dila_i_w, dila_o_h, dila_o_w) s[dout_dilated_ubuf].unroll(dila_i_h) s[dout_dilated_ubuf].unroll(dila_i_w) s[dout_dilated_ubuf].compute_at(s[dout_cbuf_nc1hwc0], ub_l1hcut_o) s[dout_dilated_ubuf].emit_insn(dout_dilated_ubuf.op.axis[0], insn_cmd.DMA_PADDING) else: s[dout_dilated_ubuf].compute_inline() s[dout_mul_ubuf].compute_at(s[dout_cbuf_nc1hwc0], ub_l1hcut_o) s[dout_ubuf].compute_at(s[dout_cbuf_nc1hwc0], ub_l1hcut_o) s[dvealuemean_ubuf].compute_at(s[dout_cbuf_nc1hwc0], ub_l1hcut_o) s[dout_ubuf].emit_insn(dout_ubuf.op.axis[0], insn_cmd.DMA_COPY) s[dvealuemean_ubuf].emit_insn(dvealuemean_ubuf.op.axis[0], insn_cmd.DMA_COPY) s[dout_mul_ubuf].emit_insn(dout_mul_ubuf.op.axis[0], insn_cmd.MUL) s[dout_cbuf_nc1hwc0].emit_insn(ub_l1hcut_i, insn_cmd.DMA_COPY) # emit convolution params. setfmatrix_dict = { "conv_kernel_h": res.op.attrs['weight_height'], "conv_kernel_w": res.op.attrs['weight_width'], "conv_padding_top": res.op.attrs['dilated_pad'][0], "conv_padding_bottom": res.op.attrs['dilated_pad'][1], "conv_padding_left": res.op.attrs['dilated_pad'][2], "conv_padding_right": res.op.attrs['dilated_pad'][3], "conv_stride_h": res.op.attrs['dilated_strides'][0], "conv_stride_w": res.op.attrs['dilated_strides'][1], "conv_fm_c": dout_dilated.shape[2] * dout_dilated.shape[5], "conv_fm_h": dout_dilated.shape[3], "conv_fm_w": dout_dilated.shape[4] } s[dout_cbuf_row_major].emit_insn(dout_cbuf_row_major.op.axis[1], insn_cmd.SET_FMATRIX, setfmatrix_dict) s[dout_ca].emit_insn(dout_ca.op.axis[1], insn_cmd.IM2COL) s[weight_cbuf].emit_insn(weight_cbuf.op.axis[0], insn_cmd.DMA_COPY) s[weight_cb].emit_insn(weight_cb.op.axis[3], insn_cmd.DMA_COPY) s[mad_ubuf].emit_insn(mad_ubuf_Ncut_i, insn_cmd.DMA_COPY) mad_dict = { "mad_pattern": tbe_platform.cce_params.CONV_MODE, "k_outer": mad_cc_kcut_o } s[mad_cc].emit_insn(mad_cc_Ncut_i, insn_cmd.MAD, mad_dict) s[res].emit_insn(conv_Ncut_i, insn_cmd.DMA_COPY) s[dout_ca].double_buffer() s[weight_cb].double_buffer() s[mad_cc].double_buffer() # for multi cores if res_block_n < 16: res_NNCut_o, res_NNCut_i = s[res].split(conv_Ncut_o, nparts=res_block_n) res_ccCut_o, res_ccCut_i = s[res].split(res.op.axis[1], nparts=res_block_cgroup) s[res].reorder(res_NNCut_o, res_ccCut_o, res_NNCut_i, res_ccCut_i) out_fused = s[res].fuse(res_NNCut_o, res_ccCut_o) out_fused_out, _ = s[res].split(out_fused, nparts=res_block_n * res_block_cgroup) bind_out, _ = s[res].split(out_fused_out, 1) blockidx = tvm.thread_axis("blockIdx.x") s[res].bind(bind_out, blockidx) else: block = tvm.thread_axis("blockIdx.x") s[res].bind(conv_Ncut_o, block) return s
def dynamic_lstm(input_x, weight, bias, output_h, kernel_name="dynamic_lstm"): """ x : dict A dict object, contains a Tensor 's type and shape and format, the type can be float32, the format can be [FRACTAL_NZ] w : dict A dict object, contains a Tensor 's type and shape and format, the type can be float32, the format can be [FRACTAL_ZN_LSTM] b : dict A dict object, contains a Tensor 's type and shape and format, the type can be float32, the format can be [ND] output_h : dict A dict object, contains a Tensor 's type and shape and format, the type can be float32, the format can be [FRACTAL_NZ] """ check_dtype(input_x, weight, bias, output_h) shape_x_input = input_x.get("shape") shape_w_input = weight.get("shape") shape_b_input = bias.get("shape") shape_output = output_h.get("shape") check(shape_x_input, shape_w_input, shape_b_input, shape_output) scan_one_num = 1 t_size = shape_x_input[0] + scan_one_num m_size = shape_x_input[2] k_size = shape_w_input[0] n_size = shape_w_input[1] hidden_size = shape_output[1] block_size = n_size // hidden_size in_x = k_size - hidden_size shape_b = (1, k_size, block_size, hidden_size, 16, 16) shape_c = (1, block_size, hidden_size, m_size, 16, 16) shape_bias = (1, block_size, hidden_size, 1, 1, 16) shape_x = (t_size, in_x, m_size, 16, 16) shape_h = (1, k_size - in_x, m_size, 16, 16) shape_i = (1, hidden_size, m_size, 16, 16) shape_i_t = (t_size, hidden_size, m_size, 16, 16) core_num = cce.get_soc_spec("CORE_NUM") # one core use 4 int64 that is 32B align shape_sync = (4 * core_num,) k0_size = 16 input_dtype = input_x.get("dtype") data_dtype = 'float16' sync_dtype = 'int64' # define placeholder input_x = tvm.placeholder(shape_x, dtype=input_dtype, name='input_x') weight = tvm.placeholder(shape_b, dtype=input_dtype, name='weight') bias = tvm.placeholder(shape_bias, name='bias', dtype=input_dtype) s_state_h = tvm.placeholder(shape_h, dtype=input_dtype, name='state_h') s_state_c = tvm.placeholder(shape_i, dtype=input_dtype, name='state_c') sync0 = tvm.placeholder(shape_sync, name="sync0", dtype='int64') # compute # weight need first to ub and cast to float16 weight_ub = \ tvm.compute( shape_b, lambda *indices: weight(*indices), name="weight_ub") weight_fp16 = \ tvm.compute(shape_b, lambda *indices: weight_ub(*indices).astype(data_dtype), name="weight_fp16") # input and s_state_h need first to ub and cast to float16 shape_a_z_bigz = (t_size, m_size, k_size, 16, 16) # input and s_start_h is Nz, need trans to zZ # so change axis 1 and 2 a_ub = tvm.compute(shape_a_z_bigz, lambda *indice: tvm.select(indice[2] < in_x, input_x[indice[0], indice[2], indice[1], indice[3], indice[4]], s_state_h[0, indice[2] - in_x, indice[1], indice[3], indice[4]] ), name="a_ub", tag="concat") shape_a_z_bigz_1 = (1, m_size, k_size, 16, 16) a_ub_fp16 = \ tvm.compute(shape_a_z_bigz_1, lambda *indices: a_ub(*indices).astype(data_dtype), name="a_ub_fp16") a_l1 = tvm.compute(shape_a_z_bigz_1, lambda *indices: a_ub_fp16(*indices), name='a_l1') b_l1 = tvm.compute(shape_b, lambda *indices: weight_fp16(*indices), name='b_l1') # shape_a_z_bigz_1 = (1, m_size, k_size, 16, 16) a_l0a = tvm.compute(shape_a_z_bigz, lambda *indices: a_l1(*indices), name="a_l0a") b_l0b = tvm.compute(shape_b, lambda *indices: b_l1(*indices), name="b_l0b") k1 = tvm.reduce_axis((0, k_size), name='k1') k0 = tvm.reduce_axis((0, k0_size), name='k0') c_l0c = tvm.compute(shape_c, lambda t, nb_0, nb_1, mb, mp, np: tvm.sum((a_l0a[t, mb, k1, mp, k0] * \ b_l0b[t, k1, nb_0, nb_1, np, k0]) \ .astype('float32'), axis=[k1, k0]), name='c_l0c') c_ub = tvm.compute(shape_c, lambda *indices: c_l0c(*indices), name="c_ub") bias_ub = tvm.compute(shape_bias, lambda *indices: bias(*indices), name='bias_ub') bias_bc_ub = te.lang.cce.broadcast(bias_ub, shape_c) c_ub_bias = te.lang.cce.vadd(c_ub, bias_bc_ub) # split matmul res i_t_index = 0 j_t_index = 1 f_t_index = 2 o_t_index = 3 i_t = \ tvm.compute(shape_i, lambda t, i, j, k, l: c_ub_bias(t, i_t_index, i, j, k, l), name="i_t") j_t = \ tvm.compute(shape_i, lambda t, i, j, k, l: c_ub_bias(t, j_t_index, i, j, k, l), name="j_t") f_t = \ tvm.compute(shape_i, lambda t, i, j, k, l: c_ub_bias(t, f_t_index, i, j, k, l), name="f_t") o_t = \ tvm.compute(shape_i, lambda t, i, j, k, l: c_ub_bias(t, o_t_index, i, j, k, l), name="o_t") f_t_sigmoid = sigmoid_compute(f_t) i_t_sigmoid = sigmoid_compute(i_t) o_t_sigmoid = sigmoid_compute(o_t) j_t_tanh = tanh_compute(j_t) c_t_tmp1 = te.lang.cce.vmul(s_state_c, f_t_sigmoid) c_t_tmp2 = te.lang.cce.vmul(j_t_tanh, i_t_sigmoid) update_c = te.lang.cce.vadd(c_t_tmp1, c_t_tmp2) update_c_gm = tvm.compute(shape_i_t, lambda t, i, j, k, l: update_c(0, i, j, k, l), name="update_c_gm") c_t_tanh = tanh_compute(update_c) update_h = te.lang.cce.vmul(c_t_tanh, o_t_sigmoid) update_h_gm = tvm.compute(shape_i_t, lambda t, i, j, k, l: update_h(0, i, j, k, l), name="update_h_gm") update_hc_vn = \ tvm.compute( shape_i_t, lambda t, i, j, k, l: update_c_gm(0, i, j, k, l) +\ update_h_gm(t, i, j, k, l), name="update_hc_vn") update_c_gm_vn = \ tvm.compute( shape_i_t, lambda t, i, j, k, l: update_hc_vn(0, i, j, k, l), name="update_c_gm_vn") update_h_gm_vn = \ tvm.compute( shape_i_t, lambda t, i, j, k, l: update_hc_vn(0, i, j, k, l), name="update_h_gm_vn") update_c_ub = \ tvm.compute( shape_i, lambda t, i, j, k, l: update_c_gm_vn(t, i, j, k, l), name="update_c_ub") update_c_gm_2 = \ tvm.compute(shape_i_t, lambda t, i, j, k, l: update_c_ub(0, i, j, k, l), name="update_c_gm_2") update_h_ub = \ tvm.compute( shape_i, lambda t, i, j, k, l: update_h_gm_vn(t, i, j, k, l), name="update_h_ub") update_h_gm_2 = \ tvm.compute( shape_i_t, lambda t, i, j, k, l: update_h_ub(0, i, j, k, l) +\ update_c_gm_2(t, i, j, k, l), name="update_h_gm_2") update_h_gm_2_dummy = \ tvm.compute(shape_i_t, lambda t, i, j, k, l: update_h_gm_2(t, i, j, k, l), name="update_h_gm_2_dummy") # state init init_shape = (1, hidden_size, m_size, 16, 16) s_state_h_ub = \ tvm.compute(shape_h, lambda *indices: tvm.const(0.0, dtype=input_dtype), name='s_state_h_ub') s_state_c_ub = \ tvm.compute(shape_i, lambda *indices: tvm.const(0.0, dtype=input_dtype), name='s_state_c_ub') s_init_h = \ tvm.compute( init_shape, lambda _, i, j, k, l: s_state_h_ub[0, i, j, k, l], name="s_init_h") s_init_c = \ tvm.compute( init_shape, lambda _, i, j, k, l: s_state_c_ub[0, i, j, k, l], name="s_init_c") # scan scan_h, scan_c = tvm.scan( [s_init_h, s_init_c], [update_h_ub, update_c_ub], [s_state_h, s_state_c], scan_update=[update_h_gm_2, update_h_gm_2_dummy], name="lstm_scan") # end compute # schedule s = tvm.create_schedule([scan_h.op, scan_c.op]) new_build_list = [input_x, weight, bias, update_h_gm, update_c_gm, sync0, update_h_gm_vn, update_c_gm_vn] def gen_reversed_subgraph_list(out_tensor, tensor_list): """ traverse tensors by Depth-First-Search """ if out_tensor is None: return stack = [out_tensor] visited_list = [] while stack: cur_tensor = stack.pop() visited_list.append(cur_tensor) for in_tensor in cur_tensor.op.input_tensors: if in_tensor not in visited_list: stack.append(in_tensor) if "elewise" in in_tensor.op.tag or \ "broadcast" == in_tensor.op.tag: if in_tensor not in tensor_list: tensor_list.append(in_tensor) elewise_tensors = [] gen_reversed_subgraph_list(update_h_gm, elewise_tensors) barrier_tensor = c_ub_bias elewise_before_barrier_tensors = [bias_bc_ub] # set scope s[a_l1].set_scope(cce.scope_cbuf) s[b_l1].set_scope(cce.scope_cbuf) s[a_l0a].set_scope(cce.scope_ca) s[b_l0b].set_scope(cce.scope_cb) s[c_l0c].set_scope(cce.scope_cc) s[c_ub].set_scope(cce.scope_ubuf) s[s_init_h].set_scope(cce.scope_ubuf) s[bias_ub].set_scope(cce.scope_ubuf) s[bias_bc_ub].set_scope(cce.scope_ubuf) s[scan_h].set_scope(cce.scope_ubuf) s[scan_c].set_scope(cce.scope_ubuf) s[update_h_ub].set_scope(cce.scope_ubuf) s[update_c_ub].set_scope(cce.scope_ubuf) s[s_state_h_ub].set_scope(cce.scope_ubuf) s[s_state_c_ub].set_scope(cce.scope_ubuf) s[weight_ub].set_scope(cce.scope_ubuf) s[weight_fp16].set_scope(cce.scope_ubuf) s[a_ub].set_scope(cce.scope_ubuf) s[a_ub_fp16].set_scope(cce.scope_ubuf) for tensor in elewise_tensors: s[tensor].set_scope(cce.scope_ubuf) # compute inline compute_inline_tensors = [i_t, j_t, f_t, o_t] for tensor in compute_inline_tensors: s[tensor].compute_inline() # matmul tiling factor_l1_m, factor_l1_n, factor_l1_k, \ factor_l0_m, factor_l0_n, factor_l0_k = \ _get_lstm_tiling(m_size, k_size, n_size) l1_n_outer, l1_n_inner = \ s[c_l0c].split(c_l0c.op.axis[2], factor=factor_l1_n // block_size) l1_m_outer, l1_m_inner = \ s[c_l0c].split(c_l0c.op.axis[3], factor=factor_l1_m) l1_k_outer, l1_k_inner = \ s[c_l0c].split(c_l0c.op.reduce_axis[0], factor=factor_l1_k) l0_n_outer, l0_n_inner = s[c_l0c].split(l1_n_inner, factor=factor_l0_n) l0_m_outer, l0_m_inner = s[c_l0c].split(l1_m_inner, factor=factor_l0_m) l0_k_outer, l0_k_inner = s[c_l0c].split(l1_k_inner, factor=factor_l0_k) s[c_l0c].reorder(l1_n_outer, c_l0c.op.axis[1], l1_m_outer, l1_k_outer, l0_n_outer, l0_m_outer, l0_k_outer, l0_n_inner, l0_m_inner, c_l0c.op.axis[3 + 1], c_l0c.op.axis[4 + 1], l0_k_inner, c_l0c.op.reduce_axis[1]) s[weight_ub].compute_at(s[c_l0c], l1_k_outer) s[weight_fp16].compute_at(s[c_l0c], l1_k_outer) s[a_ub].compute_at(s[c_l0c], l1_k_outer) s[a_ub_fp16].compute_at(s[c_l0c], l1_k_outer) s[a_l0a].compute_at(s[c_l0c], l0_k_outer) s[b_l0b].compute_at(s[c_l0c], l0_k_outer) s[a_l1].compute_at(s[c_l0c], l1_k_outer) s[b_l1].compute_at(s[c_l0c], l1_k_outer) ub_n_outer, ub_n_inner = \ s[c_ub].split(c_ub.op.axis[2], factor=factor_l1_n // block_size) ub_m_outer, ub_m_inner = s[c_ub].split(c_ub.op.axis[3], factor=factor_l1_m) s[c_ub].reorder(ub_n_outer, c_ub.op.axis[1], ub_m_outer, ub_n_inner, ub_m_inner, c_ub.op.axis[4], c_ub.op.axis[5]) s[c_l0c].compute_at(s[c_ub], ub_n_outer) # elewise compute_at barrier_outer, barrier_inner = \ s[barrier_tensor].split(barrier_tensor.op.axis[2], factor=factor_l1_n // block_size) s[barrier_tensor].reorder( barrier_tensor.op.axis[0], barrier_outer, barrier_tensor.op.axis[1], barrier_inner, barrier_tensor.op.axis[3], barrier_tensor.op.axis[4], barrier_tensor.op.axis[5]) s[c_ub].compute_at(s[barrier_tensor], barrier_outer) s[bias_ub].compute_at(s[barrier_tensor], barrier_outer) for tensor in elewise_before_barrier_tensors: s[tensor].compute_at(s[barrier_tensor], barrier_outer) vn_outer, vn_inner = \ s[update_hc_vn].split(update_hc_vn.op.axis[0 + 1], factor=factor_l1_n // block_size) second_split_factor = \ (hidden_size // (factor_l1_n // block_size)) // core_num vn_o_outer, vn_o_inner = \ s[update_hc_vn].split(vn_outer, factor=second_split_factor) s[barrier_tensor].compute_at(s[update_hc_vn], vn_o_inner) for tensor in elewise_tensors: if tensor not in elewise_before_barrier_tensors: s[tensor].compute_at(s[update_hc_vn], vn_o_inner) s[update_c_gm].compute_at(s[update_hc_vn], vn_o_inner) s[update_h_gm].compute_at(s[update_hc_vn], vn_o_inner) second_split_factor = hidden_size // core_num res_h_outer, res_h_inner = \ s[update_h_gm_2].split(update_h_gm_2.op.axis[1], factor=hidden_size) s[update_hc_vn].compute_at(s[update_h_gm_2], update_h_gm_2.op.axis[0]) s[update_c_gm_vn].compute_at(s[update_h_gm_2], res_h_outer) s[update_h_gm_vn].compute_at(s[update_h_gm_2], res_h_outer) s[update_c_ub].compute_at(s[update_h_gm_2], res_h_outer) s[update_c_gm_2].compute_at(s[update_h_gm_2], res_h_outer) s[update_h_ub].compute_at(s[update_h_gm_2], res_h_outer) s[update_h_gm_vn].bind_buffer( update_h_gm_vn.op.axis[0], 0, scan_h.op.scan_axis + res_h_outer) s[update_c_gm_vn].bind_buffer( update_c_gm_vn.op.axis[0], 0, scan_h.op.scan_axis + res_h_outer) # bind s[update_hc_vn].bind(vn_o_outer, tvm.thread_axis("blockIdx.x")) # multi core sync s[update_hc_vn].pragma(update_hc_vn.op.axis[0], pragma_type="multicore_sync_wait_after", pragma_value=sync0[0]) s[update_hc_vn].pragma(update_hc_vn.op.axis[0], pragma_type="multicore_sync_set_after", pragma_value=sync0[0]) # modify for extend s[input_x].bind_buffer(0, 0, scan_h.op.scan_axis) s[update_h_gm].buffer_tile((scan_h.op.scan_axis*1, 1), (None, None), (None, None), (None, None), (None, None)) s[update_c_gm].buffer_tile((scan_h.op.scan_axis*1, 1), (None, None), (None, None), (None, None), (None, None)) s[update_h_gm_2].buffer_tile((0, 1), (None, None), (None, None), (None, None), (None, None)) s[update_c_gm_2].buffer_tile((0, 1), (None, None), (None, None), (None, None), (None, None)) # buffer reuse s[update_h_gm].reused_by(update_h_gm_vn) s[update_c_gm].reused_by(update_c_gm_vn) # emit_insn s[a_l1].emit_insn(a_l1.op.axis[0], 'dma_copy') s[b_l1].emit_insn(b_l1.op.axis[0], 'dma_copy') s[a_l0a].emit_insn(a_l0a.op.axis[0], 'dma_copy') s[b_l0b].emit_insn(b_l0b.op.axis[0], 'dma_copy') s[weight_ub].emit_insn(weight_ub.op.axis[0], 'dma_copy') s[weight_fp16].emit_insn(weight_fp16.op.axis[0], 'vector_conv') s[a_ub].emit_insn(a_ub.op.axis[0], 'dma_copy') s[a_ub_fp16].emit_insn(a_ub_fp16.op.axis[0], 'vector_conv') mad_dict = {"mad_pattern": 0, "k_outer": [l1_k_outer, l0_k_outer]} s[c_l0c].emit_insn(l0_n_inner, 'mad', mad_dict) s[c_ub].emit_insn(ub_n_inner, 'dma_copy') s[s_init_h].emit_insn(s_init_h.op.axis[0], 'dma_copy') s[s_init_c].emit_insn(s_init_c.op.axis[0], 'dma_copy') s[bias_bc_ub].emit_insn(bias_bc_ub.op.axis[0], 'unified_broadcast') s[s_state_h_ub].emit_insn(s_state_h_ub.op.axis[0], 'broadcast') s[s_state_c_ub].emit_insn(s_state_c_ub.op.axis[0], 'broadcast') s[barrier_tensor].emit_insn(barrier_tensor.op.axis[1], 'vector_add') for tensor in elewise_tensors: if tensor != barrier_tensor: insn = get_emit_insn_map(tensor) s[tensor].emit_insn(tensor.op.axis[0], insn) s[bias_ub].emit_insn(bias_ub.op.axis[0], 'dma_copy') s[update_c_gm].emit_insn(s[update_c_gm].op.axis[1], 'dma_copy') s[update_h_gm].emit_insn(s[update_h_gm].op.axis[1], 'dma_copy') s[update_c_ub].emit_insn(update_c_ub.op.axis[1], 'dma_copy') s[update_h_ub].emit_insn(update_h_ub.op.axis[1], 'dma_copy') s[update_hc_vn].emit_insn(vn_inner, 'phony_insn') s[update_c_gm_vn].emit_insn(s[update_c_gm_vn].op.axis[0], 'phony_insn') s[update_h_gm_vn].emit_insn(s[update_h_gm_vn].op.axis[0], 'phony_insn') s[update_h_gm_2].emit_insn(res_h_inner, 'phony_insn') s[update_c_gm_2].emit_insn(s[update_c_gm_2].op.axis[0], 'phony_insn') s[update_h_gm_2_dummy].emit_insn( update_h_gm_2_dummy.op.axis[0], 'phony_insn') def _write_workspace_info(shape_list, dtype_list, sync_num, kernel_name): """ modify json after build """ def _write_code(wkspace_dict, fname): fname = os.path.realpath(fname) if fname.startswith(os.getcwd()): if os.path.exists(fname): with open(fname, "r") as f: load_dict = json.load(f) load_dict.update(wkspace_dict) with open(fname, "w") as f: json.dump(load_dict, f, sort_keys=True, indent=4, separators=(',', ':')) def _get_data_width(ele): """ get data width """ m_sea = re.search(r'\d+', ele) if m_sea: return int(m_sea.group(0)) // 8 return 0 if not os.path.exists("kernel_meta"): os.mkdir("kernel_meta") os.chmod("kernel_meta", stat.S_IRWXU + stat.S_IRGRP + stat.S_IXGRP) num = len(shape_list) wkspace_dict = {} if num: total_size = [functools_reduce(lambda x, y: x * y, list_i) for list_i in shape_list] addr_type_list = [] for i, element in enumerate(dtype_list): total_size[i] = total_size[i] * _get_data_width(element) addr_type_list.append(0) if not os.path.exists("kernel_meta"): os.mkdir("kernel_meta") os.chmod("kernel_meta", stat.S_IRWXU + stat.S_IRGRP + stat.S_IXGRP) wkspace_dict["workspace"] = {"num": num, "size": total_size, "type": addr_type_list} if sync_num: parameters_list = \ (len(new_build_list) - 2 - sync_num) * [0, ] + sync_num * [1, ] wkspace_dict["parameters"] = parameters_list if wkspace_dict: _write_code(wkspace_dict, "kernel_meta/" + kernel_name + ".json") with build_config: tvm.build(s, new_build_list, "cce", name=kernel_name) _write_workspace_info( [shape_i_t, shape_sync], [input_dtype, sync_dtype], 1, kernel_name)
def assign(ref, value, output, kernel_name="assign"): """ algorithm: assign calculating: update 'ref' by assigning 'value' to it Parameters ---------- ref: dict dict of input_ref, include shape and dtype, value: dict dict of input_value, include shape and dtype, Must have the same shape and dtype as input_ref output: dict dict of output kernel_name : str cce kernel name, default value is assign Returns ------- None """ ref_shape = util.scalar2tensor_one(ref.get("shape")) value_shape = util.scalar2tensor_one(value.get("shape")) dtype = ref.get("dtype").lower() _check_params(ref_shape, value_shape, dtype, kernel_name) data_b = tvm.placeholder(value_shape, dtype=dtype, name='data_b') data_b_ub = tvm.compute(value_shape, lambda *i: data_b(*i), name='data_b_ub') data_a = tvm.compute(ref_shape, lambda *i: data_b_ub(*i), name='data_a') sch = tvm.create_schedule(data_a.op) sch[data_b_ub].set_scope(cce.scope_ubuf) split_axis, split_factor = _tilling_axis(ref_shape, dtype, True) core_bind_axis, core_bind_split_factor = _core_bind_axis(ref_shape) if core_bind_axis < split_axis: core_bind_axis_outer, core_bind_axis_inner = sch[data_a].split( data_a.op.axis[core_bind_axis], factor=core_bind_split_factor) if core_bind_axis == 0: axis_outer = core_bind_axis_outer else: axis_outer = data_a.op.axis[0] for axis_index in range(1, core_bind_axis): axis_outer = sch[data_a].fuse(axis_outer, data_a.op.axis[axis_index]) axis_outer = sch[data_a].fuse(axis_outer, core_bind_axis_outer) axis_inner = core_bind_axis_inner for axis_index in range(core_bind_axis + 1, split_axis): axis_inner = sch[data_a].fuse(axis_inner, data_a.op.axis[axis_index]) tilling_axis_outer, tilling_axis_inner = sch[data_a].split( data_a.op.axis[split_axis], factor=split_factor) axis_inner = sch[data_a].fuse(axis_inner, tilling_axis_outer) else: if split_axis == 0: axis_outer, tilling_axis_inner = sch[data_a].split( data_a.op.axis[split_axis], factor=split_factor) core_num = _get_target_core_num(ref_shape[split_axis] // split_factor) axis_outer, axis_inner = sch[data_a].split(axis_outer, nparts=core_num) else: temp_shape = list(ref_shape[:split_axis]) temp_shape.append(ref_shape[split_axis] // split_factor) if split_axis == core_bind_axis and \ core_bind_split_factor > split_factor: core_bind_axis, core_bind_split_factor \ = _core_bind_axis(temp_shape) axis_outer, tilling_axis_inner = sch[data_a].split( data_a.op.axis[split_axis], factor=split_factor) if core_bind_axis == split_axis: core_bind_axis_outer, axis_inner \ = sch[data_a].split(axis_outer, factor=core_bind_split_factor) else: factor = ref_shape[split_axis] // split_factor core_bind_axis_outer, axis_inner \ = sch[data_a].split(axis_outer, factor=factor) axis_outer = data_a.op.axis[0] for axis_index in range(1, split_axis): axis_outer = sch[data_a].fuse(axis_outer, data_a.op.axis[axis_index]) axis_outer = sch[data_a].fuse(axis_outer, core_bind_axis_outer) else: core_bind_axis_outer, core_bind_axis_inner = sch[data_a].split( data_a.op.axis[core_bind_axis], factor=core_bind_split_factor) axis_outer = data_a.op.axis[0] for axis_index in range(1, core_bind_axis): axis_outer = sch[data_a].fuse(axis_outer, data_a.op.axis[axis_index]) axis_outer = sch[data_a].fuse(axis_outer, core_bind_axis_outer) axis_inner = core_bind_axis_inner tilling_axis_inner = axis_inner split_axis = core_bind_axis sch[data_a].bind(axis_outer, tvm.thread_axis('blockIdx.x')) sch[data_b_ub].compute_at(sch[data_a], axis_inner) sch[data_b_ub].emit_insn(data_b_ub.op.axis[split_axis], insn_cmd.DMA_COPY) sch[data_a].emit_insn(tilling_axis_inner, insn_cmd.DMA_COPY) with build_config: tvm.build(sch, [data_a, data_b], "cce", name=kernel_name)
def prior_box_d(feature, img, data_h, data_w, box_height, box_width, y, min_size, max_size, img_h=0, img_w=0, step_h=0.0, step_w=0.0, flip=True, clip=False, offset=0.5, variance=[0.1], kernel_name="prior_box"): """ calculating data Parameters ---------- input_x : dict shape and dtype of input output_y : dict shape and dtype of output, should be same shape and type as input kernel_name : str kernel name, default value is "prior_box" Returns ------- None """ """ TODO: Please refer to the TE DSL Manual, And code here with TE DSL. """ img_h, img_w, step_h, step_w = prior_box_check(feature, img, data_h, \ data_w, min_size, max_size, img_h, img_w, step_h, step_w, variance, kernel_name) shape_img = img.get("shape") shape_feature = feature.get("shape") dtype = feature.get("dtype") if img_h == 0 or img_w == 0: img_h = shape_img[INDEX_H] img_w = shape_img[INDEX_W] rec_img_h = 1.0 / img_h rec_img_w = 1.0 / img_w if step_h == 0 or step_w == 0: step_h = 1.0 * img_h / shape_feature[INDEX_H] step_w = 1.0 * img_w / shape_feature[INDEX_W] scale = 0.5 op_list, ins_list, tensor_dic, y, tensor_list = prior_box_compute(feature, img, data_h, \ data_w, box_height, box_width, y, rec_img_h, rec_img_w, step_h, step_w, \ clip, offset, scale, variance) UB_SIZE_LIMIT = \ tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.UB_SIZE) UB_SIZE_LIMIT = UB_SIZE_LIMIT / 21 schedule = tvm.create_schedule(y.op) dtype_bytes_size = tbe_platform.cce_intrin.get_bit_len(dtype) // 8 # 32 means one block size(32 Bytes), divide by 32 to get the numbers of data that # can be stored in one block. element = 32 // dtype_bytes_size align(schedule, op_list, tensor_dic, clip, element, 0) # muti core npart_0, npart_1, npart_2, npart_3, npart_4, split_axis_0, split_size, fuse_num = \ multicore_factor_calculate(tensor_dic.get("y").shape, element) xr1o, xr1i = schedule[y].split(y.op.axis[0], nparts=npart_0) xr2o, xr2i = schedule[y].split(y.op.axis[1], nparts=npart_1) xho, xhi = schedule[y].split(y.op.axis[2], nparts=npart_2) xwo, xwi = schedule[y].split(y.op.axis[3], nparts=npart_3) xno, xni = schedule[y].split(y.op.axis[4], nparts=npart_4) schedule[y].reorder(xr1o, xr2o, xho, xwo, xno, xr1i, xr2i, xhi, xwi, xni, \ y.op.axis[5]) block_axis = schedule[y].fuse(xr1o, xr2o, xho, xwo, xno) schedule[y].bind(block_axis, tvm.thread_axis("blockIdx.x")) # tiling strategy split_flag, split_axis, split_factor = \ tiling_factor_calculate(tensor_dic.get("y").shape, split_axis_0, \ split_size, dtype, UB_SIZE_LIMIT, fuse_num) if split_flag: if split_axis == 0: xo, xi = schedule[y].split(xr1i, factor=split_factor) elif split_axis == 1: xo, xi = schedule[y].split(xr2i, factor=split_factor) elif split_axis == 2: xo, xi = schedule[y].split(xhi, factor=split_factor) elif split_axis == 3: xo, xi = schedule[y].split(xwi, factor=split_factor) elif split_axis == 4: xo, xi = schedule[y].split(xni, factor=split_factor) prior_compute(schedule, op_list, xo) buffer_mapping(schedule, op_list) double_buf(schedule, op_list) axis_list = get_ins_emit_axis(op_list, xi) ins_emit(schedule, op_list, axis_list, ins_list) else: # schedule optimize prior_compute(schedule, op_list, block_axis) buffer_mapping(schedule, op_list) double_buf(schedule, op_list) # instructions replace if split_axis_0 == 0: axis_list = get_ins_emit_axis(op_list, xr1i) elif split_axis_0 == 1: axis_list = get_ins_emit_axis(op_list, xr2i) elif split_axis_0 == 2: axis_list = get_ins_emit_axis(op_list, xhi) elif split_axis_0 == 3: axis_list = get_ins_emit_axis(op_list, xwi) elif split_axis_0 == 4 or split_axis_0 == 5: axis_list = get_ins_emit_axis(op_list, xni) ins_emit(schedule, op_list, axis_list, ins_list) with build_config: tvm.build(schedule, tensor_list, "cce", name=kernel_name)
def basic_rnn_cell_schedule(self, schedule_list): """ Compute at operate for ot Parameters ---------- schedule_list: list the output tensors need to schedule Returns ------- sch: tvm schedule schedule operator """ sch = tvm.create_schedule(schedule_list) batch_dim = int(self.datas["x"].shape[1]) input_dim = int(self.datas["x"].shape[0]) hidden_dim = int(self.datas["w_ho"].shape[0]) emit_cmd_list = self.emit_cmd tensors = self.tensor_list1.copy() tensors.update(self.tensor_list2) scope_list = self.scope_list for key in scope_list: sch[tensors[key]].set_scope(scope_list[key]) for key in emit_cmd_list: tensor = tensors[key] op_name = emit_cmd_list[key] if key == "ub_whh_ht_cont": sch[tensor].reorder(tensor.op.axis[2], tensor.op.axis[1], tensor.op.axis[0], tensor.op.axis[3]) sch[tensor].emit_insn(sch[tensor].op.axis[1], op_name) else: sch[tensor].emit_insn(sch[tensor].op.axis[0], op_name) tilling_info1 = get_tilling(batch_dim, input_dim, hidden_dim) mad_tensors_1 = { "l0c": tensors["l0c_wht_xt"], "l1_left": tensors["l1_x"], "l1_right": tensors["l1_w_xh"], "l0a": tensors["l0a_x"], "l0b": tensors["l0b_w_xh"], } matmul_schedule(sch, mad_tensors_1, tilling_info1, True) # matmul schedule for l0c_wht_xt sch[tensors["l0c_bias_h"]].reused_by(tensors["l0c_wht_xt"], tensors["l0c_wht_xt_bias_h"]) tilling_info2 = get_tilling(batch_dim, hidden_dim, hidden_dim) if self.expose_hidden: mad_tensors_2 = { "l0c": tensors["l0c_whh_ht"], "l1_left": tensors["l1_h_0"], "l1_right": tensors["l1_w_hh"], "l0a": tensors["l0a_h_0"], "l0b": tensors["l0b_w_hh"], } compute_at_axis = matmul_schedule(sch, mad_tensors_2, tilling_info2, False) if self.dtypes["h_0"] == "float32": sch[tensors["ub_h_0"]].compute_at(sch[mad_tensors_2["l0c"]], compute_at_axis) sch[tensors["h_0_fp16"]].compute_at(sch[mad_tensors_2["l0c"]], compute_at_axis) # split ht gm_ht = tensors["gm_ht"] m_o, m_i = sch[gm_ht].split(gm_ht.op.axis[1], factor=tilling_info2["m_l0"]) m_o_o, m_o_i = sch[gm_ht].split(m_o, factor=tilling_info2["m_l1"]) n_o, n_i = sch[gm_ht].split(gm_ht.op.axis[0], factor=tilling_info2["n_l0"]) n_o_o, n_o_i = sch[gm_ht].split(n_o, factor=tilling_info2["n_l1"]) sch[gm_ht].reorder(m_o_o, m_o_i, n_o_o, n_o_i, n_i, m_i, gm_ht.op.axis[2], gm_ht.op.axis[3]) ht_tensors = self.get_ht_tensors() for key in ht_tensors: sch[ht_tensors[key]].compute_at(sch[gm_ht], n_o_i) sch[gm_ht].emit_insn(sch[gm_ht].op.axis[2], "dma_copy") tilling_info2 = get_tilling(batch_dim, hidden_dim, hidden_dim) mad_tensors_3 = { "l0c": tensors["l0c_who_ht"], "l1_left": tensors["l1_ht"], "l1_right": tensors["l1_w_ho"], "l0a": tensors["l0a_ht"], "l0b": tensors["l0b_w_ho"], } compute_at_axis = matmul_schedule(sch, mad_tensors_3, tilling_info2, True) if self.dtypes["h_t"] == "float32": sch[tensors["ub_ht_new"]].compute_at(sch[mad_tensors_3["l0c"]], compute_at_axis) sch[tensors["ub_ht_fp16"]].compute_at(sch[mad_tensors_3["l0c"]], compute_at_axis) sch[gm_ht].compute_at(sch[mad_tensors_3["l0c"]], compute_at_axis) sch[tensors["l0c_bias_o"]].reused_by(tensors["l0c_who_ht"], tensors["l0c_who_ht_bias_o"]) tilling_info2 = get_tilling(batch_dim, hidden_dim, hidden_dim) # split ot gm_ot = tensors["gm_ot"] m_o, m_i = sch[gm_ot].split(gm_ot.op.axis[1], factor=tilling_info2["m_l0"]) m_o_o, m_o_i = sch[gm_ot].split(m_o, factor=tilling_info2["m_l1"]) n_o, n_i = sch[gm_ot].split(gm_ot.op.axis[0], factor=tilling_info2["n_l0"]) n_o_o, n_o_i = sch[gm_ot].split(n_o, factor=tilling_info2["n_l1"]) sch[gm_ot].reorder(m_o_o, m_o_i, n_o_o, n_o_i, n_i, m_i, gm_ot.op.axis[2], gm_ot.op.axis[3]) ot_tensors = self.get_ot_tensors() for key in ot_tensors: sch[ot_tensors[key]].compute_at(sch[gm_ot], n_o_i) sch[gm_ot].emit_insn(sch[gm_ot].op.axis[2], "dma_copy") res_empty = tensors["res_empty"] m_o, m_i = sch[res_empty].split(res_empty.op.axis[1], factor=tilling_info2["m_l0"]) m_o_o, m_o_i = sch[res_empty].split(m_o, factor=tilling_info2["m_l1"]) n_o, n_i = sch[res_empty].split(res_empty.op.axis[0], factor=tilling_info2["n_l0"]) n_o_o, n_o_i = sch[res_empty].split(n_o, factor=tilling_info2["n_l1"]) sch[res_empty].reorder(m_o_o, m_o_i, n_o_o, n_o_i, n_i, m_i, res_empty.op.axis[2], res_empty.op.axis[3]) sch[gm_ht].compute_at(sch[res_empty], m_o_i) sch[gm_ot].compute_at(sch[res_empty], m_o_i) sch[res_empty].emit_insn(sch[res_empty].op.axis[2], "phony_insn") bind, _ = sch[res_empty].split(m_o_o, factor=tilling_info2["block"]) sch[res_empty].bind(bind, tvm.thread_axis("blockIdx.x")) return sch
def __init__(self, ib_, dtype_list, shape_list, nbins): self.ir_builder = ib_ self.input_dtype = dtype_list[0] self.input_range_dtype = dtype_list[1] self.output_dtype = dtype_list[2] self.nbins = nbins self.data_shape, self.data_range_shape, _ = shape_list # if data_type not fp32 will vconv input data self.is_need_mid_dtype = False if self.input_dtype not in ("float32", ): self.is_need_mid_dtype = True self.mid_dtype = "float32" else: self.mid_dtype = self.input_dtype # get dtype size, float16 size = 2 byte / float32 size = 4 byte self.input_dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.input_dtype) // \ cce_params.VECTOR_INST_BLOCK_NUM self.output_dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.output_dtype) // \ cce_params.VECTOR_INST_BLOCK_NUM self.mid_dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.mid_dtype) // \ cce_params.VECTOR_INST_BLOCK_NUM # get one block data size, block align len, 1 block=16 fp16 and =8 fp32 self.input_align_len = cce_params.BLOCK_REDUCE_INT8 // \ self.input_dtype_size self.output_align_len = cce_params.BLOCK_REDUCE_INT8 // \ self.output_dtype_size self.mid_align_len = cce_params.BLOCK_REDUCE_INT8 // \ self.mid_dtype_size # get vector data size, 8 block =16*8 fp16 and =8*8 when fp32 self.input_vec_align_len = cce_params.VECTOR_INST_BLOCK_WIDTH // \ self.input_dtype_size self.output_vec_align_len = cce_params.VECTOR_INST_BLOCK_WIDTH // \ self.output_dtype_size self.mid_vec_align_len = cce_params.VECTOR_INST_BLOCK_WIDTH // \ self.mid_dtype_size # for set_vec_mask self.uint64_all_one = tvm.const(MAX_VALUE_UINT64, "uint64") # get run plat, mini or cloud self.compile_plat = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION") self.segment_size_calcu_histogram = SEGMENT_SIZE_CALCU_HISTOGRAM if self.compile_plat in ("Ascend310", ): self.segment_size_calcu_histogram = \ SEGMENT_SIZE_CALCU_HISTOGRAM_MINI # cce pipe stri PIPE_ALL self.args_str = tvm.call_pure_intrin("int32", "tvm_cce_string_print", "PIPE_ALL") self.deqscale = tvm.call_pure_intrin("float16", "tvm_cce_string_print", "(half)1.000000e+00f") # tmp params for compute self.offset = 0 self.out_begin = 0 self.out_end = 0 self.ub_size = 0 self.src_ub = kernel_api.ib_new_alloc(self.ir_builder, self.input_dtype, [SEGMENT_SIZE_COPY_GM_TO_UB], "src_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(SEGMENT_SIZE_COPY_GM_TO_UB, self.input_dtype_size) self.range_src_ub = kernel_api.ib_new_alloc( self.ir_builder, self.input_range_dtype, [self.input_align_len], "range_src_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(self.input_align_len, self.mid_dtype_size) # offset: for vcadd self.vcadd_ub = kernel_api.ib_new_alloc( self.ir_builder, self.mid_dtype, [self.segment_size_calcu_histogram], "vcadd_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(self.segment_size_calcu_histogram, self.mid_dtype_size) # offset:output in des ub if self.input_dtype != self.mid_dtype: self.src_mid_input_ub = \ kernel_api.ib_new_alloc(self.ir_builder, self.mid_dtype, [SEGMENT_SIZE_COPY_GM_TO_UB], "src_mid_input_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(SEGMENT_SIZE_COPY_GM_TO_UB, self.mid_dtype_size) self.src_mid_input_range_ub = \ kernel_api.ib_new_alloc(self.ir_builder, self.mid_dtype, [self.mid_vec_align_len], "src_mid_input_range_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(self.mid_vec_align_len, self.mid_dtype_size) else: self.src_mid_input_ub = self.src_ub self.src_mid_input_range_ub = self.range_src_ub self.reg = self.ir_builder.allocate(self.mid_dtype, (7, ), name="range_data", scope=cce_params.scope_reg) self.mask = \ self.ir_builder.allocate("uint64", (4,), name="mask", scope=cce_params.scope_reg) self.set_mask_list = \ self.ir_builder.allocate("uint64", (64,), name="mask", scope=cce_params.scope_reg) # for preprocess _shape = \ [(((SEGMENT_SIZE_COPY_GM_TO_UB - 1) // self.mid_vec_align_len) + 1) * self.mid_vec_align_len] self.range0_ub = kernel_api.ib_new_alloc(self.ir_builder, self.mid_dtype, [self.mid_vec_align_len * 2], "range0_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(self.mid_vec_align_len * 2, self.mid_dtype_size) # get output num per core self.device_core_num = \ tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.CORE_NUM) self.ub_total_size = \ tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.UB_SIZE) self.max_output_size = (self.ub_total_size * 0.8 - self.ub_size) // self.mid_dtype_size // 6 self.tmp_out_num_per_core = \ ((self.nbins - 1) // self.device_core_num) + 1 self.out_num_per_core = \ ((self.tmp_out_num_per_core - 1) // self.output_align_len + 1)*self.output_align_len if self.out_num_per_core >= self.max_output_size: self.out_num_per_core = \ int((self.max_output_size // self.mid_vec_align_len)*self.mid_vec_align_len) self.is_same = 0 if self.nbins % self.out_num_per_core == 0 else 1 self.core_num = self.nbins // self.out_num_per_core + self.is_same # offset:output in src ub _shape = \ (((self.out_num_per_core + 1 - 1) // self.mid_vec_align_len) + 1) \ * self.mid_vec_align_len self.src_output_ub = kernel_api.ib_new_alloc( self.ir_builder, self.mid_dtype, [_shape], "src_output_ub", scope=tbe_platform.scope_ubuf) self.src_output_ub_p1 = kernel_api.ib_new_alloc( self.ir_builder, self.mid_dtype, [_shape], "src_output_ub_p1", scope=tbe_platform.scope_ubuf) self.get_ub_size(_shape + _shape, self.mid_dtype_size) _shape = (((self.out_num_per_core - 1) // self.output_align_len) + 1) * self.output_align_len self.des_output_ub = kernel_api.ib_new_alloc( self.ir_builder, self.output_dtype, [_shape], "des_output_ub", scope=tbe_platform.scope_ubuf) # offset:tmp output in des ub for cast to des dtype self.des_tmp_output_ub = kernel_api.ib_new_alloc( self.ir_builder, self.output_dtype, [_shape], "des_tmp_output_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(_shape + _shape, self.output_dtype_size) if self.compile_plat in ("Ascend310", ): self.src_fp16_output_ub = kernel_api.ib_new_alloc( self.ir_builder, "float16", [_shape], "src_fp16_output_ub", scope=tbe_platform.scope_ubuf) self.index_ub = kernel_api.ib_new_alloc( self.ir_builder, "float32", [self.mid_align_len], "index_ub", scope=tbe_platform.scope_ubuf) self.get_ub_size(_shape, 2) self.get_ub_size(self.mid_align_len, self.mid_dtype_size) # bind blockIdx.x self.block = tvm.thread_axis("blockIdx.x") self.ir_builder.scope_attr(self.block, "thread_extent", self.core_num)
def bn_training_reduce_schedule_nd(res, core_num=None): """bn_training_reduce schedule method""" cce_emitinsn_params.cceEmitParamsIns.clear_param() # Prepare extra tensors # Step 1: Get two output tensors # Step 2: Merge two output tensors into Dummy # Step 3: Move UB data to GM tensor output_first = res[0] # Square Sum output_second = res[1] # Sum final_output = tvm.compute( output_first.shape, lambda *indices: output_first(*indices) + output_second(*indices), name="DummyYummySweety") is_cast = False if "cast" in output_second.op.input_tensors[0].name: is_cast = True # Calculate block split factor by axis_n_size and core_num axis_n_size = int(res[0].shape[1]) if not core_num: core_num = int(cceconf.get_soc_spec("CORE_NUM")) # Multi core kernel requires aligned output element_size = cce_util.get_align_factor(output_first.dtype)[1] block_element_num = te.platform.cce_intrin_md.ALIGNMENT_BYTES // element_size estimate_block_split_factor = max(axis_n_size // core_num, 8) nearest_aligned_factor = estimate_block_split_factor % block_element_num # Decrease core_num for aligned output if estimate_block_split_factor < block_element_num and core_num > 1: return bn_training_reduce_schedule_nd(res, core_num - 1) # Round to the nearest block_split_factor = estimate_block_split_factor - nearest_aligned_factor # Calculate UB split ub_size = te.platform.CceProductParams().getParams("Unified_Buffer") // 2 reduce_data_num = 1 reduce_data_factor = 2 if is_cast: reduce_data_factor = 3 for reduce_axis in output_first.op.reduce_axis: reduce_data_num *= int(reduce_axis.dom.extent) reduce_data_num *= reduce_data_factor max_possible_loop = ub_size // (element_size * reduce_data_num) actual_loop = 1 for loop in range(max_possible_loop - 1, 0, -1): if block_split_factor % loop == 0: actual_loop = loop break # Force aligned if multi-core is enabled if actual_loop < block_element_num and actual_loop < block_split_factor and core_num > 1: actual_loop = block_element_num # Find all tensors if is_cast: # With Cast, prepare tensor parameters mul_tensor = output_first.op.input_tensors[0] cast_tensor = mul_tensor.op.input_tensors[0] res_input = cast_tensor.op.input_tensors[0] input_tensor_next = [cast_tensor ] # First compute tensor is cast_tensor ub_tensors = [cast_tensor, mul_tensor, output_first, output_second] else: # Without Cast, prepare tensor parameters cast_tensor = None mul_tensor = output_first.op.input_tensors[0] res_input = mul_tensor.op.input_tensors[0] input_tensor_next = [mul_tensor, output_second ] # First compute tensor is cast_tensor ub_tensors = [mul_tensor, output_first, output_second] # Create original schedule sch = tvm.create_schedule(final_output.op) # //////////////////////////////////// # ///////// DataFlow Control ///////// # //////////////////////////////////// # Read input in input_tensor_ub = sch.cache_read(res_input, cce_params.scope_ubuf, input_tensor_next) ub_tensors.append(input_tensor_ub) # Compute procedure in ubuf for ub_tens in ub_tensors: sch[ub_tens].set_scope(cce_params.scope_ubuf) # //////////////////////////////////// # //////// Split axis Control //////// # //////////////////////////////////// outer, inner = \ sch[final_output].split(sch[final_output].op.axis[1], factor=block_split_factor) ub_outer, ub_inner = sch[final_output].split(inner, factor=actual_loop) sch[final_output].bind(outer, tvm.thread_axis("blockIdx.x")) # //////////////////////////////////// # ///////// Compute Control ////////// # //////////////////////////////////// compute_at_axis = ub_outer for ub_tens in ub_tensors: sch[ub_tens].compute_at(sch[final_output], compute_at_axis) # //////////////////////////////////// # //////////// EmitInsn ////////////// # //////////////////////////////////// def emit_on_self(tensor, axisnum=0, op='dma_copy'): """Do emit insn""" sch[tensor].emit_insn(sch[tensor].op.axis[axisnum], op) def emit_on_self_ex(tensor, axis, op='dma_copy'): """Do emit insn""" sch[tensor].emit_insn(axis, op) # Fake results emit_on_self(input_tensor_ub, 0) if is_cast: emit_on_self(cast_tensor, 0, cast_tensor.op.tag.split('|')[0]) emit_on_self(mul_tensor, 0, mul_tensor.op.tag) sch[output_first].pragma(sch[output_first].op.axis[1], "emit_insn", "bn_reduce_sum") sch[output_second].pragma(sch[output_second].op.axis[1], "emit_insn", "bn_reduce_sum") sch[output_first].double_buffer() sch[output_second].double_buffer() emit_on_self_ex(final_output, ub_inner, "binary_reduce_output_reversed") def new_alloc(dtype, shape, name): """Alloc mem""" new_buffer = tvm.decl_buffer(shape, dtype, name=name, scope="", data=None) return new_buffer out_buffer_sec = new_alloc(final_output.dtype, (block_split_factor, ), "reduce_sec_output_gm") cce_emitinsn_params.cceEmitParamsIns.insert_param( "binary_reduce_output_buffer", out_buffer_sec) tensor_list = [res_input, final_output, out_buffer_sec] return sch, tensor_list