예제 #1
0
def _apply_adagrad_compute(var, accum, lr, grad, update_slots):
    """Compute apply_adagrad"""
    input_dtype = var.dtype
    if input_dtype == "float16":
        var = akg.lang.ascend.cast_to(var, "float32")
        accum = akg.lang.ascend.cast_to(accum, "float32")
        lr = akg.lang.ascend.cast_to(lr, "float32")
        grad = akg.lang.ascend.cast_to(grad, "float32")

    if update_slots is True:
        # accum += grad ** 2
        grad_square = akg.lang.ascend.vmul(grad, grad)
        accum = akg.lang.ascend.vadd(accum, grad_square)
    elif input_dtype == 'float32':
        accum = akg.lang.ascend.vadds(accum, akg.tvm.const(0, "float32"))

    # var -= lr * grad / accum.sqrt()
    lr_grad = akg.tvm.compute(grad.shape,
                              lambda *indices: grad(*indices) * lr[0],
                              tag='elewise_single_VS_mul')
    rsqrt_accum = rsqrt(accum, target=utils.CCE)

    update = akg.lang.ascend.vmul(lr_grad, rsqrt_accum)
    out_var = akg.lang.ascend.vsub(var, update)

    if input_dtype == "float16":
        out_var = akg.lang.ascend.cast_to(out_var, "float16")
        accum = akg.lang.ascend.cast_to(accum, "float16")

    return out_var, accum
예제 #2
0
def _apply_proximal_adagrad_compute(var, accum, lr, l1, l2, grad):
    """compute the FOBOS algorithm with adagrad learning rate"""

    dtype = var.dtype
    if dtype == "float16":
        # cast to float32 for higher accuracy
        compute_type = "float32"
        var, accum, lr, l1, l2, grad = [
            akg.topi.cast(t, compute_type)
            for t in [var, accum, lr, l1, l2, grad]
        ]

    shape = var.shape
    accum_new = akg.tvm.compute(
        shape,
        lambda *indice: accum(*indice) + grad(*indice) * grad(*indice),
        name="accum_new")

    accum_new_rsqrt = rsqrt(accum_new, target="cce")
    ada_lr = akg.topi.multiply(lr, accum_new_rsqrt)

    var_new = apply_proximal_gradient_descent_impl(var, ada_lr, l1, l2, grad)

    # cast to origin dtype
    var_new, accum_new = [
        akg.topi.cast(t, dtype) if t.dtype != dtype else t
        for t in [var_new, accum_new]
    ]
    return var_new, accum_new
예제 #3
0
def l2normalize(data, target=utils.CCE):
    utils.ops_dtype_check(data.dtype, utils.DtypeForDavinci.ALL_FLOAT)
    utils.check_shape(data.shape)
    square_res = akg.lang.ascend.vmul(data, data)
    reduce_sum = sum(square_res, -1, keepdims=True, target=target)
    one_of_square = rsqrt(reduce_sum, target=target)
    broad_cast = akg.lang.ascend.broadcast(one_of_square, data.shape)
    res = akg.lang.ascend.vmul(data, broad_cast)
    attrs = {"pragma_modshift": 1}
    return res, attrs
