Exemple #1
0
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay,
                          param, m, v, gradient, decay_flag, optim_filter):
    """
    Update parameters function when device target is ascend.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.
        optim_filter(bool): Applies parameter update or not.

    Returns:
        Tensor, the new value of v after updating.
    """
    if optim_filter:
        op_cast = P.Cast()
        op_norm = layer.Norm()
        op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
        op_lamb_apply_weight_assign = P.LambApplyWeightAssign()

        param_fp32 = op_cast(param, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)
        new_global_step = op_cast(global_step + num_one, mstype.float32)
        weight_decay_flag = op_cast(decay_flag, mstype.float32)

        update, _, _ = op_lamb_apply_optimizer_assign(
            gradient_fp32, v, m, param_fp32, beta1, 1.0 - beta1, beta2,
            1.0 - beta2, eps, new_global_step, weight_decay_flag, weight_decay)
        w_norm = op_norm(param_fp32)
        g_norm = op_norm(update)
        update = F.depend(
            update,
            op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param))
        return update
    return gradient
 def __init__(self):
     super(Net, self).__init__()
     self.lamb_apply_weight_assign = P.LambApplyWeightAssign()