Пример #1
0
def acos_grad(x, dy):
    """
    Gradient for acos.

    .. math:
        dx = [\\frac{-1}{(1 - x^2)^0.5} / ] \\cdot dy

    Args:
        x (tvm.tensor.Tensor): tensor of type float16, float32.
        dy (tvm.tensor.Tensor): tensor of type float16, float32.

    Returns:
        tvm.tensor.Tensor, same type and shape as x.
    """
    dtype = x.dtype
    vc_util.ops_dtype_check(x.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    vc_util.ops_dtype_check(dy.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    vc_util.check_shape(x.shape)
    vc_util.check_shape(dy.shape)

    one = akg.tvm.const(1.0, dtype=dtype)
    mid_square = akg.tvm.compute(x.shape,
                                 lambda *i: (one - x(*i) * x(*i)),
                                 name="mid_square")
    rsq = rsqrt.rsqrt(mid_square)
    dx = akg.tvm.compute(x.shape, lambda *i: -rsq(*i) * dy(*i), name="dx")

    return dx
Пример #2
0
def _apply_adagrad_compute(var, accum, lr, grad, update_slots):
    """Compute apply_adagrad"""
    input_dtype = var.dtype
    if input_dtype == "float16":
        var = akg.lang.cce.cast_to(var, "float32")
        accum = akg.lang.cce.cast_to(accum, "float32")
        lr = akg.lang.cce.cast_to(lr, "float32")
        grad = akg.lang.cce.cast_to(grad, "float32")

    if update_slots is True:
        # accum += grad ** 2
        grad_square = akg.lang.cce.vmul(grad, grad)
        accum = akg.lang.cce.vadd(accum, grad_square)
    elif input_dtype == 'float32':
        accum = akg.lang.cce.vadds(accum, akg.tvm.const(0, "float32"))

    # var -= lr * grad / accum.sqrt()
    lr_grad = akg.tvm.compute(grad.shape,
                              lambda *indices: grad(*indices) * lr[0],
                              tag='elewise_single_VS_mul')
    rsqrt_accum = rsqrt(accum)

    update = akg.lang.cce.vmul(lr_grad, rsqrt_accum)
    out_var = akg.lang.cce.vsub(var, update)

    if input_dtype == "float16":
        out_var = akg.lang.cce.cast_to(out_var, "float16")
        accum = akg.lang.cce.cast_to(accum, "float16")

    return out_var, accum
Пример #3
0
def l2normalize(data):
    vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT)
    vc_util.check_shape(data.shape)
    square_res = akg.lang.cce.vmul(data, data)
    reduce_sum, _ = sum_value(square_res, -1, keepdims=True)
    one_of_square = rsqrt(reduce_sum)
    broad_cast = akg.lang.cce.broadcast(one_of_square, data.shape)
    res = akg.lang.cce.vmul(data, broad_cast)
    attrs = {"pragma_modshift": 1}
    return res, attrs
Пример #4
0
def _apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon):
    """Compute apply_rms_prop"""
    compute_dtype = "float32"
    dtype = var.dtype
    if dtype != compute_dtype:
        var, ms, mom, grad, lr, momentum, rho = [
            topi.cast(t, compute_dtype)
            for t in [var, ms, mom, grad, lr, momentum, rho]
        ]
    shape = get_shape(var)
    cons_eps = akg.tvm.const(epsilon, dtype=compute_dtype)
    one_minus_rho = akg.tvm.compute(
        (1, ),
        lambda *indice: akg.tvm.const(1.0, compute_dtype) - rho[0],
        name="one_minus_rho")

    # var_update = var - (momentum * mom + lr * grad / sqrt(rho * ms + (1 - rho) * grad * grad + epsilon))
    mom_1 = akg.tvm.compute(shape,
                            lambda *indice: momentum[0] * mom(*indice),
                            name="mom_1")
    lr_grad = akg.tvm.compute(shape,
                              lambda *indice: grad(*indice) * lr[0],
                              name="lr_grad")
    rho_ms = akg.tvm.compute(shape,
                             lambda *indice: ms(*indice) * rho[0],
                             name="rho_ms")
    rho_grad2 = akg.tvm.compute(
        shape,
        lambda *indice: grad(*indice) * grad(*indice) * one_minus_rho[0],
        name="rho_grad2")
    ms_update = akg.tvm.compute(
        shape,
        lambda *indice: rho_ms(*indice) + rho_grad2(*indice),
        name="ms_update")
    ms_eps = akg.tvm.compute(shape,
                             lambda *indice: ms_update(*indice) + cons_eps,
                             name="ms_eps")
    rsq = rsqrt(ms_eps)
    mom_2 = akg.tvm.compute(shape,
                            lambda *indice: lr_grad(*indice) * rsq(*indice),
                            name="mom_2")
    mom_update = akg.tvm.compute(
        shape,
        lambda *indice: mom_1(*indice) + mom_2(*indice),
        name="mom_update")
    var_update = akg.tvm.compute(
        shape,
        lambda *indice: var(*indice) - mom_update(*indice),
        name="var_update")
    if var_update.dtype != dtype:
        var_update, ms_update, mom_update = [
            topi.cast(t, dtype) for t in [var_update, ms_update, mom_update]
        ]

    return var_update, ms_update, mom_update