예제 #4
0
def _apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon):
    """Compute apply_rms_prop"""
    compute_dtype = "float32"
    dtype = var.dtype
    if dtype != compute_dtype:
        var, ms, mom, grad, lr, momentum, rho = [
            topi.cast(t, compute_dtype)
            for t in [var, ms, mom, grad, lr, momentum, rho]
        ]
    shape = get_shape(var)
    cons_eps = akg.tvm.const(epsilon, dtype=compute_dtype)
    one_minus_rho = akg.tvm.compute(
        (1, ),
        lambda *indice: akg.tvm.const(1.0, compute_dtype) - rho[0],
        name="one_minus_rho")

    # var_update = var - (momentum * mom + lr * grad / sqrt(rho * ms + (1 - rho) * grad * grad + epsilon))
    mom_1 = akg.tvm.compute(shape,
                            lambda *indice: momentum[0] * mom(*indice),
                            name="mom_1")
    lr_grad = akg.tvm.compute(shape,
                              lambda *indice: grad(*indice) * lr[0],
                              name="lr_grad")
    rho_ms = akg.tvm.compute(shape,
                             lambda *indice: ms(*indice) * rho[0],
                             name="rho_ms")
    rho_grad2 = akg.tvm.compute(
        shape,
        lambda *indice: grad(*indice) * grad(*indice) * one_minus_rho[0],
        name="rho_grad2")
    ms_update = akg.tvm.compute(
        shape,
        lambda *indice: rho_ms(*indice) + rho_grad2(*indice),
        name="ms_update")
    ms_eps = akg.tvm.compute(shape,
                             lambda *indice: ms_update(*indice) + cons_eps,
                             name="ms_eps")
    rsq = rsqrt(ms_eps, target="cce")
    mom_2 = akg.tvm.compute(shape,
                            lambda *indice: lr_grad(*indice) * rsq(*indice),
                            name="mom_2")
    mom_update = akg.tvm.compute(
        shape,
        lambda *indice: mom_1(*indice) + mom_2(*indice),
        name="mom_update")
    var_update = akg.tvm.compute(
        shape,
        lambda *indice: var(*indice) - mom_update(*indice),
        name="var_update")
    if var_update.dtype != dtype:
        var_update, ms_update, mom_update = [
            topi.cast(t, dtype) for t in [var_update, ms_update, mom_update]
        ]

    return var_update, ms_update, mom_update
예제 #5
0
def _apply_adadelta_compute(var, accum, accum_update, grad, lr, rho, epsilon):
    """Compute apply_adadelta"""
    dtype = var.dtype
    if dtype == "float16":
        var = topi.cast(var, "float32")
        accum = topi.cast(accum, "float32")
        accum_update = topi.cast(accum_update, "float32")
        lr = topi.cast(lr, "float32")
        rho = topi.cast(rho, "float32")
        grad = topi.cast(grad, "float32")

    epsilon = tvm.const(epsilon, "float32")
    tensor_one = akg.lang.ascend.broadcast(tvm.const(1, "float32"), var.shape)
    tensor_rho = topi.broadcast_to(rho, var.shape)
    tensor_rho_gs = topi.subtract(tensor_one, tensor_rho)
    tensor_epsilon = akg.lang.ascend.broadcast(epsilon, var.shape)

    # accum = accum * rho + grad ** 2 * (1 - rho)
    rhs = topi.multiply(accum, tensor_rho)
    lhs = topi.multiply(grad, grad)
    lhs = topi.multiply(lhs, tensor_rho_gs)
    accum_res = akg.lang.ascend.vadd(lhs, rhs)

    # update = (accum_update + epsilon).sqrt * (accum + epsilon).rsqrt * grad
    rhs = topi.add(accum_update, tensor_epsilon)
    rhs = sqrt(rhs, target=utils.CCE)
    lhs = topi.add(accum_res, tensor_epsilon)
    lhs = rsqrt(lhs, target=utils.CCE)
    lhs = topi.multiply(grad, lhs)
    update = topi.multiply(lhs, rhs)

    # var -= update * lr
    var_res = topi.broadcast_to(lr, var.shape)
    var_res = topi.multiply(update, var_res)
    var_res = topi.subtract(var, var_res)

    # accum_update = rho * accum_update + (1 - rho) * update.square
    rhs = topi.multiply(accum_update, tensor_rho)
    lhs = topi.multiply(update, update)
    lhs = topi.multiply(lhs, tensor_rho_gs)
    accum_update_res = akg.lang.ascend.vadd(lhs, rhs)

    if dtype == "float16":
        var_res = topi.cast(var_res, "float16")
        accum_res = topi.cast(accum_res, "float16")
        accum_update_res = topi.cast(accum_update_res, "float16")

    return var_res, accum_res, accum_update_res
