def acos_grad(x, dy): """ Gradient for acos. .. math: dx = [\\frac{-1}{(1 - x^2)^0.5} / ] \\cdot dy Args: x (tvm.tensor.Tensor): tensor of type float16, float32. dy (tvm.tensor.Tensor): tensor of type float16, float32. Returns: tvm.tensor.Tensor, same type and shape as x. """ dtype = x.dtype vc_util.ops_dtype_check(x.dtype, vc_util.DtypeForDavinci.ALL_FLOAT) vc_util.ops_dtype_check(dy.dtype, vc_util.DtypeForDavinci.ALL_FLOAT) vc_util.check_shape(x.shape) vc_util.check_shape(dy.shape) one = akg.tvm.const(1.0, dtype=dtype) mid_square = akg.tvm.compute(x.shape, lambda *i: (one - x(*i) * x(*i)), name="mid_square") rsq = rsqrt.rsqrt(mid_square) dx = akg.tvm.compute(x.shape, lambda *i: -rsq(*i) * dy(*i), name="dx") return dx
def _apply_adagrad_compute(var, accum, lr, grad, update_slots): """Compute apply_adagrad""" input_dtype = var.dtype if input_dtype == "float16": var = akg.lang.cce.cast_to(var, "float32") accum = akg.lang.cce.cast_to(accum, "float32") lr = akg.lang.cce.cast_to(lr, "float32") grad = akg.lang.cce.cast_to(grad, "float32") if update_slots is True: # accum += grad ** 2 grad_square = akg.lang.cce.vmul(grad, grad) accum = akg.lang.cce.vadd(accum, grad_square) elif input_dtype == 'float32': accum = akg.lang.cce.vadds(accum, akg.tvm.const(0, "float32")) # var -= lr * grad / accum.sqrt() lr_grad = akg.tvm.compute(grad.shape, lambda *indices: grad(*indices) * lr[0], tag='elewise_single_VS_mul') rsqrt_accum = rsqrt(accum) update = akg.lang.cce.vmul(lr_grad, rsqrt_accum) out_var = akg.lang.cce.vsub(var, update) if input_dtype == "float16": out_var = akg.lang.cce.cast_to(out_var, "float16") accum = akg.lang.cce.cast_to(accum, "float16") return out_var, accum
def l2normalize(data): vc_util.ops_dtype_check(data.dtype, vc_util.DtypeForDavinci.ALL_FLOAT) vc_util.check_shape(data.shape) square_res = akg.lang.cce.vmul(data, data) reduce_sum, _ = sum_value(square_res, -1, keepdims=True) one_of_square = rsqrt(reduce_sum) broad_cast = akg.lang.cce.broadcast(one_of_square, data.shape) res = akg.lang.cce.vmul(data, broad_cast) attrs = {"pragma_modshift": 1} return res, attrs
def _apply_rms_prop_compute(var, ms, mom, grad, lr, momentum, rho, epsilon): """Compute apply_rms_prop""" compute_dtype = "float32" dtype = var.dtype if dtype != compute_dtype: var, ms, mom, grad, lr, momentum, rho = [ topi.cast(t, compute_dtype) for t in [var, ms, mom, grad, lr, momentum, rho] ] shape = get_shape(var) cons_eps = akg.tvm.const(epsilon, dtype=compute_dtype) one_minus_rho = akg.tvm.compute( (1, ), lambda *indice: akg.tvm.const(1.0, compute_dtype) - rho[0], name="one_minus_rho") # var_update = var - (momentum * mom + lr * grad / sqrt(rho * ms + (1 - rho) * grad * grad + epsilon)) mom_1 = akg.tvm.compute(shape, lambda *indice: momentum[0] * mom(*indice), name="mom_1") lr_grad = akg.tvm.compute(shape, lambda *indice: grad(*indice) * lr[0], name="lr_grad") rho_ms = akg.tvm.compute(shape, lambda *indice: ms(*indice) * rho[0], name="rho_ms") rho_grad2 = akg.tvm.compute( shape, lambda *indice: grad(*indice) * grad(*indice) * one_minus_rho[0], name="rho_grad2") ms_update = akg.tvm.compute( shape, lambda *indice: rho_ms(*indice) + rho_grad2(*indice), name="ms_update") ms_eps = akg.tvm.compute(shape, lambda *indice: ms_update(*indice) + cons_eps, name="ms_eps") rsq = rsqrt(ms_eps) mom_2 = akg.tvm.compute(shape, lambda *indice: lr_grad(*indice) * rsq(*indice), name="mom_2") mom_update = akg.tvm.compute( shape, lambda *indice: mom_1(*indice) + mom_2(*indice), name="mom_update") var_update = akg.tvm.compute( shape, lambda *indice: var(*indice) - mom_update(*indice), name="var_update") if var_update.dtype != dtype: var_update, ms_update, mom_update = [ topi.cast(t, dtype) for t in [var_update, ms_update, mom_update] ] return var_update, ms_update, mom_update
def _apply_adadelta_compute(var, accum, accum_update, grad, lr, rho, epsilon): """Compute apply_adadelta""" dtype = var.dtype if dtype == "float16": var = topi.cast(var, "float32") accum = topi.cast(accum, "float32") accum_update = topi.cast(accum_update, "float32") lr = topi.cast(lr, "float32") rho = topi.cast(rho, "float32") grad = topi.cast(grad, "float32") epsilon = tvm.const(epsilon, "float32") tensor_one = akg.lang.cce.broadcast(tvm.const(1, "float32"), var.shape) tensor_rho = topi.broadcast_to(rho, var.shape) tensor_rho_gs = topi.subtract(tensor_one, tensor_rho) tensor_epsilon = akg.lang.cce.broadcast(epsilon, var.shape) # accum = accum * rho + grad ** 2 * (1 - rho) rhs = topi.multiply(accum, tensor_rho) lhs = topi.multiply(grad, grad) lhs = topi.multiply(lhs, tensor_rho_gs) accum_res = akg.lang.cce.vadd(lhs, rhs) # update = (accum_update + epsilon).sqrt * (accum + epsilon).rsqrt * grad rhs = topi.add(accum_update, tensor_epsilon) rhs = sqrt(rhs) lhs = topi.add(accum_res, tensor_epsilon) lhs = rsqrt(lhs) lhs = topi.multiply(grad, lhs) update = topi.multiply(lhs, rhs) # var -= update * lr var_res = topi.broadcast_to(lr, var.shape) var_res = topi.multiply(update, var_res) var_res = topi.subtract(var, var_res) # accum_update = rho * accum_update + (1 - rho) * update.square rhs = topi.multiply(accum_update, tensor_rho) lhs = topi.multiply(update, update) lhs = topi.multiply(lhs, tensor_rho_gs) accum_update_res = akg.lang.cce.vadd(lhs, rhs) if dtype == "float16": var_res = topi.cast(var_res, "float16") accum_res = topi.cast(accum_res, "float16") accum_update_res = topi.cast(accum_update_res, "float16") return var_res, accum_res, accum_update_res
def _apply_proximal_adagrad_compute(var, accum, lr, l1, l2, grad): """compute the FOBOS algorithm with adagrad learning rate""" dtype = var.dtype if dtype == "float16": # cast to float32 for higher accuracy compute_type = "float32" var, accum, lr, l1, l2, grad = [akg.topi.cast(t, compute_type) for t in [var, accum, lr, l1, l2, grad]] shape = var.shape accum_new = akg.tvm.compute(shape, lambda *indice: accum(*indice) + grad(*indice) * grad(*indice), name="accum_new") accum_new_rsqrt = rsqrt(accum_new) ada_lr = akg.topi.multiply(lr, accum_new_rsqrt) var_new = apply_proximal_gradient_descent_impl(var, ada_lr, l1, l2, grad) # cast to origin dtype var_new, accum_new = [akg.topi.cast(t, dtype) if t.dtype != dtype else t for t in [var_new, accum_new]] return var_new, accum_new
def _after_res_compute(abs_data): """ compute bessel_i1e for abs value of data greater than or equal to 3.75 Algrithm: t = 3.75 / x I1(x) = (1 / sqrt(x))*(0.39894228 - 0.03988024t - 0.00362018t^2 + 0.00163801t^3 - 0.01031555t^4 + 0.02282967t^5 - 0.02895312t^6 + 0.01787654t^7 - 0.00420059t^8) """ broad_const_limit = akg.lang.cce.broadcast( akg.tvm.const(CONST_LIMIT, abs_data.dtype), abs_data.shape) data = div(broad_const_limit, abs_data) after_res = topi.multiply(data, ITR_AFTER[LEN_AFTER - 1]) after_res = topi.add(after_res, ITR_AFTER[LEN_AFTER - 2]) for iter_number in ITR_AFTER[LEN_AFTER - 3::-1]: after_res = mul(after_res, data) after_res = topi.add(after_res, iter_number) abs_data_rsqrt = rsqrt(abs_data) after_res = mul(after_res, abs_data_rsqrt) return after_res
def fused_bn3(data, mean, variance, gamma, beta, eps=1e-3): """ The third part of fused batch norm, calculate the normalized result. Read fused_bn1 docs for details. Note: This part is also the reference implement for fused_batch_norm! Args: data (tvm.tensor.Tensor): Tensor of type float16 or float32 with \"NC1HWC0\" format. mean (tvm.tensor.Tensor): Tensor of type float32, data's mean. variance (tvm.tensor.Tensor): Tensor of type float32, data's variance. gamma (tvm.tensor.Tensor): Tensor of type float32 for scaling. beta (tvm.tensor.Tensor): Tensor of type float32 for bias. eps (float): small float value to avoid dividing zero. Returns: Tensor as normalized, scaled, shifted data. """ bn3_check(data, mean, variance, gamma, beta) dim_info, _ = bn3_set_dim_func(data, mean, variance, gamma, beta, eps) attrs = {**DEFAULT_ATTR_MAP_BN3} ori_dtype = data.dtype # calculate batch norm result rsd = rsqrt( akg.tvm.compute( variance.shape, lambda *i: variance(*i) + akg.tvm.const(eps, dtype=variance.dtype), name="var_eps")) hat_gamma = akg.tvm.compute(gamma.shape, lambda *i: gamma(*i) * rsd(*i), name="hat_gamma", attrs={'no_inline': 1}) hat_beta = akg.tvm.compute(gamma.shape, lambda *i: beta(*i) - hat_gamma(*i) * mean(*i), name="hat_beta", attrs={'no_inline': 1}) hat_gamma_bc = akg.lang.cce.broadcast(hat_gamma, data.shape) hat_beta_bc = akg.lang.cce.broadcast(hat_beta, data.shape) data_fp32 = akg.tvm.compute(data.shape, lambda *i: data(*i).astype("float32"), name="data_fp32") bn_res_fp32 = akg.tvm.compute( data.shape, lambda *i: akg.lang.cce.vmadd(data_fp32(*i), hat_gamma_bc(*i), hat_beta_bc(*i)), name="bn_res_fp32") res = akg.tvm.compute(bn_res_fp32.shape, lambda *i: bn_res_fp32(*i).astype(ori_dtype), name="bn_res") if dim_info != "": attrs["dim"] = dim_info return res, attrs
def rsqrt_ad(head, a): b = rsqrt.rsqrt(a) _jacs = list(akg.differentiate(b, [a], head)) return _jacs[0]
def fused_batch_norm(inputs, attrs): r""" Batch normalization. See Source: <a href="https://arxiv.org/abs/1502.03167"> Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy. </a> .. math:: \begin{array}{ll} \\ \mu = \frac{1}{m} \sum^m_{i=1}{x_i} \\ \sigma^2 = \frac{1}{m} \sum^m_{i=1}{(x_i-\mu)^2} \\ \hat{x_i} = \frac{x_i - \mu}{ \sqrt{\sigma^2 + \epsilon} } \\ y_i = \gamma \hat{x_i} + \beta \equiv BN_{\gamma, \beta}(x_i) \end{array} This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is .. math:: \hat{z_{new}} = momentum \cdot \hat{z} + (1-momentum) \cdot z_t where :math:`\hat{z}` is the estimated statistic and :math:`z_t` is the new observed value. Note: When data_format is \"NC1HWC0\", the `gamma`, `beta`, `moving_mean` and `moving_variance` should be 5D tensors of shape `(1, C1, 1, 1, C0)`, otherwise, they should be 1D tensors of shape `(C,)`. Args: inputs: data (tvm.tensor.Tensor): Tensor of type float16, float32. (:math:`x_i`) gamma (tvm.tensor.Tensor): Tensor for scaling (:math:`\gamma`). beta (tvm.tensor.Tensor): Tensor for bias (:math:`\beta`). moving_mean (tvm.tensor.Tensor): Tensor for population mean used for inference. moving_variance (tvm.tensor.Tensor): Tensor for population variance used for inference. attrs: momentum (float): A float number used for the moving_mean and moving_variance computation. eps (float): A small float added to variance to avoid dividing by zero. is_training (bool): A bool value to specify if the operation is used for training or inference. data_format (str): Support format, \"DefaultFormat\", \"NCHW\", \"NHWC\" or \"NC1HWC0\". axis (Union[int, list, tuple]): Integer to specify the channel axis when data_format is \"DefaultFormat\". List or tuple for \"NC1HWC0\". When format is \"NCHW\" or \"NHWC\", it's not work. Must be in the range [-rank(data), rank(data)). single_sum (bool): whether use "mul_axis_sum". Returns: outs (tvm.tensor.Tensor): Tensor for normalized, scaled, shifted data. new_moving_mean (tvm.tensor.Tensor): Tensor of same type and shape as `moving_mean`. The `moving_mean` updated by data. Only returns when `is_training` is True. new_moving_variance (tvm.tensor.Tensor): Tensor of same type and shape as `moving_variance`. The `moving_variance` updated by data. Only returns when `is_training` is True. sample_mean (tvm.tensor.Tensor): Tensor of same type and shape as `moving_mean`. The mean of `data`. Only returns when `is_training` is True. sample_var (tvm.tensor.Tensor): Tensor of same type and shape as `moving_variance`. The variance of `data`. Only returns when `is_training` is True. """ if len(inputs) != 5: raise ValueError( "Input tensors number should be 5, but get %s." % len(inputs)) data_format = attrs.get("data_format", "DefaultFormat") params = check_inputs(inputs, data_format, attrs.get("axis", 1)) data = inputs[0] gamma = inputs[1] beta = inputs[2] moving_mean = inputs[3] moving_variance = inputs[4] ori_dtype = data.dtype shape = get_shape(data) axes = params.get("axes", (0,)) keepdims = params.get("is_special5d", False) mid_shape = params.get("mid_shape", [1, ]) data = akg.tvm.compute(data.shape, lambda *i: data(*i), "batchnorm_" + data_format) ori_moving_mean = moving_mean ori_moving_variance = moving_variance if ori_dtype != DTYPE_FLOAT32: data = akg.topi.cast(data, DTYPE_FLOAT32) gamma = akg.topi.cast(gamma, DTYPE_FLOAT32) beta = akg.topi.cast(beta, DTYPE_FLOAT32) moving_mean = akg.topi.cast(moving_mean, DTYPE_FLOAT32) moving_variance = akg.topi.cast(moving_variance, DTYPE_FLOAT32) ######## following is dsl ######## is_training = attrs.get("is_training", True) if is_training: value_num = 1 for index in axes: value_num *= shape[index] avg_num = round(float(1) / float(value_num), 12) data_square = akg.tvm.compute(data.shape, lambda *i: data(*i) * data(*i), name="data_square") # cal mean data_mean = akg.lang.ascend.vmuls( sum_data(data, axes, keepdims, attrs.get("single_sum", False)), avg_num) data_square_mean = akg.lang.ascend.vmuls(sum_data(data_square, axes, keepdims, attrs.get("single_sum", False)), avg_num) data_mean_square = akg.tvm.compute(data_mean.shape, lambda *i: data_mean(*i) * data_mean(*i), name="data_mean_square") data_variance = akg.tvm.compute(data_mean.shape, lambda *i: data_square_mean( *i) - data_mean_square(*i), name="data_variance") mean_new = update_by_moving_average( moving_mean, data_mean, attrs.get("momentum", 0.99)) variance_new = update_by_moving_average(moving_variance, data_variance, attrs.get("momentum", 0.99)) else: # no_bc version data_variance = moving_variance data_mean = moving_mean rsveps = akg.lang.ascend.vadds(data_variance, akg.tvm.const( attrs.get("eps", 1e-3), dtype=DTYPE_FLOAT32)) rsveps = rsqrt(rsveps, utils.CCE) rsveps = akg.lang.ascend.broadcast(rsveps, shape) mean2 = akg.lang.ascend.vmuls(data_mean, akg.tvm.const(-1, data.dtype)) mean2 = akg.lang.ascend.broadcast(mean2, shape) dmean = akg.tvm.compute( shape, lambda *i: data(*i) + mean2(*i), name="dmean") dmsve = akg.tvm.compute(shape, lambda *i: dmean(*i) * rsveps(*i), name="dmsve") if not keepdims: gamma = akg.topi.reshape(gamma, mid_shape) beta = akg.topi.reshape(beta, mid_shape) gamma_bc = akg.lang.ascend.broadcast(gamma, shape) beta_bc = akg.lang.ascend.broadcast(beta, shape) dmsveg = akg.tvm.compute(shape, lambda *i: dmsve(*i) * gamma_bc(*i), name="dmsveg") outs = akg.tvm.compute(shape, lambda *i: dmsveg(*i) + beta_bc(*i), name="output") out_attrs = get_attrs(outs) if is_training: if ori_dtype != DTYPE_FLOAT32: outs = akg.topi.cast(outs, ori_dtype) mean_new = akg.topi.cast(mean_new, ori_dtype) variance_new = akg.topi.cast(variance_new, ori_dtype) data_mean = akg.topi.cast(data_mean, ori_dtype) data_variance = akg.topi.cast(data_variance, ori_dtype) mean_new, binds_info_mean = TensorUtils.inplace_set( ori_moving_mean, mean_new, buffer_name="mean_buf") variance_new, binds_info_var = TensorUtils.inplace_set( ori_moving_variance, variance_new, buffer_name="var_buf") binds_info_all = binds_info_mean binds_info_all.update(binds_info_var) out_attrs[BINDS] = binds_info_all # the new moving_mean and moving_var are updated inplace in # inputs(moving_mean and moving_var). But Mindspore needs # These two fake outputs though it never uses them fake_moving_mean = akg.tvm.compute(mean_new.shape, lambda *indices: mean_new(*indices), "fake_moving_mean") fake_moving_var = akg.tvm.compute(mean_new.shape, lambda *indices: variance_new( *indices), "fake_moving_var") out_tensors = (outs, fake_moving_mean, fake_moving_var, data_mean, data_variance, mean_new, variance_new,) else: if ori_dtype != DTYPE_FLOAT32: outs = akg.topi.cast(outs, ori_dtype) out_tensors = (outs,) out_tensors = list(out_tensors) if isinstance( out_tensors, tuple) else out_tensors if shape_is_dynamic(out_tensors): out_attrs["custom_tiling"] = batch_norm_tiling_strategy_dynamic(outs) else: out_attrs["custom_tiling"] = batch_norm_tiling_strategy( outs, data_format) out_tensors.append(out_attrs) return out_tensors
def _apply_centered_rms_prop_compute(var, mg, ms, mom, grad, lr, momentum, rho, epsilon): """Compute apply_centered_rms_prop""" inp_dtype = var.dtype if inp_dtype == "float16": var = akg.lang.cce.cast_to(var, "float32") mg = akg.lang.cce.cast_to(mg, "float32") ms = akg.lang.cce.cast_to(ms, "float32") mom = akg.lang.cce.cast_to(mom, "float32") lr = akg.lang.cce.cast_to(lr, "float32") rho = akg.lang.cce.cast_to(rho, "float32") momentum = akg.lang.cce.cast_to(momentum, "float32") grad = akg.lang.cce.cast_to(grad, "float32") epsilon = akg.tvm.const(epsilon, var.dtype) tensor_one_rho = akg.tvm.compute(rho.shape, lambda *indices: rho(*indices) * akg.tvm.const(-1, rho.dtype), tag='elewise_single_VS_mul') tensor_one_rho = akg.tvm.compute( tensor_one_rho.shape, lambda *indices: tensor_one_rho(*indices) + akg.tvm.const(1, rho.dtype), tag='elewise_single_VS_add') # out_mg <- rho * mg + (1-rho) * grad mg_rho = akg.tvm.compute(mg.shape, lambda *indices: mg(*indices) * rho[0], tag='elewise_single_VS_mul') rhs = akg.tvm.compute(grad.shape, lambda *indices: grad(*indices) * tensor_one_rho[0], tag='elewise_single_VS_mul') out_mg = akg.lang.cce.vadd(mg_rho, rhs) # out_ms <- rho * ms + (1-rho) * grad * grad ms_rho = akg.tvm.compute(ms.shape, lambda *indices: ms(*indices) * rho[0], tag='elewise_single_VS_mul') rhs = akg.lang.cce.vmul(grad, grad) rhs = akg.tvm.compute(rhs.shape, lambda *indices: rhs(*indices) * tensor_one_rho[0], tag='elewise_single_VS_mul') out_ms = akg.lang.cce.vadd(ms_rho, rhs) # out_mom <- momentum * mom + lr * grad / sqrt(out_ms - out_mg * out_mg + epsilon) lhs_mom = akg.tvm.compute(mom.shape, lambda *indices: mom(*indices) * momentum[0], tag='elewise_single_VS_mul') lr_grad = akg.tvm.compute(grad.shape, lambda *indices: grad(*indices) * lr[0], tag='elewise_single_VS_mul') rhs = akg.lang.cce.vmul(out_mg, out_mg) rhs = akg.lang.cce.vsub(out_ms, rhs) rhs_eps = akg.tvm.compute(rhs.shape, lambda *indices: rhs(*indices) + epsilon, tag='elewise_single_VS_add') rhs_eps = rsqrt(rhs_eps) rhs_eps = akg.lang.cce.vmul(lr_grad, rhs_eps) out_mom = akg.lang.cce.vadd(lhs_mom, rhs_eps) # out_var <- var - out_mom out_var = akg.lang.cce.vsub(var, out_mom) if inp_dtype == "float16": out_var = akg.lang.cce.cast_to(out_var, "float16") out_mg = akg.lang.cce.cast_to(out_mg, "float16") out_ms = akg.lang.cce.cast_to(out_ms, "float16") out_mom = akg.lang.cce.cast_to(out_mom, "float16") return out_var, out_mg, out_ms, out_mom