def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correction_mul"): """CorrectionMul op""" shape = x.get("shape") data_format = x.get("format") util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32"] inp_dtype = x.get("dtype").lower() if not inp_dtype in check_list: raise RuntimeError("Dtype of input only support float16, float32") # shape = util.shape_refine(shape) x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) shape_c = [1] * len(shape) shape_c[channel] = batch_std.get("ori_shape")[0] if data_format == "NC1HWC0" and channel == 1: shape_c = batch_std.get("shape") batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype) running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype) res = correction_mul_compute(x_t, batch_std_t, running_std_t, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, "tensor_list": [x_t, batch_std_t, running_std_t, res]} te.lang.cce.cce_build_code(sch, config)
def custom_l2_loss(shape, dtype, kernel_name="cce_tf_l2_loss", need_build=False, need_print=False): """ Computes half the L2 norm of a tensor without the sqrt: output = sum(t ** 2) / 2 Parameters ---------- shape : shape of data dtype : source data type, only support float16, float32 kernel_name : cce kernel name, default value is "cce_reductionLayer" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) util.check_reduce_shape_rule(shape) check_list = ["float16", "float32"] if not dtype.lower() in check_list: raise RuntimeError("tf_l2_loss_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) shape, axis = util.simplify_axis_shape(shape, range(len(shape))) inp_dtype = dtype.lower() data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype) coeff_sqrt = tvm.const(1.0 / (2**(0.5)), dtype=inp_dtype) data_mul = te.lang.cce.vmuls(data_input, coeff_sqrt) data_sqr = te.lang.cce.vmul(data_mul, data_mul) res = te.lang.cce.sum(data_sqr, axis) with tvm.target.cce(): sch = generic.auto_schedule(res) config = { "print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": [data_input, res] } te.lang.cce.cce_build_code(sch, config)
def custom_sign(shape, dtype, kernel_name="cce_custom_sign", need_build=False, need_print=False): """ x*32768 algrithm: sign = round(-------------------------) 2 ** (-15) + |x*32768| calculating data type is float16 Parameters ---------- shape : shape of data dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32, int32 kernel_name : cce kernel name, default value is "cce_sign" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32", "int32"] if not dtype.lower() in check_list: raise RuntimeError( "custom_sign_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) shape = util.shape_refine(shape) inp_dtype = dtype.lower() data = tvm.placeholder(shape, name="data", dtype=inp_dtype) with tvm.target.cce(): res = custom_sign_compute([data], shape, dtype, kernel_name, need_build, need_print) sch = generic.auto_schedule(res) config = { "print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": [data, res] } te.lang.cce.cce_build_code(sch, config)
def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"): """CorrectionMulGrad op""" shape_dout = dout.get("shape") shape_x = dout.get("shape") dtype_dout = dout.get("dtype") dtype_x = x.get("dtype") dtype_batch_std = batch_std.get("dtype") dtype_running_std = running_std.get("dtype") inp_dtype_dout = dtype_dout.lower() inp_dtype_x = dtype_x.lower() inp_dtype_batch_std = dtype_batch_std.lower() inp_dtype_running_std = dtype_running_std.lower() util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) util.check_dtype_rule(inp_dtype_x, ("float16", "float32")) util.check_dtype_rule(inp_dtype_batch_std, ("float16", "float32")) util.check_dtype_rule(inp_dtype_running_std, ("float16", "float32")) util.compare_tensor_dict_key(dout, x, "dtype") util.compare_tensor_dict_key(dout, x, "shape") util.compare_tensor_dict_key(dx, x, "shape") util.compare_tensor_dict_key(batch_std, running_std, "shape") util.compare_tensor_dict_key(dx, mul_dx, "shape") util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) data_format = dout.get("format") ori_format = dout.get("format") if data_format.upper() not in ("NC1HWC0", "NCHW"): raise RuntimeError("Un supported data format {}".format(data_format)) if data_format.upper() == "NCHW" and ori_format != "NCHW": raise RuntimeError("data_format(NCHW) must same as ori_format") shape_c = [1] * len(shape_x) shape_c[channel] = batch_std.get("ori_shape")[0] if data_format == "NC1HWC0" and channel == 1: shape_c = batch_std.get("shape") dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout) x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x) batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype_batch_std) running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype_running_std) res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t, running_std_t, channel, data_format, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res_list) tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} te.lang.cce.cce_build_code(sch, config)
def custom_logical_not(shape, dtype, kernel_name="cce_tf_logical_not", need_build=False, need_print=False): """ logical not for the input tensor Parameters ---------- shape : input shape of data dtype : the data type, support bool kernel_name : cce kernel name, default value is "cce_logical_not" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape) check_list = ["bool"] if not dtype.lower() in check_list: raise RuntimeError( "logical_not_cce ony supports %s while dtype is %s" % (",".join(check_list), dtype)) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) inp_dtype = dtype.lower() data = tvm.placeholder(shape, name="data", dtype=inp_dtype) with tvm.target.cce(): result = tvm.compute( shape, lambda *i: tvm.select(data[i] is True, False, True), name="result") schedule = tvm.create_schedule(result.op) if need_print: with build_config: print(tvm.lower(schedule, [data, result], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [data, result], "cce", name=kernel_name)
def segment_min(input_tensor, segment_ids, output_y, kernel_name="segment_min"): """ calculating data Parameters ---------- input_tensor : dict shape and dtype of input segment_ids : list int the list of segment_ids output_y : dict shape and dtype of output, kernel_name : str kernel name, default value is "segment_min" Returns ------- None """ shape_tensor = input_tensor.get("shape") dtype_tensor = input_tensor.get("dtype") input_tensor_dtype = dtype_tensor.lower() # judgement of ids length_ids = len(segment_ids) if length_ids != shape_tensor[0]: raise RuntimeError("length of ids must equal to shape[0] of input_tensor!") ids_is_1d_and_sorted(segment_ids) check_tuple_tensor = ("float16", "float32", "int32", "int8", "uint8") util.check_dtype_rule(dtype_tensor, check_tuple_tensor) util.check_shape_size(shape_tensor, SHAPE_SIZE_LIMIT) util.check_shape_rule(shape_tensor) # 校验轴 if dtype_tensor == "int8": data_input = tvm.placeholder(shape_tensor, name="data_input", dtype=input_tensor_dtype) data_input1 = te.lang.cce.cast_to(data_input, "float16") res1 = segment_min_compute(data_input1, segment_ids, output_y, kernel_name) res = te.lang.cce.cast_to(res1, "int8") else: data_input = tvm.placeholder(shape_tensor, name="data_input", dtype=input_tensor_dtype) res = segment_min_compute(data_input, segment_ids, output_y, kernel_name) with tvm.target.cce(): schedule = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_input, res]} te.lang.cce.cce_build_code(schedule, config)
def batchnorm_fold2(x, beta, gamma, batch_std, batch_mean, running_std, y, kernel_name="batchnorm_fold2"): """_BatchNormFold2 op""" shape = x.get("shape") util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32"] inp_dtype = x.get("dtype").lower() if not inp_dtype in check_list: raise RuntimeError("Dtype of input only support float16, float32") data_format = x.get("format") ori_format = x.get("ori_format") if data_format.upper() not in ("NC1HWC0", "NCHW"): raise RuntimeError("Un supported data format {}".format(data_format)) if data_format.upper() == "NCHW" and ori_format != "NCHW": raise RuntimeError("data_format(NCHW) must same as ori_format") shape_c = gamma.get("shape") if gamma.get("format").upper() == "NCHW": shape_c = 1, gamma.get("shape")[0], 1, 1 x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) beta_t = tvm.placeholder(shape_c, name="beta", dtype=inp_dtype) gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype) batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype) batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype) running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype) res = batchnorm_fold2_compute(x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) config = { "print_ir": False, "name": kernel_name, "tensor_list": [x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, res] } te.lang.cce.cce_build_code(sch, config)
def segment_max_d(x, y, segment_ids, kernel_name="segment_max_d"): """ Operation and Schedule for segment_max Parameters ---------- x : dict shape and dtype of input y: dict shape and dtype of output segment_ids : list should be the size of the first dimension kernel_name: str kernel name, default value is "segment_max_d" Returns ------- None """ shape = x.get("shape") dtype = x.get("dtype") util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32", "int32"] if dtype.lower() not in check_list: raise RuntimeError("segment_max only support float16, float32, int32") # when shape[0] > first_dim_size_threshold, # default stack space may not be enough, we need to prompt the user if shape[0] > FIRST_DIM_SIZE_THRESHOLD: print("Default stack space may not be enough.\ You shall increase the stack space.") dtype = dtype.lower() _check_segment_ids(shape, segment_ids) input_data = tvm.placeholder(shape, name="input_data", dtype=dtype) with tvm.target.cce(): res = segment_max_d_compute(input_data, y, segment_ids, kernel_name) sch = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [input_data, res]} te.lang.cce.cce_build_code(sch, config)
def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"): """CorrectionMulGradReduce op""" shape_dout = mul_dx.get("shape") shape_x = mul_dx.get("shape") dtype_dout = mul_dx.get("dtype") inp_dtype_dout = dtype_dout.lower() util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) data_format = mul_dx.get("format") ori_format = mul_dx.get("format") if data_format.upper() not in ("NC1HWC0", "NCHW"): raise RuntimeError("Un supported data format {}".format(data_format)) if data_format.upper() == "NCHW" and ori_format != "NCHW": raise RuntimeError("data_format(NCHW) must same as ori_format") shape_c = [1] * len(shape_x) shape_c[channel] = d_batch_std.get("ori_shape")[0] if data_format == "NC1HWC0" and channel == 1: shape_c = d_batch_std.get("shape") dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout) res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) tensor_list = [dout_t, res] config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config)
def strided_slice_two_turn_one(input_x, output_x, kernel_name): """ Returns ------- None """ input_shape = input_x.get("shape") input_dtype = input_x.get("dtype").lower() check_list = ("float16", "float32") util.check_dtype_rule(input_dtype, check_list) util.check_kernel_name(kernel_name) util.check_shape_rule(input_shape) util.check_shape_size(input_shape, SHAPE_SIZE_LIMIT) ss_last_dim = StridedSliceLastDim(input_x, output_x, kernel_name) return ss_last_dim.strided_slice_compute()
def custom_negative(shape, dtype, kernel_name="cce_custom_negative", need_build=False, need_print=False): """ calculate y = -x, calculating data type is float16 Parameters ---------- shape : shape of data dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32, int32 kernel_name : cce kernel name, default value is "cce_custom_negative" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32", "int32"] if not (dtype.lower() in check_list): raise RuntimeError("sqrt_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) caffe2_negative.caffe2_negative_cce(shape, dtype, kernel_name=kernel_name, need_build=need_build, need_print=need_print)
def mul_no_nan_compute(input_x1, input_x2, output_y, kernel_name="mul_no_nan"): """ calculating data Parameters ---------- input_x1 : TVM tensor the placeholder of input_x1 input_x2 : TVM tensor the placeholder of input_x2 output_y : dict dict of output_y, include keys(shape and dtype) kernel_name : str kernel name, default value is "mul_no_nan" Returns ------- output tensor """ """ np.where(np.equal(y, 0.), np.zeros((), dtype=dtype), np.multiply(x, y)) """ src_dtype = input_x1.dtype.lower() shape_x1 = te.lang.cce.util.shape_to_list(input_x1.shape) shape_x2 = te.lang.cce.util.shape_to_list(input_x2.shape) shape_x1, shape_x2, shape_max = util.produce_shapes(shape_x1, shape_x2) util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT) input_x1 = te.lang.cce.broadcast(input_x1, shape_max) input_x2 = te.lang.cce.broadcast(input_x2, shape_max) mul_res = te.lang.cce.vmul(input_x1, input_x2) zero = tvm.const(0, dtype=src_dtype) zeros = te.lang.cce.broadcast(zero, shape_max) res = te.lang.cce.vcmpsel(input_x2, zeros, operation='eq', slhs=zeros, srhs=mul_res) return res
def custom_subtract(shape_x, shape_y, dtype, kernel_name="cce_subtract", need_build=True, need_print=True): """ do element-wise subtract operation between two input tensors Parameters: ---------- shape_x : shape of input data1 shape_y : shape of input data2 dtype : source data type, support float16,float32,int32 kernel_name : cce kernel name, default value is "cce_subtract" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_rule(shape_y) util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32", "int32"] dtype = dtype.lower() if not (dtype in check_list): raise RuntimeError( "tf_subtract_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) print("######## shape") shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y) util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT) data1 = tvm.placeholder(shape_x, dtype=dtype, name="data1") data2 = tvm.placeholder(shape_y, dtype=dtype, name="data2") with tvm.target.cce(): data1_tmp1 = te.lang.cce.broadcast(data1, shape_max) data2_tmp1 = te.lang.cce.broadcast(data2, shape_max) res = te.lang.cce.vsub(data1_tmp1, data2_tmp1) sch = generic.auto_schedule(res) config = { "print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": [data1, data2, res] } te.lang.cce.cce_build_code(sch, config)
def custom_equal(shape_x, shape_y, dtype, kernel_name="cce_tf_equal", need_build=False, need_print=False): """ do element-wise equal operation between two input tensors Parameters: ---------- shape_x : shape of input x shape_y : shape of input y dtype : source data type, support float16,float32,int32,int8,uint8 kernel_name : cce kernel name, default value is "cce_tf_equal" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_rule(shape_y) check_list = ["float16", "float32", "int32", "int8", "uint8", "bool"] dtype = dtype.lower() if not (dtype in check_list): raise RuntimeError( "tf_equal_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT) shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y) util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT) x = tvm.placeholder(shape_x, dtype=dtype, name="x") y = tvm.placeholder(shape_y, dtype=dtype, name="y") x_tmp = te.lang.cce.broadcast(x, shape_max) y_tmp = te.lang.cce.broadcast(y, shape_max) res = tvm.compute(shape_max, lambda *i: x_tmp(*i) == y_tmp(*i), name='res') sch = tvm.create_schedule(res.op) if need_print: with build_config: print(tvm.lower(sch, [x, y, res], simple_mode=True)) if need_build: with build_config: tvm.build(sch, [x, y, res], "cce", name=kernel_name)
def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name="batchnorm_fold2_grad_reduce"): """_BatchNormFold2GradReduce op""" shape = x.get("shape") x_format = x.get("format") util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32"] inp_dtype = x.get("dtype").lower() if not inp_dtype in check_list: raise RuntimeError("Dtype of input only support float16, float32") dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype) x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype) res_list = batchnorm_fold2_grad_reduce_compute(dout_t, x_t, dout, kernel_name) if x_format == "NC1HWC0": with tvm.target.cce(): sch = generic.auto_schedule(res_list) tensor_list = [dout_t, x_t] + list(res_list) config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(sch, config) return from impl.bn_training_reduce import bn_training_reduce_schedule_nd sch, tensor_list = bn_training_reduce_schedule_nd(res_list) with build_config: tvm.build(sch, tensor_list, "cce", name=kernel_name)
def check_param_common(self): """ Check parameter Parameters ---------- None Returns ------- None """ util.check_kernel_name(self.kernel_name) util.check_shape_rule(self.indices_shape) util.check_shape_rule(self.grad_shape) util.check_shape_size(self.indices_shape, SHAPE_SIZE_LIMIT) util.check_shape_size(self.grad_shape, SHAPE_SIZE_LIMIT) check_list_indices_dtype = ("int32", "int64") util.check_dtype_rule(self.indices_dtype, check_list_indices_dtype) util.check_dtype_rule(self.grad_dtype, ("float32")) if self.grad_shape[1:] != self.var_shape[1:]: raise RuntimeError( "grad's shape must be the same as var's shape" " except first dimension") if len(self.indices_shape) != 1: raise RuntimeError( "indices must be one-dimensioal") if self.grad_shape[0] != self.indices_shape[0]: raise RuntimeError("grad must be the same shape as indices in " "first dimension")
def addcdiv(x1, x2, x3, y=None, alpha=1.0, kernel_name="addcdiv"): check_list = ("float16", "float32") shape_x1 = x1.get("shape") dtype_x1 = x1.get("dtype").lower() shape_x2 = x2.get("shape") dtype_x2 = x2.get("dtype").lower() shape_x3 = x3.get("shape") dtype_x3 = x3.get("dtype").lower() util.check_shape_rule(shape_x1) # 校验算子的shape,维度数需要大于等于1、小于等于8 util.check_shape_size(shape_x1, SHAPE_SIZE_LIMIT) # 校验算子第一个输入shape大小 util.check_dtype_rule(dtype_x1, check_list) # 校验算子的输入数据类型 util.check_shape_rule(shape_x2) util.check_shape_size(shape_x2, SHAPE_SIZE_LIMIT) util.check_dtype_rule(dtype_x2, check_list) util.check_shape_rule(shape_x3) util.check_shape_size(shape_x3, SHAPE_SIZE_LIMIT) util.check_dtype_rule(dtype_x3, check_list) if dtype_x1 != dtype_x2 or dtype_x1 != dtype_x3: raise RuntimeError("the type of x1, x2, x3 must be the same!") util.check_kernel_name(kernel_name) # 校验算子的kernel_name # 取shape_x1,shape_x2,shape_x3中每个维度的大值赋给shape_max shape_x2, shape_x3, shape_max = broadcast_shapes(shape_x2, shape_x3) util.check_tensor_shape_size(shape_max) # 对shape_max进行校验 shape_x1, _, shape_max = broadcast_shapes(shape_x1, shape_max) util.check_tensor_shape_size(shape_max) # 对shape_max进行校验 shape_x2, _, _ = broadcast_shapes(shape_x2, shape_max) # 将input_x的shape广播为shape_max shape_x3, _, _ = broadcast_shapes(shape_x3, shape_max) # 将input_y的shape广播为shape_max data_x1 = tvm.placeholder(shape_x1, name="data_x1", dtype=dtype_x1) data_x2 = tvm.placeholder(shape_x2, name="data_x2", dtype=dtype_x2) data_x3 = tvm.placeholder(shape_x3, name="data_x3", dtype=dtype_x3) res = addcdiv_compute(data_x1, data_x2, data_x3, shape_max, alpha, kernel_name) with tvm.target.cce(): schedule = generic.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_x1, data_x2, data_x3, res]} te.lang.cce.cce_build_code(schedule, config)
def fused_minimum_or_maximum_grad_cce( shape_dz, shape_x, shape_y, grad_x=True, grad_y=True, cmp_type="LE", dtype="float32", kernel_name="cce_fused_minimum_or_maximum_grad", need_build=False, need_print=False): """ algorithm: calculating minimum or maximum_grad of the two input data Parameters ---------- shape_dz: list or tuple. shape of data_inputdz shape_x: list or tuple. shape of data_inputx shape_y: list or tuple. shape of data_inputy grad_x: bool if grad_x is true,output need return dx grad_y: bool if grad_y is true,output need return dy cmp_type: str LessEqual or GreatEqual dtype: str the data type, assume src_dtype equals dst_dtype, only support float16, float32, int32 kernel_name: str cce kernel name, default value is "cce_fused_minimum_or_maximum_grad" need_build: bool if need to build CCEC kernel, default value is False need_print: bool if need to print the ir, default value is False Returns: ------- none. """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_rule(shape_y) shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y) util.check_shape_rule(shape_max) util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT) if list(shape_dz) != list(shape_max): raise RuntimeError( "fused_minimum_or_maximum_grad_cce shape_dz != shape_max") dtype = dtype.lower() if dtype not in ["float16", "float32", "int32"]: raise RuntimeError("fused_minimum_or_maximum_grad_cce only support" " float16, float32, int32") if (grad_x, grad_y) == (False, False): raise RuntimeError("grad_x and grad_x at least one is true") placeholders = [] placeholders.append(tvm.placeholder(shape_dz, name="input_dz", dtype=dtype)) placeholders.append(tvm.placeholder(shape_x, name="input_x", dtype=dtype)) placeholders.append(tvm.placeholder(shape_y, name="input_y", dtype=dtype)) outs = fused_minimum_or_maximum_grad_compute(placeholders, shape_x, shape_y, shape_dz, cmp_type, dtype) with tvm.target.cce(): if (grad_x, grad_y) == (True, False): sch = generic.auto_schedule(outs[0]) outs = [outs[0]] if (grad_x, grad_y) == (False, True): sch = generic.auto_schedule(outs[1]) outs = [outs[1]] if (grad_x, grad_y) == (True, True): sch = generic.auto_schedule(outs) config = { "print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": placeholders + outs } te.lang.cce.cce_build_code(sch, config)
def custom_round(shape, dtype, kernel_name="cce_round", need_build=False, need_print=False): """ doing round operations, calculating data type is float16 or float32 or int32 Parameters ---------- shape : shape of data dtype : the data type, assume src_dtype equals dst_dtype kernel_name : cce kernel name, default value is "cce_round" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ check_list = ["float16", "float32", "int32"] device_api_map = { "float16": "cc_device_round_float16", "float32": "cc_device_round_float", "int32": "cc_device_round_int32" } max_dim = 8 shape_len = len(shape) if shape_len > max_dim: raise RuntimeError( "round_cce only support up to %d dimensions while the shape's dimension is %d" % (max_dim, shape_len)) util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) if not (dtype.lower() in check_list): raise RuntimeError("round_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) inp_dtype = dtype.lower() shape = util.shape_refine(shape) data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype) device_api = device_api_map[inp_dtype] block_num = "block_num" block_idx = "block_idx" v_ndim = tvm.const(len(shape), "int32") padC0 = tvm.const(0, "int32") p_shape = util.create_param_ptr(shape, "int32", "p_shape") output = tvm.extern( shape, [data_input, p_shape], lambda ins, outs: tvm.call_extern( "int32_t", device_api, block_num, block_idx, v_ndim, ins[1].access_ptr("r"), # shape padC0, ins[0].access_ptr("r"), # input x outs[0].access_ptr("w")), name="output", dtype=inp_dtype) s = tvm.create_schedule(output.op) if need_print: with build_config: print(tvm.lower(s, [data_input, output], simple_mode=True)) if need_build: with build_config: tvm.build(s, [data_input, output], "cce", name=kernel_name)
def custom_pow(shape, shape_y, dtype, kernel_name="cce_tf_pow", need_build=False, need_print=False): """ calculate x^y, calculating data type is float16 or float32 or int32 when x < 0 , the output is a meaningless value. Parameters ---------- shape : shape of data dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32, int32 kernel_name : cce kernel name, default value is "tf_pow_cce" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ supported_dtypes = ["float16", "float32", "int32"] device_api = "cc_device_pow" util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) if not dtype.lower() in supported_dtypes: raise RuntimeError("tf_pow_cce only support %s while dtype is %s" % (",".join(supported_dtypes), dtype)) inp_dtype = dtype.lower() shape = util.shape_refine(shape) data_lhs = tvm.placeholder(shape, name="data_lhs", dtype=inp_dtype) data_rhs = tvm.placeholder(shape, name="data_rhs", dtype=inp_dtype) v_datatype = util.get_device_api_dtype(inp_dtype) v_ndim = len(shape) block_num = "block_num" block_idx = "block_idx" pad_c0 = 0 p_scale = util.create_param_ptr([0], inp_dtype, "p_scale") p_shift = util.create_param_ptr([0], inp_dtype, "p_shift") p_power = util.create_param_ptr([0], inp_dtype, "p_power") p_shape = util.create_param_ptr(shape, "int32", "p_shape") output = tvm.extern( shape, [data_lhs, data_rhs, p_scale, p_shift, p_power, p_shape], lambda ins, outs: tvm.call_extern( "int32_t", device_api, block_num, block_idx, v_datatype, ins[2].access_ptr("r"), # scale ins[3].access_ptr("r"), # shift ins[4].access_ptr("r"), # power v_ndim, ins[5].access_ptr("r"), # shape pad_c0, ins[0].access_ptr("r"), # input x v_ndim, v_ndim, ins[5].access_ptr("r"), # shape pad_c0, ins[1].access_ptr("r"), # input y outs[0].access_ptr("w")), name="output", dtype=inp_dtype) schedule = tvm.create_schedule(output.op) if need_print: with build_config: print( tvm.lower(schedule, [data_lhs, data_rhs, output], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [data_lhs, data_rhs, output], "cce", name=kernel_name)
def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): """ calculating matrix multiplication with bias, C = A*B + bias, support input data with fractal format. Parameters: shape_a: list or tuple Shape of the first tensor a with rank > 1 shape_b: list or tuple Shape of the second tensor b with the same type with a, and shape_a, shape_b must be 2 dims src_dtype: str The data type of input, support "float32", "float16" dst_dtype: str The data type of output, support "float32", "float16" trans_a: bool If True, shape_a == transposed before multiplication trans_b: bool If True, shape_b == transposed before multiplication is_fractal: bool If True, the input data format of a and b must be fractal format shape_bias: list or tuple Shape of bias, only support the input data format with ND Returns ------- None """ print("!!!!come into zzt~~~~~~~!!!!") shape_a = input_x1.get("ori_shape") shape_b = input_x2.get("ori_shape") shape_output = output_y.get("ori_shape") print("============") print(input_x1.get("format"), input_x2.get("format")) print(shape_a, shape_b) print("============") if input_x2.get("format") == "FRACTAL_Z": n, c, h, w = shape_b c0 = 16 c1 = c // c0 if c1 == 0: c1 = 1 shape_b = [n, c1 * h * w * c0] shape_a = [n, n] if input_x1.get("format") == "FRACTAL_Z": n, c, h, w = shape_a c0 = 16 c1 = c // c0 if c1 == 0: c1 = 1 shape_a = [n, c1 * h * w * c0] shape_b = [c1 * h * w * c0, c1 * h * w * c0] if input_x2.get("format") == "FRACTAL_NZ": shape_a = [shape_b[0], shape_b[0]] shape_b = shape_b if input_x1.get("format") == "FRACTAL_NZ": shape_a = shape_a shape_b = [shape_a[1], shape_a[1]] shape_a = list(shape_a) shape_b = list(shape_b) shape_a = _get_input_shape(shape_a) shape_b = _get_input_shape(shape_b) util.check_kernel_name(kernel_name) util.check_shape_rule(shape_a) util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) shape_a = [shape_a[1], shape_a[0]] trans_a = bool(1 - trans_a) shape_b = [shape_b[1], shape_b[0]] trans_b = bool(1 - trans_b) shape_bias = () if bias is not None and bool(bias): shape_bias = bias.get("shape") shape_bias = list(shape_bias) shape_bias = _get_bias(shape_bias) src_dtype = input_x1.get("dtype").lower() dst_dtype = output_y.get("dtype").lower() _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) m_shape = shape_a[len(shape_a) - 2] km_shape = shape_a[len(shape_a) - 1] kn_shape = shape_b[len(shape_a) - 2] n_shape = shape_b[len(shape_a) - 1] if src_dtype == "float16": block_reduce = cce.BLOCK_REDUCE block_in = cce.BLOCK_IN block_out = cce.BLOCK_OUT if trans_a and km_shape == 1: block_in = cce.BLOCK_VECTOR if not trans_a and m_shape == 1: block_in = cce.BLOCK_VECTOR if trans_b and kn_shape == 1: block_out = cce.BLOCK_VECTOR if not trans_b and n_shape == 1: block_out = cce.BLOCK_VECTOR if trans_a: shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) else: shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) if trans_b: shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) else: shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) format_a = "FRACTAL_NZ" shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) format_b = "FRACTAL_NZ" print("=======================================") print(shape_a_temp, shape_b_temp) print(format_a, format_b) print("=======================================") tensor_bias = None tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', dtype=src_dtype) tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', dtype=src_dtype) if shape_bias: tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', dtype=dst_dtype) if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[ 0] == 128 and shape_b_temp[1] == 63: if util.get_product_version() == util.VERSION_MINI: tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) else: tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) resMatmul = tik_instance.Tensor("float16", shape_output, name="output", scope=tik.scope_gm) with tik_instance.for_range(0, 32, block_num=32) as block_index: resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256, ), scope=tik.scope_ubuf, name="resMatmul_local_UB") resMatmul_local_UB_local_L0C = tik_instance.Tensor( "float32", (128 * 256, ), scope=tik.scope_cc, name="resMatmul_local_UB") input_1_local_L1_local_L0A = tik_instance.Tensor( "float16", (128 * 128, ), scope=tik.scope_ca, name="input_1_local_L1_local_L0A") input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256, ), scope=tik.scope_cbuf, name="input_2_local_L1") input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128, ), scope=tik.scope_cbuf, name="input_1_local_L1") input_2_local_L1_local_L0B = tik_instance.Tensor( "float16", (128 * 256, ), scope=tik.scope_cb, name="input_2_local_L1_local_L0B") core_m_idx = block_index % 8 core_n_idx = block_index // 8 with tik_instance.if_scope(core_m_idx != 7): tik_instance.data_move( input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128, 55 * 16, 0) tik_instance.data_move( input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, 32, 128, 55 * 16, 0) with tik_instance.for_range(0, 8) as cc12: tik_instance.load2dv1( input_1_local_L1_local_L0A[cc12 * 2048], input_1_local_L1[cc12 * 256], 0, 8, 8, 0, False) with tik_instance.for_range(0, 2) as cc6: with tik_instance.for_range(0, 8) as cc121: tik_instance.load2dv1( input_2_local_L1_local_L0B[cc121 * 4096], input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16, 8, 0, True) tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B, 128, 128, 256, 0) tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0, 1) tik_instance.data_move( resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], resMatmul_local_UB, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2) with tik_instance.else_scope(): tik_instance.data_move( input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112, 56 * 16, 0) tik_instance.data_move( input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, 32, 112, 56 * 16, 0) with tik_instance.for_range(0, 7) as cc10: tik_instance.load2dv1( input_1_local_L1_local_L0A[cc10 * 1792], input_1_local_L1[cc10 * 256], 0, 7, 7, 0, False) with tik_instance.for_range(0, 2) as cc5: with tik_instance.for_range(0, 7) as cc101: tik_instance.load2dv1( input_2_local_L1_local_L0B[cc101 * 4096], input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16, 7, 0, True) tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B, 112, 112, 256, 0) tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 112, 0, 0, 1) tik_instance.data_move( resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], resMatmul_local_UB, 0, 16, 224 // 2, 0, 56 * 16 * 2 // 2) tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resMatmul]) return tik_instance print("come into tbe, shape is error!") result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) with tvm.target.cce(): schedule = generic.auto_schedule(result) tensor_list = [tensor_a, tensor_b, result] if shape_bias: tensor_list = [tensor_a, tensor_b, tensor_bias, result] config = { "print_ir": False, "name": kernel_name, "tensor_list": tensor_list } te.lang.cce.cce_build_code(schedule, config)
def custom_Upsample(shape, dtype, scale, data_format="channels_last", kernel_name="cce_darknet_upsample", need_build=False, need_print=False): """ Parameters ---------- shape: input tensor's shape dtype: input tensor's dtype, support:`float16,float32 scale: the upsampling factors data_format: "channels_last" or "channels_first" kernel_name : kernel name, default value is "MyUpsample" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ """ TODO: Please refer to the TE DSL Manual, And code here with TE DSL. """ inp_dtype = dtype.lower() check_list = ["float16", "float32", "int32", "int8", "uint8"] if inp_dtype not in check_list: raise RuntimeError("upsample only support %s while dtype is %s" % (",".join(check_list), dtype)) util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) size = (scale, scale) shape_size = len(shape) if not (shape_size == 4 or shape_size == 5): raise RuntimeError( "upsample only support 4D or 5D while len(shape):%d" % len(shape)) input_tensor = tvm.placeholder(shape, name="input_tensor", dtype=inp_dtype) res = None if shape_size == 5: # shape_size == 5 D-sepecial (N, C1, H, W, C0) output_shape = (shape[0], shape[1], shape[2] * size[0], shape[3] * size[1], shape[4]) res = tvm.compute( output_shape, lambda n, c0, h, w, c: input_tensor[n, c0, h // size[ 0], w // size[1], c]) else: if data_format == "channels_last": output_shape = (shape[0], shape[1] * size[0], shape[2] * size[1], shape[3]) res = tvm.compute( output_shape, lambda n, h, w, c: input_tensor[n, h // size[0], w // size[1], c]) elif data_format == "channels_first": output_shape = (shape[0], shape[1], shape[2] * size[0], shape[3] * size[1]) res = tvm.compute( output_shape, lambda n, c, h, w: input_tensor[n, c, h // size[ 0], w // size[1]]) else: raise RuntimeError( "upsample only support channels_last|channels_first " "while input type %s" % data_format) schedule = tvm.create_schedule(res.op) if need_print: with build_config: print(tvm.lower(schedule, [input_tensor, res], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [input_tensor, res], "cce", name=kernel_name)
def custom_expm1(shape, dtype, kernel_name="cce_tf_expm1", need_build=False, need_print=False): """ algorithm: expm1 calculating data's expm1, y= (e ** x) - 1,dtype is float16 or float32. Parameters ---------- shape : shape of data. dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32. kernel_name : cce kernel name, default value is "cce_tf_expm1". need_buid : if need to build CCEC kernel, default value is False. need_print : if need to print the ir, default value is False. Returns ------- None """ # [aicpu] int32_t cc_device_exp(uint32_t blockNum, uint32_t blockIdx, int32_t dataType, const void *scale, const void *shift, # const void *base, int32_t dimCnt, int32_t *shape, uint32_t padC0, const void *x, void *y); supported_dtypes = ["float16", "float32"] util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) if not (dtype.lower() in supported_dtypes): raise RuntimeError("tf_expm1_cce only support %s while dtype is %s" % (",".join(supported_dtypes), dtype)) inp_dtype = dtype.lower() shape = util.shape_refine(shape) data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype) # step 1. calculate y = exp ** x by aicpu api device_api = "DeviceExp" v_datatype = util.get_device_api_dtype(inp_dtype) v_ndim = len(shape) block_num = "block_num" block_idx = "block_idx" padC0 = 0 p_scale = util.create_param_ptr([1], inp_dtype, "p_scale") p_shift = util.create_param_ptr([0], inp_dtype, "p_shift") p_base = util.create_param_ptr([-1], inp_dtype, "p_base") p_shape = util.create_param_ptr(shape, "int32", "p_shape") output_exp = tvm.extern( shape, [data_input, p_scale, p_shift, p_base, p_shape], lambda ins, outs: tvm.call_extern( "int32_t", device_api, block_num, block_idx, v_datatype, ins[1].access_ptr("r"), # scale ins[2].access_ptr("r"), # shift ins[3].access_ptr("r"), # base v_ndim, ins[4].access_ptr("r"), # shape padC0, ins[0].access_ptr("r"), # input x outs[0].access_ptr("w")), name="output_exp", dtype=inp_dtype) offset = tvm.const((-1), dtype=inp_dtype) # step 2. cauculate y = exp ** x - 1 by tvm output = tvm.compute( shape, lambda *indice: output_exp(*indice) + offset.astype(inp_dtype), name="output") # step 3. schedule the computation by tvm s = tvm.create_schedule(output.op) # step 4. build by tvm if need_print: with build_config: print(tvm.lower(s, [data_input, output], simple_mode=True)) if need_build: with build_config: tvm.build(s, [data_input, output], "cce", name=kernel_name)
def decode_bbox(box_predictions, anchors, decoded_boxes, decode_clip, kernel_name="decode_bbox"): """ calculating data Parameters ---------- box_predictions : shape and dtype of input anchors : shape and dtype of input decoded_boxes : shape and dtype of output, s hould be same shape and type as input decode_clip : decode_clip kernel_name : kernel name, default value is "decode_bbox" Returns ------- None """ # check param & data shape_box_predictions = box_predictions.get("shape") shape_anchors = anchors.get("shape") shape_decoded_boxes = decoded_boxes.get("shape") util.check_kernel_name(kernel_name) format_box_predictions = box_predictions.get("format") format_anchors = anchors.get("format") format_decoded_boxes = decoded_boxes.get("format") check_format_shape(format_box_predictions, format_anchors, format_decoded_boxes) util.check_shape_rule(shape_box_predictions, CONFIG_THREE, CONFIG_FOUR, None) util.check_shape_rule(shape_anchors, CONFIG_THREE, CONFIG_FOUR, None) util.check_shape_rule(shape_decoded_boxes, CONFIG_TWO, CONFIG_TWO, None) util.check_shape_size(shape_box_predictions, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_anchors, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_decoded_boxes, SHAPE_SIZE_LIMIT) util.check_dtype_rule(box_predictions.get("dtype").lower(), ("float16", )) util.check_dtype_rule(anchors.get("dtype").lower(), ("float16", )) util.check_dtype_rule(decoded_boxes.get("dtype").lower(), ("float16", )) if shape_box_predictions != shape_anchors: raise RuntimeError("the input shape_box_predictions and anchors)" "must be same") if (reduce(lambda x, y: x * y, shape_box_predictions[:])) \ != (reduce(lambda x, y: x * y, shape_decoded_boxes[:])): raise RuntimeError("the input shape (box_predictions and anchors" "is not equal to out shape(decoded_boxes)") if (shape_box_predictions[-1] == CONFIG_FOUR and len(shape_box_predictions) == CONFIG_THREE): if shape_decoded_boxes[1] != CONFIG_FOUR: raise RuntimeError("the output shape_decoded_boxes must be 4") else: if (shape_box_predictions[0] == CONFIG_FOUR and len(shape_box_predictions) == CONFIG_FOUR): if shape_decoded_boxes[0] != CONFIG_FOUR: raise RuntimeError("the output shape_decoded_boxes must be 4") else: raise RuntimeError("the input shape not in {(4,C,H,W), (H,W,4)}") if not isinstance(decode_clip, (float, int)): raise RuntimeError("input param type of decode_clip should be Float") if decode_clip < 0 or decode_clip > 10: raise RuntimeError( "input param decode_clip can't be negtive and shoud be [0,10]! ") # init the tiling shape print("shape_box_predictions", shape_box_predictions) shape = TilingFunc(shape_box_predictions) # calculate the deocede_bbox tik_instance = tik.Tik(tik.Dprofile()) data_tensor = InitTensor(tik_instance, shape) if shape.input_shape[-1] == CONFIG_FOUR \ and len(shape.input_shape) == CONFIG_THREE: decode_bbox_compute(tik_instance, shape, data_tensor, decode_clip, kernel_name) if shape.input_shape[0] == CONFIG_FOUR \ and len(shape.input_shape) == CONFIG_FOUR: decode_bbox_compute_transpose(tik_instance, shape, data_tensor, decode_clip, kernel_name) return tik_instance
def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): """check_supported""" shape_a = input_x1.get("shape") shape_b = input_x2.get("shape") print("shape_a: ", shape_a) print("shape_b: ", shape_b) src_dtype = input_x1.get("dtype") util.check_kernel_name(kernel_name) util.check_shape_rule(shape_a) util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) try: trans_a_f = bool(1 - trans_a) if src_dtype in ("float32", "int32"): if len(shape_a) != 2 and len(shape_b) != 2: return False if trans_b: if shape_b[0] == 1: return False else: if shape_b[1] == 1: return False if trans_a: if trans_b: if shape_a[0] != shape_b[1]: return False elif shape_a[0] != shape_b[0]: return False elif trans_b: if shape_a[1] != shape_b[1]: return False elif shape_a[1] != shape_b[0]: return False if trans_a_f and trans_b and shape_b[1] == 1: return False if src_dtype == "float16": if len(shape_a) != 2 and len(shape_b) != 2: return False if trans_a: m_shape = shape_a[1] k_shape = shape_a[0] else: m_shape = shape_a[0] k_shape = shape_a[1] if trans_b: n_shape = shape_b[0] k_b_shape = shape_b[1] else: n_shape = shape_b[1] k_b_shape = shape_b[0] if k_shape != k_b_shape: return False if m_shape == 1 or n_shape == 1: if k_shape % 256 != 0: return False except RuntimeError as e: print(e) return False return True
def custom_truncatemod(shape1, shape2, dtype, kernel_name="cce_tf_truncatemod", need_build=False, need_print=False): """ do element-wise truncatemod operation between two input tensors Parameters: ---------- shape1 : shape of input data1 shape2 : shape of input data2 dtype : source data type, support float16,float32,int32 kernel_name : cce kernel name, default value is "cce_tf_truncatemod" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ max_dim = 8 shape1_len = len(shape1) shape2_len = len(shape2) if shape1_len > max_dim or shape2_len > max_dim: raise RuntimeError( "mod_cce only support up to %d dimensions while the shape's \ dimensions is %d, %d" % (max_dim, shape1_len, shape2_len)) util.check_kernel_name(kernel_name) util.check_shape_rule(shape1) util.check_shape_rule(shape2) util.check_shape_size(shape1, SHAPE_SIZE_LIMIT) util.check_shape_size(shape2, SHAPE_SIZE_LIMIT) check_list = ["float16", "float32", "int32"] device_api_map = {"float16": "cc_device_truncatemod_float16", "float32": "cc_device_truncatemod_float", "int32": "cc_device_truncatemod_int32"} dtype = dtype.lower() if dtype not in check_list: raise RuntimeError( "tf_truncatemod_cce only support %s while dtype is %s" % ( ",".join(check_list), dtype)) shape1, shape2, shape_out = util.produce_shapes(shape1, shape2) util.check_shape_size(shape_out, SHAPE_SIZE_LIMIT) inp_dtype = dtype.lower() device_api = device_api_map[inp_dtype] # block block_num = "block_num" block_idx = "block_idx" # x param v_xndim_cnt = tvm.const(len(shape1), "int32") p_xshape = util.create_param_ptr(shape1, "int32", "p_xshape") xpad_c0 = tvm.const(0, "int32") data_input_x = tvm.placeholder(shape1, name="data_input_x", dtype=inp_dtype) # y param v_yndim_cnt = tvm.const(len(shape2), "int32") p_yshape = util.create_param_ptr(shape2, "int32", "p_yshape") ypad_c0 = tvm.const(0, "int32") data_input_y = tvm.placeholder(shape2, name="data_input_y", dtype=inp_dtype) # output v_out_ndim_cnt = tvm.const(len(shape_out), "int32") p_out_shape = util.create_param_ptr(shape_out, "int32", "p_yshape") out_padc0 = tvm.const(0, "int32") output = tvm.extern(shape_out, [p_xshape, data_input_x, p_yshape, data_input_y, p_out_shape], lambda ins, outs: tvm.call_extern("int32_t", device_api, block_num, block_idx, v_xndim_cnt, ins[0].access_ptr("r"), # shape x xpad_c0, ins[1].access_ptr("r"), # input x v_yndim_cnt, ins[2].access_ptr("r"), # shape y ypad_c0, ins[3].access_ptr("r"), # input y v_out_ndim_cnt, ins[4].access_ptr("r"), # shape out out_padc0, outs[0].access_ptr("w")), name="output", dtype=inp_dtype) schedule = tvm.create_schedule(output.op) # print IR if need_print: with build_config: print(tvm.lower(schedule, [data_input_x, data_input_y, output], simple_mode=True)) # Compile to generate the cce file if need_build: with build_config: tvm.build(schedule, [data_input_x, data_input_y, output], "cce", name=kernel_name)
def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): """ calculating matrix multiplication with bias, C = A*B + bias, support input data with fractal format. Parameters: shape_a: list or tuple Shape of the first tensor a with rank > 1 shape_b: list or tuple Shape of the second tensor b with the same type with a, and shape_a, shape_b must be 2 dims src_dtype: str The data type of input, support "float32", "float16" dst_dtype: str The data type of output, support "float32", "float16" trans_a: bool If True, shape_a == transposed before multiplication trans_b: bool If True, shape_b == transposed before multiplication is_fractal: bool If True, the input data format of a and b must be fractal format shape_bias: list or tuple Shape of bias, only support the input data format with ND Returns ------- None """ shape_a = input_x1.get("ori_shape") shape_b = input_x2.get("ori_shape") if shape_a is not None: if len(shape_a) < 2: shape_a = input_x1.get("shape") if shape_b is not None: if len(shape_b) < 2: shape_b = input_x2.get("shape") shape_a = list(shape_a) shape_b = list(shape_b) if input_x1.get("format") == "FRACTAL_NZ": shape_a = _get_input_shape(shape_a) shape_b = _get_input_shape(shape_b) util.check_kernel_name(kernel_name) util.check_shape_rule(shape_a) util.check_shape_rule(shape_b) util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) if input_x1.get("format") == "FRACTAL_NZ": shape_a = [shape_a[1], shape_a[0]] trans_a = bool(1 - trans_a) if input_x2.get("format") == "FRACTAL_NZ": shape_b = [shape_b[1], shape_b[0]] trans_b = bool(1 - trans_b) shape_bias = () if bias is not None and bool(bias): shape_bias = bias.get("shape") shape_bias = list(shape_bias) shape_bias = _get_bias(shape_bias) src_dtype = input_x1.get("dtype").lower() dst_dtype = output_y.get("dtype").lower() if src_dtype in ("float32", "int32"): matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b, shape_bias, kernel_name) return _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) m_shape = shape_a[len(shape_a) - 2] km_shape = shape_a[len(shape_a) - 1] kn_shape = shape_b[len(shape_a) - 2] n_shape = shape_b[len(shape_a) - 1] if src_dtype == "float16": block_reduce = cce.BLOCK_REDUCE block_in = cce.BLOCK_IN block_out = cce.BLOCK_OUT if trans_a and km_shape == 1: block_in = cce.BLOCK_VECTOR if not trans_a and m_shape == 1: block_in = cce.BLOCK_VECTOR if trans_b and kn_shape == 1: block_out = cce.BLOCK_VECTOR if not trans_b and n_shape == 1: block_out = cce.BLOCK_VECTOR if trans_a: shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) else: shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) if trans_b: shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) else: shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) if input_x1.get("format") == "FORMAT_FRACTAL_Z": shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) format_a = "fractal" elif input_x1.get("format") == "FRACTAL_NZ": shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) format_a = "FRACTAL_NZ" else: shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1]) format_a = "ND" if input_x2.get("format") == "FORMAT_FRACTAL_Z": shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) format_b = "fractal" elif input_x2.get("format") == "FRACTAL_NZ": shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) format_b = "FRACTAL_NZ" else: shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1]) format_b = "ND" tensor_bias = None tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', dtype=src_dtype) tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', dtype=src_dtype) if shape_bias: tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', dtype=dst_dtype) result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) with tvm.target.cce(): schedule = generic.auto_schedule(result) tensor_list = [tensor_a, tensor_b, result] if shape_bias: tensor_list = [tensor_a, tensor_b, tensor_bias, result] config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} te.lang.cce.cce_build_code(schedule, config)
def custom_Reduction(shape, dtype, axis, op, coeff, kernel_name="cce_reductionLayer", need_build=False, need_print=False): """ Reduce a tensor on a certain axis, and scale output with coeff Parameters ---------- shape : shape of data dtype : source data type, only support float16, float32, int8, uint8 axis : the first axis to reduce, may be negative to index from the end (e.g., -1 for the last axis). If axis == 0, the output Blob always has the empty shape (count 1), performing reduction across the entire input. op : can only be one of "SUM, ASUM (sum of abs), SUMSQ (sum of sqr), MEAN" coeff : scale for output kernel_name : cce kernel name, default value is "cce_reductionLayer" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape) check_list = ["float16", "float32", "int8", "uint8"] if not dtype.lower() in check_list: raise RuntimeError( "reductionLayer_cce only support %s while dtype is %s" % (",".join(check_list), dtype)) reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN") if not isinstance(axis, int): raise RuntimeError("type of axis value should be int") if op not in reduction_op: raise RuntimeError("op can only be one of SUM, ASUM, SUMSQ , MEAN") if not isinstance(coeff, int) and not isinstance(coeff, float): raise RuntimeError("coeff must be a value") axis_origin = axis shape_origin = shape axis = util.axis_check(len(shape), axis) util.check_reduce_shape_rule(shape) shape = list(shape) shape1 = shape[:axis] + [ functools_reduce(lambda x, y: x * y, shape[axis:]) ] shape1, axis = util.shape_refine(shape1, axis) if not axis: axis = [0] shape1 = [1] + shape1 inp_dtype = dtype.lower() data = tvm.placeholder(shape1, name="data_input", dtype=inp_dtype) with tvm.target.cce(): res = caffe_reduction_layer_compute([data], shape_origin, dtype, axis_origin, op, coeff, kernel_name, need_build, need_print) if op == "MEAN" and (inp_dtype == "int8" or inp_dtype == "uint8"): util.check_shape_size(shape, SHAPE_SIZE_LIMIT) res = te.lang.cce.cast_to(res, inp_dtype) schedule = tvm.create_schedule(res.op) if need_print: with build_config: print(tvm.lower(schedule, [data, res], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [data, res], "cce", name=kernel_name) else: with tvm.target.cce(): sch = generic.auto_schedule(res) config = { "print_ir": need_print, "need_build": need_build, "name": kernel_name, "tensor_list": [data, res] } te.lang.cce.cce_build_code(sch, config)
def custom_Exp(shape, dtype, gamma, alpha, beta, kernel_name="cce_exp", need_build=False, need_print=False): """ calculate gamma **(alpha * data + beta), calculate exp(log(gamma) * alpha * data) * (gamma ** beta) Parameters ---------- shape : shape of data dtype : the data type, assume src_dtype equals dst_dtype, only support \ float16, float32 gamma : the data type must be same with dtype parameter args in (alpha * data + beta) ** gamma, base alpha : the data type must be same with dtype parameter args in (alpha * data + beta) ** gamma, scale beta : the data type must be same with dtype parameter args in (alpha * data + beta) ** gamma, shift kernel_name : cce kernel name, default value is "cce_exp" need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ supported_dtypes = ["float16", "float32"] device_api = "DeviceExp" util.check_kernel_name(kernel_name) util.check_shape_rule(shape) util.check_shape_size(shape, SHAPE_SIZE_LIMIT) if not dtype.lower() in supported_dtypes: raise RuntimeError( "caffe_exp_layer_cce only support %s while dtype is %s" % (",".join(supported_dtypes), dtype)) if gamma != -1 and gamma <= 0: # api cc_device_exp_c handle gamma == -1 as e raise ValueError( "please ensure gamma is greater than 0, where gamma = %s" % str(gamma)) inp_dtype = dtype.lower() shape = util.shape_refine(shape) data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype) v_datatype = util.get_device_api_dtype(inp_dtype) v_ndim = len(shape) block_num = "block_num" block_idx = "block_idx" pad_c0 = 0 p_scale = util.create_param_ptr([alpha], inp_dtype, "p_scale") p_shift = util.create_param_ptr([beta], inp_dtype, "p_shift") p_base = util.create_param_ptr([gamma], inp_dtype, "p_base") p_shape = util.create_param_ptr(shape, "int32", "p_shape") # scale --> alpha, shitf --> beta, base --> gamma output = tvm.extern( shape, [data_input, p_scale, p_shift, p_base, p_shape], lambda ins, outs: tvm.call_extern( "int32_t", device_api, block_num, block_idx, v_datatype, ins[1].access_ptr("r"), # scale ins[2].access_ptr("r"), # shift ins[3].access_ptr("r"), # base v_ndim, ins[4].access_ptr("r"), # shape pad_c0, ins[0].access_ptr("r"), # input x outs[0].access_ptr("w")), name="output", dtype=inp_dtype) schedule = tvm.create_schedule(output.op) if need_print: with build_config: print(tvm.lower(schedule, [data_input, output], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [data_input, output], "cce", name=kernel_name)
def custom_batch_matmul(shape_x, shape_y, dtype, trans_a=False, trans_b=False, kernel_name="cce_tf_batch_matmul", need_build=False, need_print=False): """ Multiplies slices of two tensors in batches(each slice can be viewed as an element of a batch), the output is of the same batch size. Each of the individual slices can optionally be transposed before multiplication by setting the trans_a or trans_b flag to True, which are by default False. The input tensors are 2-D or higher with the shape [..., r_x, c_x] and [..., r_y, c_y]. The output tensor is 2-D or higher with the shape [..., r_o, c_o], where r_o = c_x if trans_a else r_x c_o = r_y if trans_b else c_y Parameters ---------- shape_x : shape of the first tensor x with rank > 1 shape_y : shape of the second tensor y with the same type and shape with x dtype : the data type, support int8, uint8,float16,float32,int32 kernel_name : cce kernel name, default value is "cce_batch_matmul" trans_a : if True, shape_x is transposed before multiplication trans_b : if True, shape_y is transposed before multiplication need_buid : if need to build CCEC kernel, default value is False need_print : if need to print the ir, default value is False Returns ------- None """ util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) util.check_shape_rule(shape_y) util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT) data_dtype = dtype.lower() check_list = ["int8", "uint8", "float16", "float32", "int32"] if data_dtype not in check_list: raise RuntimeError( "batch_matmul_cce ony supports %s while dtype is %s" % (",".join(check_list), dtype)) def transpose_tensor(shape, size): """Transpose the shape, e.g., the shape [..., r_x, c_x] is transposed to [..., c_x, r_x]. Parameters ---------- shape : shape of a tensor size : length of the shape Returns ------- shape_ori : the transposed shape """ shape_ori = () if size == 1: shape_ori = shape_ori + shape elif size == 2: shape_ori = shape_ori + (shape[1], ) + (shape[0], ) else: shape_ori = shape_ori + (shape[:(size - 2)]) + ( shape[size - 1], ) + (shape[size - 2], ) return shape_ori def check_matmul(shape_x, shape_y): """Check whether batch_matmul is supported or not. Parameters ---------- shape_x : shape of the first tensor x shape_y : shape of the second tensor y with the same type and shape with x Returns ------- None """ len_x = len(shape_x) len_y = len(shape_y) if (len_x < 2) or (len_y < 2): raise RuntimeError("Only tensors of rank>=2 are supported!") if shape_x[len_x - 1] != shape_y[len_y - 2]: raise RuntimeError( "Invalid matrix multiplication for the inner 2 dimensions!") if (len_x == len_y) and (len_x > 2): for i in range(len_x - 2): if shape_x[i] != shape_y[i]: raise RuntimeError("Outer dimensions do not match!") return elif (len_x == len_y) and (len_x == 2): return else: raise RuntimeError("The input tensors are not with the same rank!") def _compute(output_shape, x, y, K, trans_a, trans_b, *indices): """matmul compuation in terms of the output shape and the transposes Parameters ---------- output_shape : the final output shape, e.g., shape_x = (2, 6), shape_y = (8, 2), trans_a = True, True_b = True, then, output_shape = (6, 8). x : the first input tensor according to shape_x. y : the second input tensor according to shape_y. K : the number of the axis for sum, in the above example, K = 2. trans_a : if True, x needs to be transposed. trans_b : if True, y needs to be transposed. *indices : the output shape space for tvm.compute. Returns ------- tvm.Tensor """ n_len = len(output_shape) k = tvm.reduce_axis((0, K), 'k') if trans_a is True and trans_b is False: # For example, A: (6, 7, 8), B: (6, 7, 9), so the length is n = 3 # C = A' * B : (6, 8, 9), A' means the transpose of A # indices means the space of (6, 8, 9), k = 7 # x_indices = indices[:1]+(7, )+indices[1:2] = (6, 7, 8) # y_indices = indices[:1]+(7, )+indices[2:] = (6, 7, 9) x_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 2): (n_len - 1)] y_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 1):] return tvm.sum(x(*x_indices) * y(*y_indices), axis=k) elif not trans_a and trans_b: # For example, A: (6, 7, 8), B: (6, 9, 8), C = A * B' : (6, 7, 9) # indices means the space of (6, 7, 9), n=3, k = 8 # x_indices = indices[:2]+(8, ) = (6, 7, 8) # y_indices = indices[:1]+indices[2:]+(8, ) = (6, 9, 8) x_indices = indices[:(n_len - 1)] + (k, ) y_indices = indices[:(n_len - 2)] + indices[(n_len - 1):] + (k, ) return tvm.sum(x(*x_indices) * y(*y_indices), axis=k) elif trans_a and trans_b: # For example, A: (6, 8, 10), B: (6, 12, 8), C = A' * B' : \ # (6, 10, 12) # indices means the space of (6, 10, 12), n=3, k = 8 # x_indices = indices[:1]+(8, )+indices[1:2] = (6, 8, 10) # y_indices = indices[:1]+indices[2:]+(8, ) = (6, 12, 8) x_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 2): (n_len - 1)] y_indices = indices[:(n_len - 2)] + indices[(n_len - 1):] + (k, ) return tvm.sum(x(*x_indices) * y(*y_indices), axis=k) else: # For example, A: (6, 15, 16), B: (6, 16, 18), C = A * B : \ # (6, 15, 18) # indices means the space of (6, 15, 18), n=3, k = 16 # x_indices = indices[:2]+(16, ) = (6, 15, 16) # y_indices = indices[:1]+(16, )+indices[2:] = (6, 16, 18) x_indices = indices[:(n_len - 1)] + (k, ) y_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 1):] return tvm.sum(x(*x_indices) * y(*y_indices), axis=k) def check_supportted_shape_size(shape_x, shape_y, limit, trans_a, trans_b): """ check shape size for operator ---------- shape: shape of data limit: limit of the product Returns ------- None """ # This function is used to check whether the shape is too large to \ # cause a timeout. # shape_x = (a,b,c,d,e,k) shape_y = (a,b,c,d,k,f) # t_1 : time consumed by each addition operation # t_2 : time consumed by each multiplication operation # t_all : time consumed by a complete calculation # t_all is approximately equal to (a*b*c*d)*(e*k*f)*(t_1+t_2) # As (t_1 + t_2) is a constant, so t_all is proportional to \ # (a * b * c * d * e * k * f) len_x = len(shape_x) len_y = len(shape_y) if (len_x < 2) or (len_y < 2): raise RuntimeError("Only tensors of rank>=2 are supported!") shape_x = list(shape_x) shape_y = list(shape_y) tmp_shape_x = shape_x[:] if trans_a: tmp_shape_x = shape_x[:-2] + [shape_x[-1], shape_x[-2]] tmp_shape_y = shape_y[:] if trans_b: tmp_shape_y = shape_y[:-2] + [shape_y[-1], shape_y[-2]] union_shape = tmp_shape_x + [tmp_shape_y[-1]] union_size = reduce(lambda i, j: i * j, union_shape) if union_size > limit: raise RuntimeError("the shape is too large to calculate") if data_dtype in ["float16", "float32", "int32"]: type_shape_map = { 'float16': SHAPE_SIZE_FP16_LIMIT, 'float32': SHAPE_SIZE_FP32_LIMIT, 'int32': SHAPE_SIZE_INT32_LIMIT } check_supportted_shape_size(shape_x, shape_y, type_shape_map[data_dtype], trans_a, trans_b) x_size = len(shape_x) y_size = len(shape_y) shape_a = shape_x shape_b = shape_y if trans_a is True: shape_x = transpose_tensor(shape_x, x_size) if trans_b is True: shape_y = transpose_tensor(shape_y, y_size) check_matmul(shape_x, shape_y) last_axis = shape_x[x_size - 1] x_temp = tvm.placeholder(shape_a, name="input_1", dtype=data_dtype) y_temp = tvm.placeholder(shape_b, name="input_2", dtype=data_dtype) # output shape output_shape = () for i in range(x_size - 1): output_shape = output_shape + (shape_x[i], ) output_shape = output_shape + (shape_y[x_size - 1], ) result = tvm.compute( output_shape, lambda *indices: _compute(output_shape, x_temp, y_temp, last_axis, trans_a, trans_b, *indices), name="result") schedule = tvm.create_schedule(result.op) if need_print: with build_config: print( tvm.lower(schedule, [x_temp, y_temp, result], simple_mode=True)) if need_build: with build_config: tvm.build(schedule, [x_temp, y_temp, result], "cce", name=kernel_name)