Пример #5
0
def _apply_adadelta_compute(var, accum, accum_update, grad, lr, rho, epsilon):
    """Compute apply_adadelta"""
    dtype = var.dtype
    if dtype == "float16":
        var = topi.cast(var, "float32")
        accum = topi.cast(accum, "float32")
        accum_update = topi.cast(accum_update, "float32")
        lr = topi.cast(lr, "float32")
        rho = topi.cast(rho, "float32")
        grad = topi.cast(grad, "float32")

    epsilon = tvm.const(epsilon, "float32")
    tensor_one = akg.lang.cce.broadcast(tvm.const(1, "float32"), var.shape)
    tensor_rho = topi.broadcast_to(rho, var.shape)
    tensor_rho_gs = topi.subtract(tensor_one, tensor_rho)
    tensor_epsilon = akg.lang.cce.broadcast(epsilon, var.shape)

    # accum = accum * rho + grad ** 2 * (1 - rho)
    rhs = topi.multiply(accum, tensor_rho)
    lhs = topi.multiply(grad, grad)
    lhs = topi.multiply(lhs, tensor_rho_gs)
    accum_res = akg.lang.cce.vadd(lhs, rhs)

    # update = (accum_update + epsilon).sqrt * (accum + epsilon).rsqrt * grad
    rhs = topi.add(accum_update, tensor_epsilon)
    rhs = sqrt(rhs)
    lhs = topi.add(accum_res, tensor_epsilon)
    lhs = rsqrt(lhs)
    lhs = topi.multiply(grad, lhs)
    update = topi.multiply(lhs, rhs)

    # var -= update * lr
    var_res = topi.broadcast_to(lr, var.shape)
    var_res = topi.multiply(update, var_res)
    var_res = topi.subtract(var, var_res)

    # accum_update = rho * accum_update + (1 - rho) * update.square
    rhs = topi.multiply(accum_update, tensor_rho)
    lhs = topi.multiply(update, update)
    lhs = topi.multiply(lhs, tensor_rho_gs)
    accum_update_res = akg.lang.cce.vadd(lhs, rhs)

    if dtype == "float16":
        var_res = topi.cast(var_res, "float16")
        accum_res = topi.cast(accum_res, "float16")
        accum_update_res = topi.cast(accum_update_res, "float16")

    return var_res, accum_res, accum_update_res
Пример #6
0
def _apply_proximal_adagrad_compute(var, accum, lr, l1, l2, grad):
    """compute the FOBOS algorithm with adagrad learning rate"""

    dtype = var.dtype
    if dtype == "float16":
        # cast to float32 for higher accuracy
        compute_type = "float32"
        var, accum, lr, l1, l2, grad = [akg.topi.cast(t, compute_type) for t in [var, accum, lr, l1, l2, grad]]

    shape = var.shape
    accum_new = akg.tvm.compute(shape, lambda *indice: accum(*indice) + grad(*indice) * grad(*indice), name="accum_new")

    accum_new_rsqrt = rsqrt(accum_new)
    ada_lr = akg.topi.multiply(lr, accum_new_rsqrt)

    var_new = apply_proximal_gradient_descent_impl(var, ada_lr, l1, l2, grad)

    # cast to origin dtype
    var_new, accum_new = [akg.topi.cast(t, dtype) if t.dtype != dtype else t for t in [var_new, accum_new]]
    return var_new, accum_new
