コード例 #1
0
def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr,
                                weight_decay, param, m, v, gradient,
                                decay_flag):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.

    Returns:
        Tensor, the new value of v after updating.
    """
    op_mul = P.Mul()
    op_square = P.Square()
    op_cast = P.Cast()
    op_shape = P.Shape()
    op_pow = P.Pow()
    op_norm = layer.Norm()
    op_fill = P.Fill()
    op_dtype = P.DType()

    param_fp32 = op_cast(param, mstype.float32)
    gradient_fp32 = op_cast(gradient, mstype.float32)

    i6_ex = op_cast(global_step + num_one, mstype.float32)
    i9 = op_cast(num_one, mstype.float32) - beta1
    x1 = op_cast(num_one, mstype.float32) - beta2
    i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex)
    i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex)
    i1 = op_square(gradient_fp32)
    add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9,
                                  beta2, x1, weight_decay, eps)

    if decay_flag:
        update = update + op_mul(weight_decay, param_fp32)

    w_norm = op_norm(param_fp32)
    g_norm = op_norm(gradient_fp32)
    g_norm_hat = op_norm(add3)

    zeros = F.zeros_like(w_norm)
    ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
    tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0)

    next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update,
                                      param, zeros, ones, tens)
    next_v = F.control_depend(add3, next_param)
    return next_v
コード例 #2
0
ファイル: loss.py プロジェクト: dongkcs/mindspore
    def construct(self, x1, x2, y):
        F.same_type_shape(x1, x2)
        _check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
        # if target > 0, 1-cosine(x1, x2)
        # else, max(0, cosine(x1, x2)-margin)
        prod_sum = self.reduce_sum(x1 * x2, (1,))
        square1 = self.reduce_sum(F.square(x1), (1,))
        square2 = self.reduce_sum(F.square(x2), (1,))
        denom = F.sqrt(square1 * square2)
        cosine = prod_sum / denom

        pos_value = 1.0 - cosine
        neg_value = self.maximum(cosine - self.margin, 0.0)
        zeros = F.zeros_like(cosine)
        pos_part = F.select(y == 1, pos_value, zeros)
        neg_part = F.select(y == -1, neg_value, zeros)
        output_unreduced = pos_part + neg_part

        return self.get_loss(output_unreduced)
コード例 #3
0
def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m,
                   v, gradient, decay_flag, optim_filter):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.
        optim_filter(bool): Applies parameter update or not.

    Returns:
        Tensor, the new value of v after updating.
    """
    if optim_filter:
        op_mul = P.Mul()
        op_sqrt = P.Sqrt()
        op_rsqrt = P.Rsqrt()
        op_square = P.Square()
        op_cast = P.Cast()
        op_reshape = P.Reshape()
        op_shape = P.Shape()
        op_pow = P.Pow()
        op_norm = layer.Norm()
        op_select = P.Select()
        op_greater = P.Greater()
        op_fill = P.Fill()
        op_dtype = P.DType()

        param_fp32 = op_cast(param, mstype.float32)
        m_fp32 = op_cast(m, mstype.float32)
        v_fp32 = op_cast(v, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)

        next_m = op_mul(beta1, m_fp32) + op_mul(
            op_cast(num_one, mstype.float32) - beta1, gradient_fp32)

        next_v = op_mul(beta2, v_fp32) + op_mul(
            op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))

        next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow(
            beta1, op_cast(global_step + num_one, mstype.float32)))
        next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow(
            beta2, op_cast(global_step + num_one, mstype.float32)))
        w_norm = op_norm(param_fp32)
        g_norm = op_norm(gradient_fp32)

        g_norm_hat = op_norm(
            op_mul(next_mm, op_rsqrt(next_vv + eps)) +
            weight_decay * param_fp32)
        zeros = F.zeros_like(w_norm)
        ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
        trust_ratio = op_select(
            op_greater(w_norm, zeros),
            op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
            ones)
        tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
        trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
        update = next_mm / (op_sqrt(next_vv) + eps)

        if decay_flag:
            update = update + op_mul(weight_decay, param_fp32)

        update_with_lr = op_mul(op_mul(trust_ratio, lr), update)

        next_param = param_fp32 - op_reshape(update_with_lr,
                                             op_shape(param_fp32))

        next_param = F.depend(
            next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
        next_param = F.depend(next_param,
                              F.assign(m, op_cast(next_m, F.dtype(m))))
        next_param = F.depend(next_param,
                              F.assign(v, op_cast(next_v, F.dtype(v))))

        return op_cast(next_param, F.dtype(param))
    return gradient