def fake_learned_scale_quant_perchannel_grad_d_reduce( dout_alpha, dalpha, channel_axis, kernel_name="fake_learned_scale_quant_perchannel_grad_d_reduce"): """FakeLearnedScaleQuantPerChannelGradDReduce""" dout_alpha_shape = dout_alpha.get("shape") dout_alpha_dtype = dout_alpha.get("dtype") util.check_kernel_name(kernel_name) util.check_shape_rule(dout_alpha_shape) util.check_tensor_shape_size(dout_alpha_shape) check_list = ["float32", 'float16'] dout_alpha_dtype = dout_alpha_dtype.lower() util.check_dtype_rule(dout_alpha_dtype, check_list) dout_alpha_data = tvm.placeholder(dout_alpha_shape, name="dout_alpha", dtype=dout_alpha_dtype) res = fake_learned_scale_quant_perchannel_grad_d_reduce_compute( dout_alpha_data, dout_alpha, channel_axis, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [dout_alpha_data, res] config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config)
def __init__(self, src, dst): """ init CopyOnly parameters Parameters ---------- src : dict shape and dtype of input dst: dict shape and dtype of output, should be same shape and type as input Returns ------- None """ self.src_shape = src.get("shape") self.src_dtype = src.get("dtype").lower() self.dst_shape = dst.get("shape") self.dst_dtype = dst.get("dtype").lower() if self.dst_dtype == "bool": self.dst_dtype = "int8" self.data_size = util.check_tensor_shape_size(list(self.src_shape)) if len(self.dst_shape) == 0: self.data_dst_size = 1 self.dst_shape = [1] else: self.data_dst_size = \ util.check_tensor_shape_size(list(self.dst_shape)) if self.data_size != self.data_dst_size: raise RuntimeError("The size of src and des is not equal," " can not use fuc(CopyOnly)") # get dtype size, float16 size = 2 byte / float32 size = 4 byte self.dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.src_dtype) // 8 # get one block data size, block align len # the len in one block = 16 fp16 and = 8 fp32 self.data_len_one_block = 32 // self.dtype_size self.data_len_one_vector = self.data_len_one_block * 8 self.ub_availble = \ tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.UB_SIZE) - 8 * 1024 self.ub_max_data = self.ub_availble // self.dtype_size self.tik_instance = tik.Tik() self.core_num = \ tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.CORE_NUM) # input and output tensor in gm self.src_gm = self.tik_instance.Tensor(self.src_dtype, self.src_shape, name="src_gm", scope=tik.scope_gm) self.dst_gm = self.tik_instance.Tensor(self.dst_dtype, self.dst_shape, name="dst_gm", scope=tik.scope_gm) self.data_ub = None
def sqrt(input_x, output_y, kernel_name="sqrt"): """ calculating data Parameters ---------- input_x : dict shape and dtype of input output_y : dict shape and dtype of output, should be same shape and type as input kernel_name : str kernel name, default value is "sqrt" Returns ------- None """ """ TODO: Please refer to the TE DSL Manual, And code here with TE DSL. """ """ TODO: operator check """ """ TODO: operator compute, invoke sqrt_compute """ print("=================当你看到这句话时,说明我这个自定义sqrt算子被执行了============================") shape = input_x.get("shape") dtype = input_x.get("dtype") input_dtype = dtype.lower() util.check_shape_rule(shape) util.check_tensor_shape_size(shape) util.check_kernel_name(kernel_name) data_input = tvm.placeholder(shape, name="data_input", dtype=input_dtype) res = sqrt_compute(data_input, output_y, kernel_name) """ TODO: auto schedule """ with tvm.target.cce(): schedule = generic.auto_schedule(res) """ TODO: operator build """ config = {"name": kernel_name, "tensor_list": [data_input, res]} te.lang.cce.cce_build_code(schedule, config)
def _check_shape(shape_x, shape_sum, shape_square_sum, shape_scale, shape_offset): """ Function to check if the shape is in line with norms. Parameters ---------- shape_x: list or tuple x's data shape shape_sum: list or tuple sum's data shape shape_square_sum: list or tuple square_sum's data shape shape_scale: list or tuple scale's data shape shape_offset: list or tuple offset's data shape Returns ------- None """ util.check_shape_rule(shape_x) util.check_tensor_shape_size(shape_x) util.check_shape_rule(shape_sum) util.check_tensor_shape_size(shape_sum) util.check_shape_rule(shape_square_sum) util.check_tensor_shape_size(shape_square_sum) util.check_shape_rule(shape_scale) util.check_tensor_shape_size(shape_scale) util.check_shape_rule(shape_offset) util.check_tensor_shape_size(shape_offset) if len(shape_x) != 5 or len(shape_sum) != 5 \ or len(shape_square_sum) != 5 or len(shape_scale) != 5 \ or len(shape_offset) != 5: raise RuntimeError("The data format is 5HD, " "but some input's shape length is not 5") dim_c1 = shape_x[1] dim_c0 = shape_x[4] if shape_sum[1] != dim_c1 or shape_sum[4] != dim_c0: raise RuntimeError("Dimension C must be equal, but %s and %s" % (str(shape_x), str(shape_sum))) if shape_square_sum[1] != dim_c1 or shape_square_sum[4] != dim_c0: raise RuntimeError("Dimension C must be equal, but %s and %s" % (str(shape_x), str(shape_square_sum))) if shape_scale[1] != dim_c1 or shape_scale[4] != dim_c0: raise RuntimeError("Dimension C must be equal, but %s and %s" % (str(shape_x), str(shape_scale))) if shape_offset[1] != dim_c1 or shape_offset[4] != dim_c0: raise RuntimeError("Dimension C must be equal, but %s and %s" % (str(shape_x), str(shape_offset)))
def hwcn_2_fractal_z_c04(src, dst, src_format, dst_format, kernel_name="hwcn_2_fractal_z_c04"): """ algorithm: hwcn_2_fractal_z_c04 Parameters ---------- src: dict dict with keys(shape, dtype) of src dst: dict dict with keys(shape, dtype) of dst src_format: str data format of src dst_format: str data format of dst kernel_name: str kernel name, default value is "hwcn_2_fractal_z_c04" Returns ------- tik_instance: tik_instance """ src_shape = src.get("shape") src_dtype = src.get("dtype").lower() util.check_kernel_name(kernel_name) util.check_shape_rule(src_shape) util.check_tensor_shape_size(src_shape) check_list = ("float16") util.check_dtype_rule(src_dtype, check_list) if len(src_shape) != 4: raise RuntimeError("hwcn_2_fractal_z_c04 only support 4D " "while src shape is %s" % ", ".join(src_shape)) if src_shape[2] > 4: raise RuntimeError("hwcn_2_fractal_z_c04 only support C <= 4 " "while src shape is %s" % ", ".join(src_shape)) if src_format.upper() != "HWCN": raise RuntimeError("hwcn_2_fractal_z_c04 only support %s " "while src format is %s" % ("HWCN", src_format)) if dst_format.upper() != "FRACTAL_Z_C04": raise RuntimeError("hwcn_2_fractal_z_c04 only support %s " "while dst format is %s" % ("FRACTAL_Z_C04", dst_format)) src_shape = list(src_shape) hwcn_2_fractal_z_c04_template = HWCN2FRACTALZC04Compute( src_shape, src_dtype, kernel_name) return hwcn_2_fractal_z_c04_template.get_tik_instance()
def check_output_dim_with_ksize_stride(padding, input_gard_shape, y_shape, ksize, strides, dilation, ceil_mode): """ The common check rule for output dim and ksize and strides """ util.check_tensor_shape_size(ksize) util.check_tensor_shape_size(strides) if len(ksize) < ATTR_SHAPE_MIN or len(strides) < ATTR_SHAPE_MIN: raise RuntimeError( "The shape length of ksize or strides must be more than 4") if ksize[0] != 1 or ksize[3] != 1: raise RuntimeError( "MaxPoolGradWithArgmax only supports pooling across width/height," "and other ksize dimension should be one") if strides[0] != 1 or strides[3] != 1: raise RuntimeError( "MaxPoolGradWithArgmax only supports pooling across width/height," "and other strides dimension should be one") if ksize[1] * ksize[2] > 255: raise RuntimeError( "invalid window params, window_h*window_w should be <=255") input_height = y_shape[2] input_width = y_shape[3] input_batch = y_shape[0] xc1 = y_shape[1] xc0 = y_shape[4] output_height = input_gard_shape[2] output_width = input_gard_shape[3] windowh = ksize[1] windoww = ksize[2] dyn = input_gard_shape[0] dyc1 = input_gard_shape[1] dyc0 = input_gard_shape[4] pad_h = padding[1] pad_w = padding[2] stride_h = strides[1] stride_w = strides[2] dilation_h = dilation[1] dilation_w = dilation[2] dyh = _pooling_output_shape(input_height, windowh, pad_h, stride_h, dilation_h, ceil_mode) dyw = _pooling_output_shape(input_width, windoww, pad_w, stride_w, dilation_w, ceil_mode) if ksize[1] >= input_height or ksize[2] >= input_width: raise RuntimeError("can not support global pooling now") if dyh != output_height or dyw != output_width or \ input_batch != dyn or xc1 != dyc1 or xc0 != dyc0: raise RuntimeError("dimentions of dx dy \ padMode window stride is wrong,please check!")
def addcdiv(x1, x2, x3, y=None, alpha=1.0, kernel_name="addcdiv"): check_list = ("float16", "float32") shape_x1 = x1.get("shape") dtype_x1 = x1.get("dtype").lower() shape_x2 = x2.get("shape") dtype_x2 = x2.get("dtype").lower() shape_x3 = x3.get("shape") dtype_x3 = x3.get("dtype").lower() util.check_shape_rule(shape_x1) # 校验算子的shape,维度数需要大于等于1、小于等于8 util.check_shape_size(shape_x1, SHAPE_SIZE_LIMIT) # 校验算子第一个输入shape大小 util.check_dtype_rule(dtype_x1, check_list) # 校验算子的输入数据类型 util.check_shape_rule(shape_x2) util.check_shape_size(shape_x2, SHAPE_SIZE_LIMIT) util.check_dtype_rule(dtype_x2, check_list) util.check_shape_rule(shape_x3) util.check_shape_size(shape_x3, SHAPE_SIZE_LIMIT) util.check_dtype_rule(dtype_x3, check_list) if dtype_x1 != dtype_x2 or dtype_x1 != dtype_x3: raise RuntimeError("the type of x1, x2, x3 must be the same!") util.check_kernel_name(kernel_name) # 校验算子的kernel_name # 取shape_x1,shape_x2,shape_x3中每个维度的大值赋给shape_max shape_x2, shape_x3, shape_max = broadcast_shapes(shape_x2, shape_x3) util.check_tensor_shape_size(shape_max) # 对shape_max进行校验 shape_x1, _, shape_max = broadcast_shapes(shape_x1, shape_max) util.check_tensor_shape_size(shape_max) # 对shape_max进行校验 shape_x2, _, _ = broadcast_shapes(shape_x2, shape_max) # 将input_x的shape广播为shape_max shape_x3, _, _ = broadcast_shapes(shape_x3, shape_max) # 将input_y的shape广播为shape_max data_x1 = tvm.placeholder(shape_x1, name="data_x1", dtype=dtype_x1) data_x2 = tvm.placeholder(shape_x2, name="data_x2", dtype=dtype_x2) data_x3 = tvm.placeholder(shape_x3, name="data_x3", dtype=dtype_x3) res = addcdiv_compute(data_x1, data_x2, data_x3, shape_max, alpha, kernel_name) with tvm.target.cce(): schedule = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_x1, data_x2, data_x3, res]} te.lang.cce.cce_build_code(schedule, config)
def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx, epsilon=1e-5, is_training=True, freeze_bn=0, kernel_name="batchnorm_fold_grad"): """batchnorm_fold_grad op """ util.check_kernel_name(kernel_name) for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std): util.check_shape_rule(iv.get("shape")) util.check_tensor_shape_size(iv.get("shape")) check_tuple = ("float16", "float32") for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std): util.check_dtype_rule(iv.get("dtype").lower(), check_tuple) shape_x = x.get("shape") dtype_x = x.get("dtype") format_data = x.get("format").upper() if format_data not in ("NCHW", "NC1HWC0"): raise RuntimeError("Format of input only support 4D and 5HD") shape_mean = d_batch_mean.get("shape") dtype_mean = d_batch_mean.get("dtype").lower() if format_data == "NC1HWC0": if len(shape_x) != 5: raise RuntimeError("batchnorm_fold only support shape 5D" "when input format is NC1HWC0") shape_mean = (1, shape_x[1], 1, 1, shape_x[4]) elif format_data == "NCHW": if len(shape_x) < 2 or len(shape_x) > 4: raise RuntimeError("batchnorm_fold only support shape 2D to 4D") if shape_x[1] != shape_mean[0]: raise RuntimeError("data_format is NCHW, shape_bias must" "be equal to the second axis of shape_x") shape_mean = (1, shape_x[1],) for _ in range(2, len(shape_x)): shape_mean = shape_mean + (1,) d_batch_mean = tvm.placeholder(shape_mean, name="d_batch_mean", dtype=dtype_mean) d_batch_std = tvm.placeholder(shape_mean, name="d_batch_std", dtype=dtype_mean) data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_x.lower()) batch_mean = tvm.placeholder(shape_mean, name="batch_mean", dtype=dtype_mean) batch_std = tvm.placeholder(shape_mean, name="batch_std", dtype=dtype_mean) res = _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [d_batch_mean, d_batch_std, data_x, batch_mean, batch_std] + res config = {"name": kernel_name, "tensor_list": tensor_list} te.lang.cce.cce_build_code(sch, config)
def check_param(self): """ check the parameters :param var_out: :return: """ var_out_shape = self.var_out.get("shape") var_out_dtype = self.var_out.get("dtype").lower() if var_out_dtype == "bool": var_out_dtype = "int8" util.check_kernel_name(self.kernel_name) util.check_shape_rule(self.var_shape) util.check_shape_rule(self.indices_shape) util.check_shape_rule(self.updates_shape) util.check_shape_rule(var_out_shape) util.check_tensor_shape_size(self.var_shape) util.check_tensor_shape_size(self.indices_shape) util.check_tensor_shape_size(self.updates_shape) util.check_tensor_shape_size(var_out_shape) check_list_var = ("float16", "float32", "int32", "int8", "uint8") check_list_indices = "int32" util.check_dtype_rule(self.var_dtype, check_list_var) util.check_dtype_rule(self.indices_dtype, check_list_indices) util.check_dtype_rule(self.updates_dtype, check_list_var) util.check_dtype_rule(var_out_dtype, check_list_var) if var_out_shape != self.var_shape: raise RuntimeError( "var_out's shape must be the same as var's shape") if (self.updates_dtype != self.var_dtype or var_out_dtype != self.var_dtype): raise RuntimeError( "updates's datatype and var_out's datatype must be the" " same as var's datatype") if self.nd_flag: if len(self.indices_shape) < 2: raise RuntimeError( "the lenth of indices_shape must be large than 2") k = self.indices_shape[-1] updates_len = len(self.indices_shape) - 1 + len(self.var_shape) - k if k > len(self.var_shape): raise RuntimeError( "indices_shape[-1] can not be large than var's rank") if len(self.updates_shape) != updates_len: raise RuntimeError("the lenth of update must be len(indices_" "shape)-1+len(var_shape)-indices_shape[-1]") updates_true_shape = self.indices_shape[:-1] + self.var_shape[k:] else: updates_true_shape = self.var_shape[:self. axis] + self.indices_shape + self.var_shape[ self.axis + 1:] if self.updates_shape != updates_true_shape: raise RuntimeError("updates's shape is illegal")
def _check_parameters(src, dst, src_format, dst_format, kernel_name): """ check the parameters including src_shape, dst_shape, src_format, dst_format, dtype and kernel_name """ src_shape = src.get("shape") dst_shape = dst.get("shape") dtype = src.get("dtype") dtype_dst = dst.get("dtype") if src_format.lower() != "ndhwc": raise RuntimeError("src_format must be NDHWC !") if dst_format.lower() != "ndc1hwc0": raise RuntimeError("dst_format must be NDC1HWC0!") util.check_kernel_name(kernel_name) check_list = ("float16", ) util.check_dtype_rule(dtype, check_list) if dtype != dtype_dst: raise RuntimeError("dtype of src and dst are different !") util.check_shape_rule(src_shape, 5, 5) util.check_shape_rule(dst_shape, 6, 6) util.check_tensor_shape_size(src_shape) util.check_tensor_shape_size(dst_shape) if dst_shape[5] != 16: raise RuntimeError( "the last dimension of dst_shape is not 16, c0 must be 16 !") if dst_shape[0] != src_shape[0]\ or (dst_shape[1] != src_shape[1] and dst_shape[1] != src_shape[1] + 2)\ or dst_shape[3] != src_shape[2] or dst_shape[4] != src_shape[3]: raise RuntimeError("the shape of src and dst not match, " "the 1st,2nd,4th,5th dimension of dst_shape and " "the 1st,2nd,3rd,4th dimension of src_shape " "must be the same !") c_dst = src_shape[4] c_1 = dst_shape[2] c_0 = dst_shape[5] if not ((c_dst <= c_1 * c_0) and (c_dst > (c_1 - 1) * c_0)): raise RuntimeError("c must be less than or equal to c1*c0," "and greater than ((c1 - 1)*c0 )!")
def permute_tik(x, y, order=(0), kernel_name="permute_tik"): """ only support nchw->nhwc Parameters ---------- x : dict shape and dtype of input y : dict shape and dtype of output, should be same shape and type as input order: tuple, list axis transformation order kernel_name : str kernel name, default value is "permute_tik" Returns ------- None """ shape = x.get("shape") dtype = y.get("dtype") input_dtype = dtype.lower() supported_dtype = ["float16"] input_format = x.get("format") check_pass = False if input_format == 'NCHW': if len(order) == 4 and order[0] == 0 \ and order[1] == 2 and order[2] == 3 and order[3] == 1: check_pass = True if not check_pass: raise RuntimeError("only support nchw->nhwc") util.check_dtype_rule(input_dtype, supported_dtype) util.check_dtype_rule(dtype, supported_dtype) util.check_shape_rule(shape) util.check_tensor_shape_size(shape) util.check_kernel_name(kernel_name) input_dict = {"x": x, "y": y, "order": order} permute_process = Permute(input_dict) permute_process.permute_compute() permute_process.instance.BuildCCE(kernel_name=kernel_name, inputs=permute_process.x_gm, outputs=permute_process.y_gm) return permute_process.instance
def __init__(self, src, dst, src_format, dst_format): """ init MaxPoolWithargmax parameters Parameters ---------- bboxes : TVM tensor the placeholder of bboxes gtboxes : TVM tensor the placeholder of gtboxes overlap : dict shape and dtype of overlap result shape is [m, n] mode : str ('iou','iof') iou : the output is gtbox and bbox iou iof : Returns ------- None """ self.src_shape = src.get("shape") self.src_dtype = src.get("dtype").lower() self.src_format = src_format self.dst_shape = dst.get("shape") self.dst_dtype = dst.get("dtype").lower() if self.dst_dtype == "bool": self.dst_dtype = "int8" self.dst_format = dst_format self.data_size = util.check_tensor_shape_size(list(self.dst_shape)) # get dtype size, float16 size = 2 byte / float32 size = 4 byte self.dtype_size = \ cce.cce_intrin.get_bit_len(self.src_dtype) // 8 # get one block data size, block align len, 1 block = 16 fp16 and = 8 fp32 self.data_len_one_bloack = 32 // self.dtype_size self.data_len_one_vector = self.data_len_one_bloack * 8 self.ub_availble = \ cce.CceProductParams().getParams("Unified_Buffer") - 8*1024 self.ub_max_data = self.ub_availble // self.dtype_size profile = tik.Dprofile() self.tik_instance = tik.Tik(profile) self.core_num = profile.get_aicore_num() # input and output tensor in gm self.src_gm = self.tik_instance.Tensor(self.src_dtype, self.src_shape, name="src_gm", scope=tik.scope_gm) self.dst_gm = self.tik_instance.Tensor(self.dst_dtype, self.dst_shape, name="dst_gm", scope=tik.scope_gm) self.data_ub = None
def stn_compute(input_x, input_theta, input_offset, output_y, size=(-1, -1, -1, -1), align_corners=False, kernel_name="stn_compute"): """ spatial transformer by theta Parameters ---------- input_x : dict shape and dtype of input input_theta: dict auxiliary_coefficients input_offset: dict auxiliary_offset size: tuple output_size align_corners: bool false output_y : dict shape and dtype of output, should be same shape and type as input kernel_name : str kernel name, default value is "stn_compute" Returns ------- None """ shape = input_x.get("shape") util.check_shape_rule(shape) util.check_tensor_shape_size(shape) util.check_kernel_name(kernel_name) stn_instance = SpatialTransformer(input_x, input_theta, input_offset, kernel_name) stn_instance.spatial_transformer_compute() return stn_instance
def _check_parameters(src, dst, src_format, dst_format, kernel_name): """ check the parameters including src_shape, dst_shape, src_format, dst_format, dtype and kernel_name """ src_shape = src.get("shape") dst_shape = dst.get("shape") dtype = src.get("dtype") dtype_dst = dst.get("dtype") if src_format.lower() != "fractal_zn" and src_format.lower() != "fractal_z": raise RuntimeError("src_format must be FRACTAL_Zn !") if dst_format.lower() != "hwcn": raise RuntimeError("dst_format must be HWCN !") util.check_kernel_name(kernel_name) check_list = ("float16", "float32") util.check_dtype_rule(dtype, check_list) if dtype != dtype_dst: raise RuntimeError("dtype of src and dst are different !") util.check_shape_rule(src_shape, 4, 4) util.check_shape_rule(dst_shape, 4, 4) util.check_tensor_shape_size(src_shape) util.check_tensor_shape_size(dst_shape) if src_shape[2] != 16 or src_shape[3] != 16: raise RuntimeError( "ni and c0 must be 16 !") h_i, w_i, c_i, n_i = dst_shape c_0 = 16 c_1 = _ceil_div(c_i, c_0) src_one = c_1*h_i*w_i n_ni = 16 n_no = _ceil_div(n_i, n_ni) if list(src_shape) != [src_one, n_no, 16, 16]: raise RuntimeError("src_shape is wrong !")
def fake_quant_per_layer(x, min_val, max_val, y, symmetric, narrow_range, num_bits, kernel_name="fake_quant_per_layer"): """FakeQuantPerLayer""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") min_shape = util.scalar2tensor_one(min_shape) max_shape = util.scalar2tensor_one(max_shape) util.check_kernel_name(kernel_name) util.check_shape_rule(input_shape) util.check_shape_rule(min_shape, 1, 1, 1) util.check_shape_rule(max_shape, 1, 1, 1) util.check_tensor_shape_size(input_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) check_list = ["float32", "float16"] x_dtype = input_dtype.lower() min_dtype = min_dtype.lower() max_dtype = max_dtype.lower() util.check_dtype_rule(x_dtype, check_list) util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]), ) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) quant_min = 0 quant_max = 2**num_bits - 1 if narrow_range: quant_min = quant_min + 1 input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, quant_min, quant_max, symmetric, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [input_data, min_data, max_data, res] config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config)
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, ema, ema_decay, channel_axis, kernel_name="minmax_update_perchannel"): """MinMaxUpdatePerChannel op""" x_shape = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[ 1] == min_shape[0]: channel_axis_ = 1 else: channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_]) util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) check_list = ["float32", "float16"] x_dtype = x_dtype.lower() min_dtype = min_dtype.lower() max_dtype = max_dtype.lower() util.check_dtype_rule(x_dtype, check_list) util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) if channel_axis_ == 0: shape_c = min_val.get("ori_shape") else: shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, ema, ema_decay, channel_axis_) with tvm.target.cce(): sch = generic.auto_schedule(res_list) tensor_list = [input_data, min_data, max_data] + list(res_list) config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config)
def fake_learned_scale_quant_perlayer( input_x, alpha, quant_max, out, neg_trunc, kernel_name="fake_learned_scale_quant_perlayer"): """FakeLearnedScaleQuantPerLayer""" input_shape = input_x.get("shape") input_dtype = input_x.get("dtype") alpha_shape = alpha.get("ori_shape") alpha_dtype = alpha.get("dtype") quant_max_shape = quant_max.get("ori_shape") quant_max_dtype = quant_max.get("dtype") alpha_shape = util.scalar2tensor_one(alpha_shape) quant_max_shape = util.scalar2tensor_one(quant_max_shape) util.check_kernel_name(kernel_name) util.check_shape_rule(input_shape) util.check_shape_rule(alpha_shape, 1, 1, 1) util.check_shape_rule(quant_max_shape, 1, 1, 1) util.check_tensor_shape_size(input_shape) util.check_tensor_shape_size(alpha_shape) util.check_tensor_shape_size(quant_max_shape) check_list = ["float32", "float16"] input_dtype = input_dtype.lower() alpha_dtype = alpha_dtype.lower() quant_max_dtype = quant_max_dtype.lower() util.check_dtype_rule(input_dtype, check_list) util.check_dtype_rule(alpha_dtype, check_list) util.check_dtype_rule(quant_max_dtype, check_list) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]), ) input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype) alpha_data = tvm.placeholder(alpha_shape, name="alpha_data", dtype=alpha_dtype) quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) res = fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_data, neg_trunc, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [input_data, alpha_data, quant_max_data, res] config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list, "bool_storage_as_1bit": False } te.lang.cce.cce_build_code(sch, config)
def _param_check(shape_x, dtype_x, axis, kernel_name): """check param Parameters ---------- shape_x: list input shape dtype_x: str input dtype axis: int axis int num kernel_name: str kernel_name string Returns ------- None """ util.check_shape_rule(shape_x, max_dim=8) util.check_tensor_shape_size(shape_x) check_list = ("int32", "float32") util.check_dtype_rule(dtype_x.lower(), check_list) util.check_kernel_name(kernel_name)
def __init__(self, shape, dtype, split_dim, num_split, size_splits): """init SplitLastDim parameters """ self.src_shape = shape self.src_dtype = dtype self.data_size = util.check_tensor_shape_size(list(self.src_shape)) self.split_dim = split_dim self.num_split = num_split self.split_dim_size = self.src_shape[self.split_dim] self.data_size_first_dim = self.data_size // self.split_dim_size self.split_output_dim_size = \ self.src_shape[self.split_dim] // self.num_split self.output_size = \ self.split_output_dim_size * self.data_size_first_dim # get dtype size, float16 size = 2 byte / float32 size = 4 byte self.dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.src_dtype) // 8 # get one block data size, block align len # the len in one block = 16 fp16 and = 8 fp32 self.data_len_one_block = 32 // self.dtype_size self.data_len_one_vector = self.data_len_one_block * 8 self.ub_availble = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.UB_SIZE) - 8 * 1024 self.ub_max_data = self.ub_availble // self.dtype_size self.tik_instance = tik.Tik() self.core_num = tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.CORE_NUM) self.max_dims = 1 self.segment_len = 1 self.out_ub = None self.out_ub_1 = None self.index_reg = None self.index_reg_1 = None # input and output tensor in gm self.src_gm = self.tik_instance.Tensor( self.src_dtype, [self.data_size_first_dim, self.split_dim_size], name="src_gm", scope=tik.scope_gm) self.dst_gm_list = [] for _, i in enumerate(range(num_split)): dst_gm = self.tik_instance.Tensor( self.src_dtype, [self.data_size_first_dim, self.split_output_dim_size], name="dst_gm_" + str(i), scope=tik.scope_gm) self.dst_gm_list.append(dst_gm)
def __init__(self, src_shape, dtype, kernel_name): """ initialize some properties """ self.src_shape = src_shape self.dtype = dtype self.kernel_name = kernel_name self.dst_shape = [ (self.src_shape[0] * self.src_shape[1] * C0 + CUBE_SIZE - 1) // CUBE_SIZE, (self.src_shape[3] + CUBE_SIZE - 1) // CUBE_SIZE, CUBE_SIZE, CUBE_SIZE ] self.num_byte = SIZE_TWO_BYTES self.mask = MAX_MASK # the number of data that can be moved in each data_move self.num_data = DATA_MOVE_MIN_UNIT // self.num_byte util.check_shape_rule(self.dst_shape) util.check_tensor_shape_size(self.dst_shape) # the number of data that UB can put in self.ub_memory = min(TOTAL_UB_MEMORY, 252 * 1024) // self.num_byte // 2 self.src_gm = None self.dst_gm = None
def leaky_relu_grad(g, x, y, negative_slope=0, kernel_name="leaky_relu_grad"): """ calculate the backpropagation of leaky_relu operation y = gradients(x>0) or negative_slope*gradients(x<=0). support dtype:float16,float32 Parameters ---------- g : dict the backpropagated gradients to the corresponding leaky_relu operation x : dict the x passed as output of leaky_relu operation y : dict the output of leaky_relu back propagation negative_slope : float or int allow non-zero slope for negative inputs to speed up optimization kernel_name : str kernel name, default value is "leaky_relu_grad" Returns ------- None """ shape_g = g.get("shape") shape_x = x.get("shape") dtype_g = g.get("dtype").lower() dtype_x = x.get("dtype").lower() util.check_kernel_name(kernel_name) util.check_shape_rule(shape_g) util.check_shape_rule(shape_x) util.check_tensor_shape_size(shape_g) util.check_tensor_shape_size(shape_x) shape_list = util.produce_shapes(shape_g, shape_x) util.check_tensor_shape_size(shape_list[2]) # check input tensor data_type check_list = ["float16", "float32"] util.check_dtype_rule(dtype_g, check_list) util.check_dtype_rule(dtype_x, check_list) util.compare_tensor_dict_key(g, x, "dtype") shape_g, shape_x = refine_shapes_for_broadcast(shape_list[0], shape_list[1]) data_g = tvm.placeholder(shape_g, name="data_g", dtype=dtype_g) data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_g) res = leaky_relu_grad_compute(data_g, data_x, y, negative_slope, kernel_name) with tvm.target.cce(): schedule = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_g, data_x, res]} te.lang.cce.cce_build_code(schedule, config)
def fake_quant_perchannel(x, min_val, max_val, y, symmetric, narrow_range, num_bits, channel_axis, kernel_name="fake_quant_perchannel"): """FakeQuantPerChannel""" x_shape = x.get("shape") x_shape_ = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: channel_axis_ = 1 else: channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) check_list = ["float32", "float16"] x_dtype = x_dtype.lower() min_dtype = min_dtype.lower() max_dtype = max_dtype.lower() util.check_dtype_rule(x_dtype, check_list) util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) quant_min = 0 quant_max = 2 ** num_bits - 1 if narrow_range: quant_min = quant_min + 1 shape_c = [1] * len(x_shape) shape_c[channel_axis_] = min_val.get("ori_shape")[0] if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) res = fake_quant_perchannel_compute(input_data, min_data, max_data, y, quant_min, quant_max, symmetric, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [input_data, min_data, max_data, res] config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} te.lang.cce.cce_build_code(sch, config)
def check_param(input_x, axis, kernel_name): """ check the parameters is valid, if one is invalid,then raise error Parameters ---------- input_x: dict,shape and datatype axis: cumulative axis kernel_name: kernel_name Returns ------- None """ input_shape = input_x.get("shape") input_dtype = input_x.get("dtype").lower() util.check_kernel_name(kernel_name) util.check_shape_rule(input_shape) util.check_tensor_shape_size(input_shape) check_dtype(input_dtype, ("float16", "float32")) if axis < len(input_shape) * (-1) or axis >= len(input_shape): raise RuntimeError("axis must be in the range [%d, %d). but is %d " % (len(input_shape) * (-1), len(input_shape), axis))
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, ema, ema_decay, kernel_name="minmax_update_perlayer"): """MinMaxUpdatePerLayer op""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") min_shape = util.scalar2tensor_one(min_shape) max_shape = util.scalar2tensor_one(max_shape) util.check_kernel_name(kernel_name) util.check_shape_rule(input_shape) util.check_shape_rule(min_shape, 1, 1, 1) util.check_shape_rule(max_shape, 1, 1, 1) util.check_tensor_shape_size(input_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) check_list = ["float32", "float16"] x_dtype = input_dtype.lower() min_dtype = min_dtype.lower() max_dtype = max_dtype.lower() util.check_dtype_rule(x_dtype, check_list) util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]), ) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) with tvm.target.cce(): sch = generic.auto_schedule(res_list) tensor_list = [input_data, min_data, max_data] + list(res_list) config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config)
def __init__(self, input_value, output_data, split_dim, num_split, kernel_name): """init SplitWith5HD parameters """ self.data_dtype = input_value.get("dtype").lower() self.src_shape = input_value.get("shape") self.src_ori_shape = input_value.get("ori_shape") self.format = input_value.get("format") self.ori_format = input_value.get("ori_format") self.output_data = output_data self.src_size = util.check_tensor_shape_size(list(self.src_shape)) self.dst_size = self.src_size // num_split self.split_dim = split_dim self.num_split = num_split self.kernel_name = kernel_name self.split_dim_size = self.src_shape[self.split_dim] self.split_output_dim_size = \ self.src_shape[self.split_dim] // self.num_split self.data_size_first_dim = self.src_size // self.split_dim_size self.output_size = \ self.split_output_dim_size * self.data_size_first_dim # get dtype size, float16 size = 2 byte / float32 size = 4 byte self.dtype_size = \ tbe_platform.cce_intrin.get_bit_len(self.data_dtype) // 8 # get one block data size, block align len # the len in one block = 16 fp16 and = 8 fp32 self.data_len_one_block = 32 // self.dtype_size self.data_len_one_vector = self.data_len_one_block * 8 self.gm_out = [] self.gm_in = None self.tik_instance = None self.core_num = 0 self.input_c0 = 16 self.input_n = 0 self.input_c = 0 self.input_h = 0 self.input_w = 0 self.src_c1 = 0 self.des_c1 = 0 self.last_ori_dim = 0 self.split_out_dim = 0
def _shape_and_dtype_check(x, y_grad, target, weight, total_weight, reduction, kernel_name): x_shape = x.get("shape") x_dtype = x.get("dtype").lower() y_grad_shape = y_grad.get("shape") y_grad_dtype = y_grad.get("dtype").lower() target_shape = target.get("shape") target_dtype = target.get("dtype").lower() total_weight_shape = total_weight.get("shape") total_weight_dtype = total_weight.get("dtype").lower() weight_shape = weight.get("shape") weight_dtype = weight.get("dtype").lower() util.check_tensor_shape_size(weight_shape) util.check_shape_rule(weight_shape) util.check_shape_rule(x_shape) util.check_shape_rule(y_grad_shape) util.check_shape_rule(target_shape) util.check_tensor_shape_size(y_grad_shape) util.check_tensor_shape_size(target_shape) util.check_kernel_name(kernel_name) util.check_dtype_rule(x_dtype, "float32") util.check_dtype_rule(y_grad_dtype, "float32") util.check_dtype_rule(target_dtype, "int32") util.check_dtype_rule(weight_dtype, "float32") util.check_dtype_rule(total_weight_dtype, "float32") if reduction in ("mean", "sum") and y_grad_shape[0] != 1: raise RuntimeError("The shape of y_grad must be (1,)," " while reduction is mean or sum. ") if len(x_shape) == 1 and y_grad_shape[0] != 1: raise RuntimeError("The shape of y_grad must be (1,)," " while input x is 1D. ") if len(x_shape) > DIM2: raise RuntimeError("The dimension of x should be equal to" "or less than two.") if len(x_shape) == DIM2 and x_shape[0] != target_shape[0]: raise RuntimeError("The first dimension of x and" " target should be equal") if x_shape[-1] != weight_shape[0]: raise RuntimeError("The last dimension of x and the first dimension" " of weight should be equal") if len(y_grad_shape) != 1: raise RuntimeError("The dimension of y_grad should be 1D.") if len(weight_shape) != 1: raise RuntimeError("The dimension of weight should be 1D.") if len(target_shape) != 1: raise RuntimeError("The dimension of target should be 1D.") if total_weight_shape[0] != 1: raise RuntimeError("The shape of total_weight must be (1,)")
def check_param(x, grad, argmax, y, ksize, strides, padding, dtype, dilation, ceil_mode, kernel_name): """ check the parameters is valid, if one is invalid,then raise error Parameters ---------- x: dict,shape and datatype grad: dict,shape and datatype argmax: dict,shape and datatype y: dict,shape and datatype ksize: kernel or windows size,minimum length is 4, just like [1, poolingWindowH, poolingWindowW, 1] strides: stride , minimum length is 4, just like [1, poolingStrideH, poolingStrideW, 1] padding: pad mode Returns ------- None """ y_shape = x.get("shape") y_dtype = x.get("dtype").lower() y_dtype_arg = y.get("dtype").lower() input_gard_shape = grad.get("shape") grad_dtype = grad.get("dtype").lower() argmax_shape = argmax.get("shape") argmax_dtype = argmax.get("dtype").lower() util.check_shape_rule(y_shape) util.check_shape_rule(input_gard_shape) util.check_shape_rule(argmax_shape) util.check_kernel_name(kernel_name) check_shape_5hd(y_shape) check_shape_5hd(input_gard_shape) util.check_tensor_shape_size(input_gard_shape) util.check_tensor_shape_size(argmax_shape) util.check_tensor_shape_size(y_shape) util.check_dtype_rule(grad_dtype, ("float16", "float32", "int32")) util.check_dtype_rule(argmax_dtype, ("uint16")) util.check_dtype_rule(y_dtype, ("float16", "float32", "int32")) if y_dtype != grad_dtype or y_dtype_arg != y_dtype: raise RuntimeError("The dtype of tensor must be same") if dtype != DT_INT32 and dtype != DT_INT64: raise RuntimeError( "The dtype of input max indice must be int32 or int64") check_output_dim_with_ksize_stride(padding, input_gard_shape, y_shape, ksize, strides, dilation, ceil_mode)
def _shape_check(shape_x1, shape_x2, shape_tgt): # check whether the shape meets the broadcast requirements, and output broadcast shape try: _, _, x_shape = util.produce_shapes(shape_x1, shape_x2) except RuntimeError: raise RuntimeError("x1 and x2 can't be broadcast") x_shape_reduce = x_shape[:] x_shape_reduce.pop(1) try: _, _, tgt_shape = util.produce_shapes(x_shape_reduce, shape_tgt) except RuntimeError: raise RuntimeError("x and target can't be broadcast") min_dim = min(len(shape_x1), len(shape_x2), len(shape_tgt)) if min_dim >= 3: reduce_dim = -1 for i in range(-1, -min_dim, -1): if (shape_x1[i] == shape_x2) or (shape_x1[i] == shape_tgt[i]): reduce_dim = i else: break if reduce_dim != -1: shape_x1 = list(shape_x1[:reduce_dim]) + [ reduce(lambda x, y: x * y, shape_x1[reduce_dim:]) ] shape_x2 = list(shape_x2[:reduce_dim]) + [ reduce(lambda x, y: x * y, shape_x2[reduce_dim:]) ] shape_tgt = list(shape_tgt[:reduce_dim]) + [ reduce(lambda x, y: x * y, shape_tgt[reduce_dim:]) ] x_shape = list(x_shape[:reduce_dim]) + [ reduce(lambda x, y: x * y, x_shape[reduce_dim:]) ] tgt_shape = list(tgt_shape[:reduce_dim]) + [ reduce(lambda x, y: x * y, tgt_shape[reduce_dim:]) ] util.check_shape_rule(shape_x1) util.check_shape_rule(shape_x2) util.check_shape_rule(shape_tgt) util.check_tensor_shape_size(shape_x1) util.check_tensor_shape_size(shape_x2) util.check_tensor_shape_size(shape_tgt) return x_shape, tgt_shape, shape_x1, shape_x2, shape_tgt
def histogram_fixed_width_d(x, range, y, nbins, dtype="int32", kernel_name='histogram_fixed_width_d'): """this operation returns a rank 1 histogram counting the number of entries in `values` that fell into every bin. The bins are equal width and determined by the arguments `value_range` and `nbins`. Parameters ---------- x: dict dict info of input value, must include the keys(shape and dtype). range: dict dict info of input value_range, must include the keys(shape and dtype). the shape must be (2,) or [2] y: dict dict info of output nbins: int number of histogram bins. dtype: str data type for returned histogram. kernel_name: str cce kernel name, default value is "histogram_fixed_width" returns ------- None """ input_shape_list = [x.get("shape"), range.get("shape")] input_dtype = x.get("dtype") dtype_input = input_dtype.lower() check_shape(input_shape_list[0], param_name="x") check_shape(input_shape_list[1], param_name="range") util.compare_tensor_dict_key(x, range, "dtype") data_shape_size = util.check_tensor_shape_size(list(input_shape_list[0])) data_range_shape_size = util.check_tensor_shape_size( list(input_shape_list[1])) check_dtype(dtype_input, ("float16", "float32", "int32"), param_name="x") if data_range_shape_size != 2: raise RuntimeError("the shape of range must be (2,) or [2]") if nbins <= 0: raise RuntimeError("the nbins must be > 0") data = tvm.placeholder([data_shape_size], dtype=dtype_input, name="input_data") range_data = tvm.placeholder([data_range_shape_size], dtype=dtype_input, name="input_range_data") res = histogram_fixed_width_d_compute(data, range_data, y, nbins, kernel_name) sch = tvm.create_schedule(res.op) with build_config: tvm.build(sch, [data, range_data, res], "cce", name=kernel_name)
def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", kernel_name="batchnorm_fold"): """batchnorm_fold TBE op""" momentum = 1.0 - momentum util.check_kernel_name(kernel_name) data_format = data_format.upper() if data_format != "NCHW": raise RuntimeError("The data_format only support NCHW") shape_x = x.get("shape") shape_mean = mean.get("shape") shape_variance = variance.get("shape") dtype_x = x.get("dtype") dtype_mean = mean.get("dtype") dtype_variance = variance.get("dtype") for shape in (shape_x, shape_mean, shape_variance): util.check_shape_rule(shape) util.check_tensor_shape_size(shape) check_tuple = ("float16", "float32") for dtype in (dtype_x, dtype_mean, dtype_variance): util.check_dtype_rule(dtype.lower(), check_tuple) format_data = x.get("format").upper() if format_data not in ("NCHW", "NC1HWC0"): raise RuntimeError("Format of input only support 4D and 5HD") if format_data == "NC1HWC0": if len(shape_x) != 5: raise RuntimeError("batchnorm_fold only support shape 5D" "when input format is NC1HWC0") shape_mean = (1, shape_x[1], 1, 1, shape_x[4]) elif format_data == "NCHW": if len(shape_x) < 2 or len(shape_x) > 4: raise RuntimeError("batchnorm_fold only support shape 2D to 4D") if shape_x[1] != shape_mean[0]: raise RuntimeError("data_format is NCHW, shape_bias must" "be equal to the second axis of shape_x") shape_mean = (1, shape_x[1],) for _ in range(2, len(shape_x)): shape_mean = shape_mean + (1,) x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower()) x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower()) x_square_sum = tvm.placeholder(shape_mean, name="x_square_sum", dtype=dtype_x.lower()) mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower()) variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower()) shape_x = te.lang.cce.util.shape_to_list(x_input.shape) num = shape_x[0] * shape_x[2] * shape_x[3] num_rec = 1.0 / num # compute the mean of x batch_mean = te.lang.cce.vmuls(x_sum, num_rec) # compute the variance of x variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) mean_square = te.lang.cce.vmul(batch_mean, batch_mean) batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) if num == 1: batch_var_scaler = 0.0 else: batch_var_scaler = float(num) / (num - 1) batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon)) factor = 1.0 - momentum factor_reverse = momentum mean_mul = te.lang.cce.vmuls(batch_mean, factor) mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) var_mul = te.lang.cce.vmuls(batch_variance, factor) var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) y = te.lang.cce.vadds(x_input, 0.0) running_mean = te.lang.cce.vadds(mean, 0.0) running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon)) res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated] with tvm.target.cce(): sch = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res} te.lang.cce.cce_build_code(sch, config)