Пример #7
0
def _after_res_compute(abs_data):
    """
    compute bessel_i1e for abs value of data greater than or equal to 3.75

    Algrithm:
    t = 3.75 / x
    I1(x) = (1 / sqrt(x))*(0.39894228 - 0.03988024t - 0.00362018t^2
                           + 0.00163801t^3 - 0.01031555t^4 + 0.02282967t^5
                           - 0.02895312t^6 + 0.01787654t^7 - 0.00420059t^8)
    """
    broad_const_limit = akg.lang.cce.broadcast(
        akg.tvm.const(CONST_LIMIT, abs_data.dtype), abs_data.shape)
    data = div(broad_const_limit, abs_data)
    after_res = topi.multiply(data, ITR_AFTER[LEN_AFTER - 1])
    after_res = topi.add(after_res, ITR_AFTER[LEN_AFTER - 2])
    for iter_number in ITR_AFTER[LEN_AFTER - 3::-1]:
        after_res = mul(after_res, data)
        after_res = topi.add(after_res, iter_number)
    abs_data_rsqrt = rsqrt(abs_data)
    after_res = mul(after_res, abs_data_rsqrt)
    return after_res
Пример #8
0
def fused_bn3(data, mean, variance, gamma, beta, eps=1e-3):
    """
    The third part of fused batch norm, calculate the normalized result.

    Read fused_bn1 docs for details.

    Note:
        This part is also the reference implement for fused_batch_norm!

    Args:
        data (tvm.tensor.Tensor): Tensor of type float16 or float32 with
                                  \"NC1HWC0\" format.
        mean (tvm.tensor.Tensor): Tensor of type float32, data's mean.
        variance (tvm.tensor.Tensor): Tensor of type float32, data's variance.
        gamma (tvm.tensor.Tensor): Tensor of type float32 for scaling.
        beta (tvm.tensor.Tensor): Tensor of type float32 for bias.
        eps (float): small float value to avoid dividing zero.

    Returns:
        Tensor as normalized, scaled, shifted data.
    """
    bn3_check(data, mean, variance, gamma, beta)
    dim_info, _ = bn3_set_dim_func(data, mean, variance, gamma, beta, eps)
    attrs = {**DEFAULT_ATTR_MAP_BN3}

    ori_dtype = data.dtype

    # calculate batch norm result
    rsd = rsqrt(
        akg.tvm.compute(
            variance.shape,
            lambda *i: variance(*i) + akg.tvm.const(eps, dtype=variance.dtype),
            name="var_eps"))

    hat_gamma = akg.tvm.compute(gamma.shape,
                                lambda *i: gamma(*i) * rsd(*i),
                                name="hat_gamma",
                                attrs={'no_inline': 1})

    hat_beta = akg.tvm.compute(gamma.shape,
                               lambda *i: beta(*i) - hat_gamma(*i) * mean(*i),
                               name="hat_beta",
                               attrs={'no_inline': 1})

    hat_gamma_bc = akg.lang.cce.broadcast(hat_gamma, data.shape)
    hat_beta_bc = akg.lang.cce.broadcast(hat_beta, data.shape)

    data_fp32 = akg.tvm.compute(data.shape,
                                lambda *i: data(*i).astype("float32"),
                                name="data_fp32")

    bn_res_fp32 = akg.tvm.compute(
        data.shape,
        lambda *i: akg.lang.cce.vmadd(data_fp32(*i), hat_gamma_bc(*i),
                                      hat_beta_bc(*i)),
        name="bn_res_fp32")
    res = akg.tvm.compute(bn_res_fp32.shape,
                          lambda *i: bn_res_fp32(*i).astype(ori_dtype),
                          name="bn_res")
    if dim_info != "":
        attrs["dim"] = dim_info
    return res, attrs
Пример #9
0
def rsqrt_ad(head, a):
    b = rsqrt.rsqrt(a)
    _jacs = list(akg.differentiate(b, [a], head))
    return _jacs[0]
