def sum_data(data, axes, keepdims, single_sum=False): """different solutions for sum multi axes""" if single_sum: data = akg.topi.sum(data, axis=axes, keepdims=keepdims) else: data = mul_axis_sum(data, axes, keepdims) return data
def fused_bn1(data): """ Fused_batch_norm is departed to 3 parts for better performance. First part: .. math:: \\begin{array}{ll} \\\\ m = N \\times H \\times W \\\\ \\mu_{tmp} = \\sum_{n, h, w}{\\frac{x}{m}} \\\\ \\sigma^2_{tmp} = \\sum_{n, h, w}{\\frac{x^2}{m}} \\end{array} Second part: .. math:: \\begin{array}{ll} \\\\ \\sigma^2 = \\sigma^2_{tmp} - \\mu^2 \\\\ \\mu_{r} = momentum \\cdot \\mu_{r} + (1-momentum) \\cdot \\mu \\\\ \\sigma^2_{r} = momentum \\cdot \\sigma^2_{r} + (1-momentum) \\cdot \\sigma^2 \\end{array} Third part: .. math:: \\begin{array}{ll} \\\\ \\hat{\\gamma} = \\gamma \\cdot \\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}} \\\\ \\hat{\\beta} = \\beta - \\hat{\\gamma} \\cdot \\mu \\\\ res = \\hat{\\gamma} \\cdot x + \\hat{\\beta} \\end{array} The first part of fused batch norm. It will reduce H and W axis firstly. Args: data (tvm.tensor.Tensor): Tensor of type float16 or float32 with shape (N,C1,H,W,C0). Returns: mean (tvm.tensor.Tensor): Tensor of type float32 with shape(1,C1,1,1,C0). var_part (tvm.tensor.Tensor): Tensor of type float32 with shape(1,C1,1,1,C0). """ bn1_check(data) dim_info, _ = bn1_set_dim_func(data) attrs = {**DEFAULT_ATTR_MAP_BN1} shape = get_shape(data) num = reduce(lambda i, j: i * j, [shape[i] for i in [3, 2, 0]]) avg_num = float(1) / float(num) attrs["custom_tiling"] = bn1_tiling_strategy(data) data = data.astype("float32") square = akg.tvm.compute(data.shape, lambda *i: data[i] * data[i], name="square") axes = [2, 3] mean_tmp = mul_axis_sum(data, axes, True) var_part_tmp = mul_axis_sum(square, axes, True) mean_tmp_div_num = akg.lang.cce.vmuls(mean_tmp, avg_num) var_tmp_div_num = akg.lang.cce.vmuls(var_part_tmp, avg_num) mean = mul_axis_sum(mean_tmp_div_num, [0], True, name="mean", attrs={'atomic_add': "mean"}) var_part = mul_axis_sum(var_tmp_div_num, [0], True, name="var_part", attrs={'atomic_add': "var_part"}) if dim_info != "": attrs["dim"] = dim_info return mean, var_part, attrs