def _compute_mini(data_input, shape): """ Use log and taylor to compute arctanh has the feature: arctanh(-abs(x)) = -arctanh(abs(x)) """ data_abs = topi.abs(data_input) result_ln = _compute_log(data_abs) result_taylor = _compute_taylor(data_abs) data_abs = topi.cast(data_abs, "float16") data_input = topi.cast(data_input, "float16") result_taylor = topi.cast(result_taylor, "float16") result_ln = topi.cast(result_ln, "float16") # when |x| < 0.5 using taylor computing, and when 0.5<|x|<1 using log() data_res = tvm.compute(shape, lambda *i : akg.tvm.expr.Select(data_abs(*i) < dc.half_const("float16"), result_taylor(*i), result_ln(*i)), name="le") # arctanh has the feature: arctanh(-abs(x)) = -arctanh(abs(x)) data_res_neg = topi.multiply(data_res, dc.neg_one_const("float16")) data_res = tvm.compute(shape, lambda *i : akg.tvm.expr.Select(data_input(*i) < dc.zero_const("float16"), data_res_neg(*i), data_res(*i)), name="neg") return data_res
def _compute_log(data_input, target=utils.CCE): """atanh(x) value is 0.5*log((1+x)/(1-x))""" data_1_sum_x = topi.add(data_input, dc.one_const(data_input.dtype)) data_sub_x = topi.multiply(data_input, dc.neg_one_const(data_input.dtype)) data_1_sub_x = topi.add(data_sub_x, dc.one_const(data_input.dtype)) data_x_mul = data_1_sum_x / data_1_sub_x data_x_log = log(data_x_mul, target) data_res = topi.multiply(data_x_log, dc.half_const(data_input.dtype)) return data_res
def _compute_log(data_input): """Atanh(x) = 0.5*log((1+x)/(1-x))""" data_1_sum_x = topi.add(data_input, dc.one_const(data_input.dtype)) data_sub_x = topi.multiply(data_input, dc.neg_one_const(data_input.dtype)) data_1_sub_x = topi.add(data_sub_x, dc.one_const(data_input.dtype)) data_x_mul = data_1_sum_x / data_1_sub_x data_x_log = log.log(data_x_mul) data_res = topi.multiply(data_x_log, dc.half_const(data_input.dtype)) return data_res
def fake_quant_with_min_max_vars_per_channel_compute(input_data, input_min, input_max, num_bits=8, narrow_range=False): """fake_quant_with_min_max_vars_per_channel compute implemention""" shape = get_shape(input_data.shape) dtype = input_data.dtype min_broadcast = akg.lang.ascend.broadcast(input_min, shape, dtype) max_broadcast = akg.lang.ascend.broadcast(input_max, shape, dtype) # get nudged_min and nudged_max by nudged_min_max_compute function nudged_min_nudged_max = nudged_min_max_compute(min_broadcast, max_broadcast, num_bits, narrow_range) # transform the input between nudged_max and nudged_min clamped_tmp = topi.minimum(input_data, nudged_min_nudged_max[1]) clamped = topi.maximum(clamped_tmp, nudged_min_nudged_max[0]) # calculate the quantized and dequantized results clamped_shifted = topi.subtract(clamped, nudged_min_nudged_max[0]) if product_is_mini(): clamped_shifted_div_scale = mul(clamped_shifted, reciprocal(nudged_min_nudged_max[2]), target=utils.CCE) else: clamped_shifted_div_scale = Divide(clamped_shifted, nudged_min_nudged_max[2], target=utils.CCE) result_tmp = topi.add(clamped_shifted_div_scale, dc.half_const(dtype)) floor_result_tmp = akg.lang.ascend.floor(result_tmp) if product_is_mini(): floor_result_tmp = topi.cast(floor_result_tmp, "float16") floor_result_tmp = topi.cast(floor_result_tmp, "float32") scale_product = topi.multiply(floor_result_tmp, nudged_min_nudged_max[2]) tmp_res = topi.add(scale_product, nudged_min_nudged_max[0]) # get bool_both_zero_value by bool_both_zero_compute function bool_both_zero_value = bool_both_zero_compute(min_broadcast, max_broadcast) res = topi.multiply(tmp_res, bool_both_zero_value) return res
def nudged_min_max_compute(min_broadcast, max_broadcast, num_bits, narrow_range): """ Calculate the maximum and minimum values of the quantization. Notes: Each channel scale[i] euqal to (max_broadcast[i] - min_broadcast[i]) / (quant_max - quant_min). Then compute nudged_zero_point: nudged_zero_point = floor(between_min_max_float + 0.5) + less_quant_min_float + more_quant_max_float, between_min_max_float is first calculated by: zero_point_from_min = (quant_min_float - min_broadcast) / scale, then between_min_max_float = zero_point_from_min, which min_broadcast <= zero_point_from_min <= max_broadcast. Besides, the value of less_quant_min_float is equal to quant_min or zero, zero_point_from_min < quant_min_float, the value is quant_min, else is 0. The same as more_quant_max_float. Finally according to scale and nudged_zero_point to compute nudged_min and nudged_max: nudged_min = (quant_min - nudged_zero_point) * scale nudged_max = (quant_max - nudged_zero_point) * scale Args: min_broadcast (tvm.tensor.Tensor): minimum value to be quantified for each channel. max_broadcast (tvm.tensor.Tensor): maximum value to be quantified for each channel. num_bits (int): num_bits is the bitwidth of the quantization, range [2,16]. narrow_range (bool): if True, for each channel, quantized into the quantization range [0, 2^num_bits - 1] else quantized into the quantization range [1, 2^num_bits - 1]. Returns: nudged_min (tvm.tensor.Tensor): The same type and shape as min_broadcast. nudged_max (tvm.tensor.Tensor): The same type and shape as max_broadcast. scale (tvm.tensor.Tensor): The same type and shape as max_broadcast. """ dtype = min_broadcast.dtype quant_min = 1 if narrow_range else 0 quant_max = (2**num_bits) - 1 # because of need compute each channel, so quant_min and quant_max need to broadcast. quant_min_float = topi.full(min_broadcast.shape, dtype, tvm.const(quant_min, dtype)) quant_max_float = topi.full(min_broadcast.shape, dtype, tvm.const(quant_max, dtype)) # caculate each channel max and min difference. max_sub_min = topi.subtract(max_broadcast, min_broadcast) quant_max_sub_quant_min = topi.subtract(quant_max_float, quant_min_float) # compute scale = (max_broadcast - min_broadcast) / (quant_max - quant_min) # and min_div_scale = min_broadcast / scale if product_is_mini(): scale = mul(max_sub_min, reciprocal(quant_max_sub_quant_min), target=utils.CCE) min_div_scale = Mul(min_broadcast, reciprocal(scale), target=utils.CCE) else: scale = Divide(max_sub_min, quant_max_sub_quant_min, target=utils.CCE) min_div_scale = Divide(min_broadcast, scale, target=utils.CCE) # zero_point_from_min = quant_min_float - min_broadcast / scale zero_point_from_min = topi.subtract(quant_min_float, min_div_scale) # if zero_point_from_min < quant_min_float, bool_less_quant_min_float = 1 else 0 bool_less_quant_min_float = less_compare_float32(zero_point_from_min, quant_min_float) # if quant_max_float < zero_point_from_min, bool_more_quant_max_float = 1 else 0 bool_more_quant_max_float = less_compare_float32(quant_max_float, zero_point_from_min) # according to above bool param to select effective value less_quant_min_float = topi.multiply(quant_min_float, bool_less_quant_min_float) more_quant_max_float = topi.multiply(quant_max_float, bool_more_quant_max_float) # compute which num is not less than quant_min_float and not large than quant_max_float tensor_one = topi.full(min_broadcast.shape, dtype, dc.one_const(dtype)) bool_not_less_quant_min_float = topi.subtract(tensor_one, bool_less_quant_min_float) bool_not_more_quant_max_float = topi.subtract(tensor_one, bool_more_quant_max_float) bool_between_min_max = topi.multiply(bool_not_less_quant_min_float, bool_not_more_quant_max_float) between_min_max_float = topi.multiply(zero_point_from_min, bool_between_min_max) # add 0.5 to num which min <= num <= max and then floor them. between_min_max_add_half_one = topi.add(between_min_max_float, dc.half_const(dtype)) between_min_max_round = akg.lang.ascend.floor(between_min_max_add_half_one) if product_is_mini(): between_min_max_round = topi.cast(between_min_max_round, "float16") between_min_max_round = topi.cast(between_min_max_round, "float32") # calculate the maximum and minimum values of the quantization nudged_zero_point_tmp = topi.add(less_quant_min_float, more_quant_max_float) nudged_zero_point = topi.add(nudged_zero_point_tmp, between_min_max_round) nudged_min_tmp = topi.subtract(quant_min_float, nudged_zero_point) nudged_max_tmp = topi.subtract(quant_max_float, nudged_zero_point) nudged_min = topi.multiply(nudged_min_tmp, scale) nudged_max = topi.multiply(nudged_max_tmp, scale) res = [nudged_min, nudged_max, scale] return res