Пример #10
0
def fused_batch_norm(inputs, attrs):
    r"""
    Batch normalization.

    See Source:
    <a href="https://arxiv.org/abs/1502.03167">
        Batch Normalization: Accelerating Deep Network Training by Reducing
        Internal Covariate Shift; S. Ioffe, C. Szegedy.
    </a>

    .. math::
        \begin{array}{ll} \\
            \mu = \frac{1}{m} \sum^m_{i=1}{x_i} \\
            \sigma^2 = \frac{1}{m} \sum^m_{i=1}{(x_i-\mu)^2} \\
            \hat{x_i} = \frac{x_i - \mu}{ \sqrt{\sigma^2 + \epsilon} } \\
            y_i = \gamma \hat{x_i} + \beta \equiv BN_{\gamma, \beta}(x_i)
        \end{array}

    This momentum argument is different from one used in optimizer classes and
    the conventional notion of momentum. Mathematically, the update rule for
    running statistics here is

    .. math::
        \hat{z_{new}} = momentum \cdot \hat{z} + (1-momentum) \cdot z_t

    where :math:`\hat{z}` is the estimated statistic and :math:`z_t` is the
    new observed value.

    Note:
        When data_format is \"NC1HWC0\", the `gamma`, `beta`, `moving_mean`
        and `moving_variance` should be 5D tensors of shape
        `(1, C1, 1, 1, C0)`, otherwise, they should be 1D tensors
        of shape `(C,)`.

    Args:
        inputs:
            data (tvm.tensor.Tensor): Tensor of type float16, float32. (:math:`x_i`)
            gamma (tvm.tensor.Tensor): Tensor for scaling (:math:`\gamma`).
            beta (tvm.tensor.Tensor): Tensor for bias (:math:`\beta`).
            moving_mean (tvm.tensor.Tensor): Tensor for population mean used for
                                            inference.
            moving_variance (tvm.tensor.Tensor): Tensor for population variance used
                                             for inference.
        attrs:
            momentum (float): A float number used for the moving_mean and
                            moving_variance computation.
            eps (float): A small float added to variance to avoid dividing by zero.
            is_training (bool): A bool value to specify if the operation is used for
                                training or inference.
            data_format (str): Support format, \"DefaultFormat\", \"NCHW\", \"NHWC\"
                            or \"NC1HWC0\".
            axis (Union[int, list, tuple]): Integer to specify the channel axis when
                                            data_format is \"DefaultFormat\". List
                                            or tuple for \"NC1HWC0\". When format is
                                            \"NCHW\" or \"NHWC\", it's not work.
                                            Must be in the range
                                            [-rank(data), rank(data)).
            single_sum (bool): whether use "mul_axis_sum".

    Returns:
        outs (tvm.tensor.Tensor): Tensor for normalized, scaled, shifted data.
        new_moving_mean (tvm.tensor.Tensor): Tensor of same type and shape as
                                             `moving_mean`. The `moving_mean`
                                             updated by data. Only returns when
                                             `is_training` is True.
        new_moving_variance (tvm.tensor.Tensor): Tensor of same type and shape as
                                                 `moving_variance`. The
                                                 `moving_variance` updated by
                                                 data. Only returns when
                                                 `is_training` is True.
        sample_mean (tvm.tensor.Tensor): Tensor of same type and shape as
                                         `moving_mean`. The mean of `data`. Only
                                         returns when `is_training` is True.
        sample_var (tvm.tensor.Tensor): Tensor of same type and shape as
                                        `moving_variance`. The variance of `data`.
                                        Only returns when `is_training` is True.
    """
    if len(inputs) != 5:
        raise ValueError(
            "Input tensors number should be 5, but get %s." % len(inputs))
    data_format = attrs.get("data_format", "DefaultFormat")
    params = check_inputs(inputs, data_format, attrs.get("axis", 1))

    data = inputs[0]
    gamma = inputs[1]
    beta = inputs[2]
    moving_mean = inputs[3]
    moving_variance = inputs[4]
    ori_dtype = data.dtype
    shape = get_shape(data)
    axes = params.get("axes", (0,))
    keepdims = params.get("is_special5d", False)
    mid_shape = params.get("mid_shape", [1, ])
    data = akg.tvm.compute(data.shape, lambda *i: data(*i),
                           "batchnorm_" + data_format)
    ori_moving_mean = moving_mean
    ori_moving_variance = moving_variance
    if ori_dtype != DTYPE_FLOAT32:
        data = akg.topi.cast(data, DTYPE_FLOAT32)
        gamma = akg.topi.cast(gamma, DTYPE_FLOAT32)
        beta = akg.topi.cast(beta, DTYPE_FLOAT32)
        moving_mean = akg.topi.cast(moving_mean, DTYPE_FLOAT32)
        moving_variance = akg.topi.cast(moving_variance, DTYPE_FLOAT32)

    ######## following is dsl ########
    is_training = attrs.get("is_training", True)
    if is_training:
        value_num = 1
        for index in axes:
            value_num *= shape[index]

        avg_num = round(float(1) / float(value_num), 12)

        data_square = akg.tvm.compute(data.shape,
                                      lambda *i: data(*i) * data(*i),
                                      name="data_square")
        # cal mean
        data_mean = akg.lang.ascend.vmuls(
            sum_data(data, axes, keepdims, attrs.get("single_sum", False)), avg_num)
        data_square_mean = akg.lang.ascend.vmuls(sum_data(data_square, axes, keepdims, attrs.get("single_sum", False)),
                                                 avg_num)
        data_mean_square = akg.tvm.compute(data_mean.shape,
                                           lambda *i: data_mean(*i) *
                                           data_mean(*i),
                                           name="data_mean_square")

        data_variance = akg.tvm.compute(data_mean.shape,
                                        lambda *i:
                                        data_square_mean(
                                            *i) - data_mean_square(*i),
                                        name="data_variance")

        mean_new = update_by_moving_average(
            moving_mean, data_mean, attrs.get("momentum", 0.99))
        variance_new = update_by_moving_average(moving_variance,
                                                data_variance, attrs.get("momentum", 0.99))
    else:
        # no_bc version
        data_variance = moving_variance
        data_mean = moving_mean

    rsveps = akg.lang.ascend.vadds(data_variance, akg.tvm.const(
        attrs.get("eps", 1e-3), dtype=DTYPE_FLOAT32))
    rsveps = rsqrt(rsveps, utils.CCE)
    rsveps = akg.lang.ascend.broadcast(rsveps, shape)

    mean2 = akg.lang.ascend.vmuls(data_mean, akg.tvm.const(-1, data.dtype))
    mean2 = akg.lang.ascend.broadcast(mean2, shape)

    dmean = akg.tvm.compute(
        shape, lambda *i: data(*i) + mean2(*i), name="dmean")
    dmsve = akg.tvm.compute(shape, lambda *i: dmean(*i)
                            * rsveps(*i), name="dmsve")

    if not keepdims:
        gamma = akg.topi.reshape(gamma, mid_shape)
        beta = akg.topi.reshape(beta, mid_shape)
    gamma_bc = akg.lang.ascend.broadcast(gamma, shape)
    beta_bc = akg.lang.ascend.broadcast(beta, shape)
    dmsveg = akg.tvm.compute(shape, lambda *i: dmsve(*i) * gamma_bc(*i),
                             name="dmsveg")
    outs = akg.tvm.compute(shape, lambda *i: dmsveg(*i) + beta_bc(*i),
                           name="output")
    out_attrs = get_attrs(outs)

    if is_training:
        if ori_dtype != DTYPE_FLOAT32:
            outs = akg.topi.cast(outs, ori_dtype)
            mean_new = akg.topi.cast(mean_new, ori_dtype)
            variance_new = akg.topi.cast(variance_new, ori_dtype)
            data_mean = akg.topi.cast(data_mean, ori_dtype)
            data_variance = akg.topi.cast(data_variance, ori_dtype)

        mean_new, binds_info_mean = TensorUtils.inplace_set(
            ori_moving_mean, mean_new, buffer_name="mean_buf")
        variance_new, binds_info_var = TensorUtils.inplace_set(
            ori_moving_variance, variance_new, buffer_name="var_buf")
        binds_info_all = binds_info_mean
        binds_info_all.update(binds_info_var)
        out_attrs[BINDS] = binds_info_all

        # the new moving_mean and moving_var are updated inplace in
        # inputs(moving_mean and moving_var). But Mindspore needs
        # These two fake outputs though it never uses them
        fake_moving_mean = akg.tvm.compute(mean_new.shape,
                                           lambda *indices: mean_new(*indices),
                                           "fake_moving_mean")
        fake_moving_var = akg.tvm.compute(mean_new.shape,
                                          lambda *indices: variance_new(
                                              *indices),
                                          "fake_moving_var")
        out_tensors = (outs, fake_moving_mean, fake_moving_var, data_mean,
                       data_variance, mean_new, variance_new,)
    else:
        if ori_dtype != DTYPE_FLOAT32:
            outs = akg.topi.cast(outs, ori_dtype)
        out_tensors = (outs,)
    out_tensors = list(out_tensors) if isinstance(
        out_tensors, tuple) else out_tensors
    if shape_is_dynamic(out_tensors):
        out_attrs["custom_tiling"] = batch_norm_tiling_strategy_dynamic(outs)
    else:
        out_attrs["custom_tiling"] = batch_norm_tiling_strategy(
            outs, data_format)
    out_tensors.append(out_attrs)

    return out_tensors