예제 #6
0
def _bessel_i0e_compute(input_data):
    """bessel i0e compute"""

    shape_input = input_data.shape
    dtype_input = input_data.dtype

    # chose the type of data in begin
    if dtype_input == "float16":
        input_data = Cast(input_data, "float32", target=utils.CCE)
    abs_data = Abs(input_data, target=utils.CCE)

    # compute bessel_i0e for data in (-3.75, 3.75)
    # t = |x| / 3.75
    # I0e = e^-|x|(1 + 3.5156229t^2 + 3.0899424t^4 + 1.2067492t^6 + 0.2659732t^8
    #       + 0.0360768t^10 + 0.0045813t^12)), |x| <= 3.75
    broad_const_limit = akg.lang.ascend.broadcast(
        akg.tvm.const(CONST_LIMIT, "float32"), shape_input)
    before_abs_data = minimum(abs_data, broad_const_limit)
    data = topi.multiply(before_abs_data, 1.0 / CONST_LIMIT)
    square_data = mul(data, data, target=utils.CCE)
    before_res = topi.multiply(square_data, ITR_BEFORE[LEN_BEFORE - 1])
    before_res = topi.add(before_res, ITR_BEFORE[LEN_BEFORE - 2])
    for iter_number in ITR_BEFORE[LEN_BEFORE - 3::-1]:
        before_res = mul(before_res, square_data, target=utils.CCE)
        before_res = topi.add(before_res, iter_number)
    exp_data = Exp(neg(before_abs_data, target=utils.CCE), target=utils.CCE)
    before_res = mul(before_res, exp_data, target=utils.CCE)

    # compute bessel_i0e for data in other domain
    # t = |x| / 3.75
    # I0e(x) = (1 / sqrt(|x|))*(0.39894228 + 0.01328592t^-1 + 0.00225319t^-2 + -0.00157565t^-3
    #           + 0.00916281t^-4 + -0.02057706t^-5 + 0.02635537t^-6 + -0.01647633t^-7
    #           + 0.00392377t^-8), |x| >= 3.75
    data = Divide(broad_const_limit, abs_data, target=utils.CCE)
    after_res = topi.multiply(data, ITR_AFTER[LEN_AFTER - 1])
    after_res = topi.add(after_res, ITR_AFTER[LEN_AFTER - 2])
    for iter_number in ITR_AFTER[LEN_AFTER - 3::-1]:
        after_res = mul(after_res, data, target=utils.CCE)
        after_res = topi.add(after_res, iter_number)
    rsqrt_data = rsqrt(abs_data, target=utils.CCE)
    after_res = mul(after_res, rsqrt_data, target=utils.CCE)
    after_res = minimum(before_res, after_res, target=utils.CCE)

    # chose the type of data in end
    if dtype_input == "float16":
        after_res = Cast(after_res, "float16", target=utils.CCE)

    return after_res
예제 #7
0
def _after_res_compute(abs_data):
    """
    compute bessel_i1e for abs value of data greater than or equal to 3.75

    Algrithm:
    t = 3.75 / x
    I1(x) = (1 / sqrt(x))*(0.39894228 - 0.03988024t - 0.00362018t^2
                           + 0.00163801t^3 - 0.01031555t^4 + 0.02282967t^5
                           - 0.02895312t^6 + 0.01787654t^7 - 0.00420059t^8)
    """
    broad_const_limit = akg.lang.ascend.broadcast(
        akg.tvm.const(CONST_LIMIT, abs_data.dtype), abs_data.shape)
    data = Divide(broad_const_limit, abs_data, target=utils.CCE)
    after_res = topi.multiply(data, ITR_AFTER[LEN_AFTER - 1])
    after_res = topi.add(after_res, ITR_AFTER[LEN_AFTER - 2])
    for iter_number in ITR_AFTER[LEN_AFTER - 3::-1]:
        after_res = mul(after_res, data, target=utils.CCE)
        after_res = topi.add(after_res, iter_number)
    abs_data_rsqrt = rsqrt(abs_data, target=utils.CCE)
    after_res = mul(after_res, abs_data_rsqrt, target=utils.CCE)
    return after_res
