Esempio n. 1
0
def sum_data(data, axes, keepdims, single_sum=False):
    """different solutions for sum multi axes"""
    if single_sum:
        data = akg.topi.sum(data, axis=axes, keepdims=keepdims)
    else:
        data = mul_axis_sum(data, axes, keepdims)
    return data
Esempio n. 2
0
def fused_bn1(data):
    """
    Fused_batch_norm is departed to 3 parts for better performance.

    First part:

    .. math::
        \\begin{array}{ll} \\\\
            m = N \\times H \\times W \\\\
            \\mu_{tmp} = \\sum_{n, h, w}{\\frac{x}{m}} \\\\
            \\sigma^2_{tmp} = \\sum_{n, h, w}{\\frac{x^2}{m}}
        \\end{array}

    Second part:

    .. math::
        \\begin{array}{ll} \\\\
            \\sigma^2 = \\sigma^2_{tmp} - \\mu^2 \\\\
            \\mu_{r} = momentum \\cdot \\mu_{r} + (1-momentum) \\cdot \\mu \\\\
            \\sigma^2_{r} = momentum \\cdot \\sigma^2_{r}
                + (1-momentum) \\cdot \\sigma^2
        \\end{array}

    Third part:

    .. math::
        \\begin{array}{ll} \\\\
            \\hat{\\gamma} =
                \\gamma \\cdot \\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}} \\\\
            \\hat{\\beta} = \\beta - \\hat{\\gamma} \\cdot \\mu \\\\
            res = \\hat{\\gamma} \\cdot x + \\hat{\\beta}
        \\end{array}

    The first part of fused batch norm. It will reduce H and W axis firstly.

    Args:
        data (tvm.tensor.Tensor): Tensor of type float16 or float32 with shape
                                  (N,C1,H,W,C0).

    Returns:
        mean (tvm.tensor.Tensor): Tensor of type float32 with shape(1,C1,1,1,C0).
        var_part (tvm.tensor.Tensor): Tensor of type float32 with shape(1,C1,1,1,C0).
    """
    bn1_check(data)
    dim_info, _ = bn1_set_dim_func(data)
    attrs = {**DEFAULT_ATTR_MAP_BN1}
    shape = get_shape(data)

    num = reduce(lambda i, j: i * j, [shape[i] for i in [3, 2, 0]])
    avg_num = float(1) / float(num)

    attrs["custom_tiling"] = bn1_tiling_strategy(data)
    data = data.astype("float32")

    square = akg.tvm.compute(data.shape,
                             lambda *i: data[i] * data[i],
                             name="square")

    axes = [2, 3]
    mean_tmp = mul_axis_sum(data, axes, True)
    var_part_tmp = mul_axis_sum(square, axes, True)
    mean_tmp_div_num = akg.lang.cce.vmuls(mean_tmp, avg_num)
    var_tmp_div_num = akg.lang.cce.vmuls(var_part_tmp, avg_num)

    mean = mul_axis_sum(mean_tmp_div_num, [0],
                        True,
                        name="mean",
                        attrs={'atomic_add': "mean"})
    var_part = mul_axis_sum(var_tmp_div_num, [0],
                            True,
                            name="var_part",
                            attrs={'atomic_add': "var_part"})

    if dim_info != "":
        attrs["dim"] = dim_info
    return mean, var_part, attrs