Пример #11
0
def _apply_centered_rms_prop_compute(var, mg, ms, mom, grad, lr, momentum, rho, epsilon):
    """Compute apply_centered_rms_prop"""
    inp_dtype = var.dtype
    if inp_dtype == "float16":
        var = akg.lang.cce.cast_to(var, "float32")
        mg = akg.lang.cce.cast_to(mg, "float32")
        ms = akg.lang.cce.cast_to(ms, "float32")
        mom = akg.lang.cce.cast_to(mom, "float32")
        lr = akg.lang.cce.cast_to(lr, "float32")
        rho = akg.lang.cce.cast_to(rho, "float32")
        momentum = akg.lang.cce.cast_to(momentum, "float32")
        grad = akg.lang.cce.cast_to(grad, "float32")
    epsilon = akg.tvm.const(epsilon, var.dtype)

    tensor_one_rho = akg.tvm.compute(rho.shape,
                                     lambda *indices: rho(*indices) * akg.tvm.const(-1, rho.dtype),
                                     tag='elewise_single_VS_mul')
    tensor_one_rho = akg.tvm.compute(
        tensor_one_rho.shape,
        lambda *indices: tensor_one_rho(*indices) + akg.tvm.const(1, rho.dtype),
        tag='elewise_single_VS_add')

    # out_mg <- rho * mg + (1-rho) * grad
    mg_rho = akg.tvm.compute(mg.shape,
                             lambda *indices: mg(*indices) * rho[0],
                             tag='elewise_single_VS_mul')
    rhs = akg.tvm.compute(grad.shape,
                          lambda *indices: grad(*indices) * tensor_one_rho[0],
                          tag='elewise_single_VS_mul')
    out_mg = akg.lang.cce.vadd(mg_rho, rhs)

    # out_ms <- rho * ms + (1-rho) * grad * grad
    ms_rho = akg.tvm.compute(ms.shape,
                             lambda *indices: ms(*indices) * rho[0],
                             tag='elewise_single_VS_mul')
    rhs = akg.lang.cce.vmul(grad, grad)
    rhs = akg.tvm.compute(rhs.shape,
                          lambda *indices: rhs(*indices) * tensor_one_rho[0],
                          tag='elewise_single_VS_mul')
    out_ms = akg.lang.cce.vadd(ms_rho, rhs)

    # out_mom <- momentum * mom + lr * grad / sqrt(out_ms - out_mg * out_mg + epsilon)
    lhs_mom = akg.tvm.compute(mom.shape,
                              lambda *indices: mom(*indices) * momentum[0],
                              tag='elewise_single_VS_mul')
    lr_grad = akg.tvm.compute(grad.shape,
                              lambda *indices: grad(*indices) * lr[0],
                              tag='elewise_single_VS_mul')
    rhs = akg.lang.cce.vmul(out_mg, out_mg)
    rhs = akg.lang.cce.vsub(out_ms, rhs)
    rhs_eps = akg.tvm.compute(rhs.shape,
                              lambda *indices: rhs(*indices) + epsilon,
                              tag='elewise_single_VS_add')
    rhs_eps = rsqrt(rhs_eps)
    rhs_eps = akg.lang.cce.vmul(lr_grad, rhs_eps)
    out_mom = akg.lang.cce.vadd(lhs_mom, rhs_eps)

    # out_var <- var - out_mom
    out_var = akg.lang.cce.vsub(var, out_mom)

    if inp_dtype == "float16":
        out_var = akg.lang.cce.cast_to(out_var, "float16")
        out_mg = akg.lang.cce.cast_to(out_mg, "float16")
        out_ms = akg.lang.cce.cast_to(out_ms, "float16")
        out_mom = akg.lang.cce.cast_to(out_mom, "float16")

    return out_var, out_mg, out_ms, out_mom