def fused_bn_reduce(data, layout, out_dtype): """ input: data: 4-D Tensor layout: input layout, only 'NCHW', 'NHWC' supported out_dtype: "float16" or "float32" output: out1_sum: 1-D tensor (C), sum on the axis "C" of input out2_squared_sum: 1-D tensor (C), sum of squared on the axis "C" of input """ if layout == "NCHW": data = topi.transpose(data, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError('Layout not supported {} '.format(layout)) inter_dtype = 'float32' data_cast = topi.cast(data, inter_dtype) out1_sum = topi.sum(data_cast, axis=(0, 1, 2)) out1_sum = topi.cast(out1_sum, out_dtype) squared = topi.multiply(data_cast, data_cast) out2_squared_sum = topi.sum(squared, axis=(0, 1, 2)) out2_squared_sum = topi.cast(out2_squared_sum, out_dtype) return [out1_sum, out2_squared_sum]
def fake_quant_with_min_max_vars_per_channel_gradient_compute( input_gradients, inputs_data, min_broadcast, max_broadcast, num_bits=8, narrow_range=False): """Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.""" shape = get_shape(inputs_data) sum_axis = [x for x in range(0, len(shape) - 1)] dtype = inputs_data.dtype nudged_min, nudged_max, _ = nudged_min_max_compute(min_broadcast, max_broadcast, num_bits, narrow_range) # both zero yields zero bool_both_zero_value = bool_both_zero_compute(min_broadcast, max_broadcast) bool_both_zero_negate = _bool_negate(bool_both_zero_value) bool_less_equal_nudged_max = _less_equal_compare_float32( inputs_data, nudged_max) bool_more_equal_nudged_min = _less_equal_compare_float32( nudged_min, inputs_data) bool_between_nudged_min_max = topi.multiply(bool_less_equal_nudged_max, bool_more_equal_nudged_min) # gradient is 1 if input in [min, max] else 0 backprops_input_tmp = topi.multiply(bool_between_nudged_min_max, input_gradients) backprops_bool_both_zero = topi.multiply(backprops_input_tmp, bool_both_zero_value) # if min and max are both zero, gradients is input_gradients input_gradients_both_zero = topi.multiply(input_gradients, bool_both_zero_negate) backprops_input = topi.add(backprops_bool_both_zero, input_gradients_both_zero) # gradients for min is input_gradients if inputs_data < nudged_min else 0 bool_less_nudged_min = _bool_negate(bool_more_equal_nudged_min) output_backprop_min_tmp = topi.multiply(bool_less_nudged_min, input_gradients) # gradients for min is 0 if min and max are both 0 output_backprop_min_bool = topi.multiply(output_backprop_min_tmp, bool_both_zero_value) if sum_axis == []: output_backprop_min = output_backprop_min_bool else: output_backprop_min = topi.sum(output_backprop_min_bool, sum_axis) # gradients for max is input_gradients if inputs_data > nudged_max else 0 bool_more_nudged_max = _bool_negate(bool_less_equal_nudged_max) output_backprop_max_tmp = topi.multiply(bool_more_nudged_max, input_gradients) # gradients for max is 0 if min and max are both 0 output_backprop_max_bool = topi.multiply(output_backprop_max_tmp, bool_both_zero_value) if sum_axis == []: output_backprop_max = output_backprop_max_bool else: output_backprop_max = topi.sum(output_backprop_max_bool, sum_axis) return backprops_input, output_backprop_min, output_backprop_max
def fused_relu_grad_bn_double_update_grad(data_1, data_2, data_3, data_4, data_5, data_6, data_7, layout='NHWC'): transform_list = [data_2, data_4, data_5, data_6, data_7] for i in transform_list: if layout == "NCHW": i = topi.transpose(i, axes=(0, 2, 3, 1)) elif layout != "NHWC": raise NotImplementedError( 'Layout not supported {} '.format(layout)) data_tmp1 = topi.full_like(data_7, 0.0) data_tmp2 = topi.greater(data_7, data_tmp1) data_tmp3 = topi.add(data_5, data_6) data_tmp4 = topi.where(data_tmp2, data_tmp3, data_tmp1) data_tmp5 = topi.cast(data_tmp4, 'float32') data_tmp7 = topi.sum(data_tmp5, axis=(0, 1, 2)) n, h, w, c = data_7.shape data_tmp8 = topi.cast(data_2, 'float32') data_tmp9 = topi.full_like(data_tmp7, 1.0/(n*h*w)) data_tmp10 = topi.multiply(data_1, data_tmp9) data_tmp11 = topi.broadcast_to(data_tmp10, data_tmp8.shape) data_tmp12 = topi.subtract(data_tmp8, data_tmp11) data_tmp13 = topi.multiply(data_tmp5, data_tmp12) data_tmp15 = topi.sum(data_tmp13, axis=(0, 1, 2)) data_tmp16 = topi.cast(data_4, 'float32') data_tmp17 = topi.multiply(data_3, data_tmp9) data_tmp18 = topi.broadcast_to(data_tmp17, data_tmp16.shape) data_tmp19 = topi.subtract(data_tmp16, data_tmp18) data_tmp20 = topi.multiply(data_tmp5, data_tmp19) data_tmp22 = topi.sum(data_tmp20, axis=(0, 1, 2)) return [data_tmp7, data_tmp15, data_tmp22]
def matrix_diag_part_compute(input_diagonal, input_help): """matrix_diag_part compute implemention""" shape_input_diagonal = get_shape(input_diagonal) dtype_input_diagonal = input_diagonal.dtype if dtype_input_diagonal == "int8" or dtype_input_diagonal == "uint8": input_diagonal = topi.cast(input_diagonal, "float16") input_help = topi.cast(input_help, "float16") if dtype_input_diagonal == "int32" and product_is_mini(): input_diagonal = topi.cast(input_diagonal, "float16") input_help = topi.cast(input_help, "float16") input_diagonal = topi.cast(input_diagonal, "float32") input_help = topi.cast(input_help, "float32") if dtype_input_diagonal == "int32" and not product_is_mini(): input_diagonal = topi.cast(input_diagonal, "float32") input_help = topi.cast(input_help, "float32") res_vmul = topi.multiply(input_help, input_diagonal) if shape_input_diagonal[-2] < shape_input_diagonal[-1]: res = topi.sum(res_vmul, -1) else: res = topi.sum(res_vmul, -2) if dtype_input_diagonal == "int32" and product_is_mini(): res = topi.cast(res, "float16") res = topi.cast(res, dtype_input_diagonal) return res
def fused_bn_follow_relu_avgpool(data0, data1, data2, data3, data4, data5, layout='NHWC', out_dtype='float16', target=utils.CUDA): """ input: data: length is 6 data0: tensor1 after bn_double_relu data1-6: bn parameters for conv2d tensor2 layout: only (N, H, W, C), (N, C, H, W) supported out_dtype: float16 output: avg-pooling( max(batch-normalized tensor1 + batch-normalized tensor2, 0) ) """ if layout == 'NCHW': data0 = topi.transpose(data0, (0, 2, 3, 1)) data5 = topi.transpose(data5, (0, 2, 3, 1)) elif layout != 'NHWC': raise NotImplementedError( 'Layout not supported {} '.format(layout)) n, h, w, c = data0.shape inter_dtype = 'float32' add0 = fused_bn_follow(data1, data2, data3, data4, data5) add0 = topi.cast(add0, data0.dtype) add1 = topi.add(data0, add0) output = topi.maximum(add1, 0) output = topi.cast(output, inter_dtype) output = topi.sum(output, axis=(1, 2)) output = topi.divide(output, h * w) output = topi.cast(output, out_dtype) return output
def _mean(data, axis, cof, shape): size = 1 for i, _ in enumerate(axis): size = size * shape[axis[i]] cof = cof / tvm.const(size, "float32") tmp = topi.multiply(data, cof) res = topi.sum(tmp, axis) return res
def bn_beta_grad(head, layout='NHWC'): if layout == "NCHW": head = topi.tranpose(head, (0, 2, 3, 1)) n, h, w, c = head.shape n = n.value h = h.value w = w.value c = c.value bn_beta_grad = topi.sum(head, axis=(0, 1, 2)) return bn_beta_grad
def bn_gamma_grad(head, in_data, data_sum, layout="NHWC"): if layout == "NCHW": head = topi.tranpose(head, (0, 2, 3, 1)) n, h, w, c = head.shape n = n.value h = h.value w = w.value c = c.value scale = tvm.const(n * h * w, head.dtype) mean = topi.divide(data_sum, scale) x_hat = topi.subtract(in_data, mean) x_hat_mul = topi.multiply(x_hat, head) bn_gamma_grad = topi.sum(x_hat_mul, axis=(0, 1, 2)) return bn_gamma_grad
def _sum(data, axis, cof): data_tmp_input = topi.multiply(data, cof) tmp = data_tmp_input res = topi.sum(tmp, axis) return res
def _sumsq(data, axis, cof): data_tmp_input = topi.multiply(data, data) tmp = topi.multiply(data_tmp_input, cof) res = topi.sum(tmp, axis) return res