Exemple #1
0
 def construct(self, gradients, overflow):
     """AdamWeightDecayForBert"""
     lr = self.get_lr()
     cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
                         self.op_reshape(overflow, (())), mstype.bool_)
     beta1 = self.op_select(
         cond, self.op_cast(F.tuple_to_array((1.0, )), mstype.float32),
         self.beta1)
     beta2 = self.op_select(
         cond, self.op_cast(F.tuple_to_array((1.0, )), mstype.float32),
         self.beta2)
     if self.is_group:
         if self.is_group_lr:
             optim_result = self.hyper_map(
                 F.partial(_adam_opt, self.beta1, self.beta2,
                           self.eps), lr, self.weight_decay,
                 self.parameters, self.moments1, self.moments2, gradients,
                 self.decay_flags, self.optim_filter)
         else:
             optim_result = self.hyper_map(
                 F.partial(_adam_opt, beta1, beta2, self.eps, lr,
                           overflow), self.weight_decay, self.parameters,
                 self.moments1, self.moments2, gradients, self.decay_flags,
                 self.optim_filter)
     else:
         optim_result = self.hyper_map(
             F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
                       self.weight_decay), self.parameters, self.moments1,
             self.moments2, gradients, self.decay_flags, self.optim_filter)
     if self.use_parallel:
         self.broadcast_params(optim_result)
     return optim_result
Exemple #2
0
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v,
                   gradient, decay_flag, optim_filter):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        overflow (Tensor): Whether overflow occurs.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Applies weight decay or not.
        optim_filter (bool): Applies parameter update or not.

    Returns:
        Tensor, the new value of v after updating.
    """
    if optim_filter:
        op_mul = P.Mul()
        op_square = P.Square()
        op_sqrt = P.Sqrt()
        op_cast = P.Cast()
        op_reshape = P.Reshape()
        op_shape = P.Shape()
        op_select = P.Select()

        param_fp32 = op_cast(param, mstype.float32)
        m_fp32 = op_cast(m, mstype.float32)
        v_fp32 = op_cast(v, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)

        cond = op_cast(
            F.fill(mstype.int32, op_shape(m_fp32), 1) *
            op_reshape(overflow, (())), mstype.bool_)
        next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))

        next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))

        update = next_m / (eps + op_sqrt(next_v))
        if decay_flag:
            update = op_mul(weight_decay, param_fp32) + update

        update_with_lr = op_mul(lr, update)
        zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
        next_param = param_fp32 - op_select(
            cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))

        next_param = F.depend(
            next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
        next_param = F.depend(next_param,
                              F.assign(m, op_cast(next_m, F.dtype(m))))
        next_param = F.depend(next_param,
                              F.assign(v, op_cast(next_v, F.dtype(v))))

        return op_cast(next_param, F.dtype(param))
    return gradient