def construct(self, gradients, overflow): """AdamWeightDecayForBert""" lr = self.get_lr() cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ self.op_reshape(overflow, (())), mstype.bool_) beta1 = self.op_select( cond, self.op_cast(F.tuple_to_array((1.0, )), mstype.float32), self.beta1) beta2 = self.op_select( cond, self.op_cast(F.tuple_to_array((1.0, )), mstype.float32), self.beta2) if self.is_group: if self.is_group_lr: optim_result = self.hyper_map( F.partial(_adam_opt, self.beta1, self.beta2, self.eps), lr, self.weight_decay, self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: optim_result = self.hyper_map( F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), self.weight_decay, self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: optim_result = self.hyper_map( F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) if self.use_parallel: self.broadcast_params(optim_result) return optim_result
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. Args: beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. overflow (Tensor): Whether overflow occurs. weight_decay (Number): Weight decay. Should be equal to or greater than 0. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. decay_flag (bool): Applies weight decay or not. optim_filter (bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ if optim_filter: op_mul = P.Mul() op_square = P.Square() op_sqrt = P.Sqrt() op_cast = P.Cast() op_reshape = P.Reshape() op_shape = P.Shape() op_select = P.Select() param_fp32 = op_cast(param, mstype.float32) m_fp32 = op_cast(m, mstype.float32) v_fp32 = op_cast(v, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32) cond = op_cast( F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) update = next_m / (eps + op_sqrt(next_v)) if decay_flag: update = op_mul(weight_decay, param_fp32) + update update_with_lr = op_mul(lr, update) zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) next_param = param_fp32 - op_select( cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) next_param = F.depend( next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) return op_cast(next_param, F.dtype(param)) return gradient