def softmax_cross_entropy_with_logits(labels, logits, axis, reduction="mean", scale=1.0): max_logits = reduce_max(logits, axis, keepdims=True, target=utils.CCE) data_sub = sub(logits, max_logits, target=utils.CCE) akg.register_variables("minus_max", [logits], data_sub) data_exp = Exp(data_sub, target=utils.CCE) data_expsum = sum(data_exp, axis, keepdims=True, target=utils.CCE) data_expsum_log = log(data_expsum, target=utils.CCE) sub_value = sub(data_sub, data_expsum_log, target=utils.CCE) neg_labels = neg(labels, target=utils.CCE) cross_entropy = mul(neg_labels, sub_value, target=utils.CCE) # backprop: prob - labels, where prob = softmax(logits) prob = Exp(sub_value, target=utils.CCE) backprop = sub(prob, labels, target=utils.CCE) if reduction.lower() == "none": loss = sum_v2(cross_entropy, axis, keepdims=True) elif reduction.lower() == "mean": loss = sum_v2(cross_entropy, axis=None) factor = logits.shape[0].value loss = loss * akg.tvm.const(1 / factor, logits.dtype) backprop = backprop * akg.tvm.const(1 / factor, logits.dtype) elif reduction.lower() == "sum": loss = sum_v2(cross_entropy, axis=None) else: raise ValueError( "reduction method {0} is not supported".format(reduction)) backprop = akg.topi.multiply(backprop, akg.tvm.const(scale, backprop.dtype)) return loss, backprop
def _before_res_compute(abs_data): """ compute bessel_i1e for abs value of data less than or equal to 3.75 Algrithm: t = x / 3.75 I1(x) = e^-|x|*x*(0.5 + 0.87890594t^2 + 0.51498869t^4 + 0.15084934t^6 + 0.02658773t^8 + 0.00301532t^10 + 0.00032411t^12) """ data = topi.multiply(abs_data, 1.0 / CONST_LIMIT) data_square = mul(data, data) before_res = topi.multiply(data_square, ITR_BEFORE[LEN_BEFORE - 1]) before_res = topi.add(before_res, ITR_BEFORE[LEN_BEFORE - 2]) for iter_number in ITR_BEFORE[LEN_BEFORE - 3::-1]: before_res = mul(before_res, data_square) before_res = topi.add(before_res, iter_number) exp_value = exp(neg(abs_data)) before_res = mul(before_res, exp_value) before_res = mul(before_res, abs_data) return before_res
def LambApplyOptimizerAssign(grad, input_v, input_m, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate): # compute next_v square_grad = topi.multiply(grad, grad) # mul_3 mul_3_result = topi.multiply(square_grad, one_minus_beta_2) # mul_2 mul_2_result = topi.multiply(input_v, beta_2) # compute: next_v = (multiply(self.beta_2, v) + multiply(1.0 - self.beta_2, square(grad))) next_v = topi.add(mul_2_result, mul_3_result) # compute next_m mul_0_result = topi.multiply(input_m, beta_1) # mul_1 mul_1_result = topi.multiply(grad, one_minus_beta_1) # compute: next_m = (multiply(self.beta_1, m) + multiply(1.0 - self.beta_1, grad)) next_m = topi.add(mul_0_result, mul_1_result) const_one = akg.tvm.const(1.0, input_v.dtype) # compute: beta1_correction = (1 - self.beta_1 ** steps) beta_1_steps = pow_compute(beta_1, steps, grad) neg_beta_1_step = neg(beta_1_steps, utils.CCE) beta1_correction = topi.add(neg_beta_1_step, const_one) # compute: beta2_correction = (1 - self.beta_2 ** steps) beta_2_steps = pow_compute(beta_2, steps, grad) neg_beta_2_step = neg(beta_2_steps, utils.CCE) beta2_correction = topi.add(neg_beta_2_step, const_one) # compute: next_m_unbiased = next_m / beta1_correction next_m_unbiased = Divide(next_m, beta1_correction, utils.CCE) # compute: next_v_unbiased = next_v / beta2_correction next_v_unbiased = Divide(next_v, beta2_correction, utils.CCE) # compute update sqrt_next_v = topi.sqrt(next_v_unbiased) # add_2 add_2_result = topi.add(sqrt_next_v, epsilon) # compute: update = next_m / (sqrt(next_v) + self.epsilon) update = Divide(next_m_unbiased, add_2_result, utils.CCE) # compute do_use_weight_decay do_use_weight_mul = topi.multiply(input_param, weight_decay_rate) do_use_weight_decay = topi.multiply(do_use_weight_mul, do_use_weight) update = topi.add(do_use_weight_decay, update) attrs = {'enable_auto_inline': False} dim_info, _ = lamb_apply_optimizer_assign_set_dim_func(grad) if dim_info != "": attrs["dim"] = dim_info return update, next_v, next_m, attrs
def Neg(x): """neg""" return neg.neg(x)