예제 #8
0
def _apply_centered_rms_prop_compute(var, mg, ms, mom, grad, lr, momentum, rho, epsilon):
    """Compute apply_centered_rms_prop"""
    inp_dtype = var.dtype
    if inp_dtype == "float16":
        var = akg.lang.ascend.cast_to(var, "float32")
        mg = akg.lang.ascend.cast_to(mg, "float32")
        ms = akg.lang.ascend.cast_to(ms, "float32")
        mom = akg.lang.ascend.cast_to(mom, "float32")
        lr = akg.lang.ascend.cast_to(lr, "float32")
        rho = akg.lang.ascend.cast_to(rho, "float32")
        momentum = akg.lang.ascend.cast_to(momentum, "float32")
        grad = akg.lang.ascend.cast_to(grad, "float32")
    epsilon = akg.tvm.const(epsilon, var.dtype)

    tensor_one_rho = akg.tvm.compute(rho.shape,
                                     lambda *indices: rho(*indices) * akg.tvm.const(-1, rho.dtype),
                                     tag='elewise_single_VS_mul')
    tensor_one_rho = akg.tvm.compute(
        tensor_one_rho.shape,
        lambda *indices: tensor_one_rho(*indices) + akg.tvm.const(1, rho.dtype),
        tag='elewise_single_VS_add')

    # out_mg <- rho * mg + (1-rho) * grad
    mg_rho = akg.tvm.compute(mg.shape,
                             lambda *indices: mg(*indices) * rho[0],
                             tag='elewise_single_VS_mul')
    rhs = akg.tvm.compute(grad.shape,
                          lambda *indices: grad(*indices) * tensor_one_rho[0],
                          tag='elewise_single_VS_mul')
    out_mg = akg.lang.ascend.vadd(mg_rho, rhs)

    # out_ms <- rho * ms + (1-rho) * grad * grad
    ms_rho = akg.tvm.compute(ms.shape,
                             lambda *indices: ms(*indices) * rho[0],
                             tag='elewise_single_VS_mul')
    rhs = akg.lang.ascend.vmul(grad, grad)
    rhs = akg.tvm.compute(rhs.shape,
                          lambda *indices: rhs(*indices) * tensor_one_rho[0],
                          tag='elewise_single_VS_mul')
    out_ms = akg.lang.ascend.vadd(ms_rho, rhs)

    # out_mom <- momentum * mom + lr * grad / sqrt(out_ms - out_mg * out_mg + epsilon)
    lhs_mom = akg.tvm.compute(mom.shape,
                              lambda *indices: mom(*indices) * momentum[0],
                              tag='elewise_single_VS_mul')
    lr_grad = akg.tvm.compute(grad.shape,
                              lambda *indices: grad(*indices) * lr[0],
                              tag='elewise_single_VS_mul')
    rhs = akg.lang.ascend.vmul(out_mg, out_mg)
    rhs = akg.lang.ascend.vsub(out_ms, rhs)
    rhs_eps = akg.tvm.compute(rhs.shape,
                              lambda *indices: rhs(*indices) + epsilon,
                              tag='elewise_single_VS_add')
    rhs_eps = rsqrt(rhs_eps, target=utils.CCE)
    rhs_eps = akg.lang.ascend.vmul(lr_grad, rhs_eps)
    out_mom = akg.lang.ascend.vadd(lhs_mom, rhs_eps)

    # out_var <- var - out_mom
    out_var = akg.lang.ascend.vsub(var, out_mom)

    if inp_dtype == "float16":
        out_var = akg.lang.ascend.cast_to(out_var, "float16")
        out_mg = akg.lang.ascend.cast_to(out_mg, "float16")
        out_ms = akg.lang.ascend.cast_to(out_ms, "float16")
        out_mom = akg.lang.ascend.cast_to(out_mom, "float16")

    return out_var, out_mg, out_ms, out_mom
예제 #9
0
def rsqrt_ad(head, a, target="cce"):
    b = rsqrt(a, target='cce')
    _jacs = list(akg.differentiate(b, [a], head))
    return _jacs[0]