예제 #1
0
 def construct(self, x):
     out_conv = self.conv(x, self.weight)
     # BN fold1
     batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(
         out_conv, self.moving_mean, self.moving_variance, self.step)
     # fake weight
     weight = self.correct_mul(self.weight, self.gamma, running_std)
     if self.fake:
         weight = self.fake_quant_weight(weight)
     out = self.conv(x, weight)
     # BN fold2
     if self.is_gpu:
         if self.training:
             out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
                                              batch_std, batch_mean,
                                              running_std, running_mean,
                                              self.step)
             F.control_depend(out, self.assignadd(self.step, self.one))
         else:
             out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
                                              batch_std, batch_mean,
                                              running_std, running_mean,
                                              self.step)
     else:
         if self.training:
             out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
                                              batch_std, batch_mean,
                                              running_std)
             F.control_depend(out, self.assignadd(self.step, self.one))
         else:
             out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
                                              running_std, running_mean,
                                              running_std)
     return out
예제 #2
0
    def broadcast_params(self, optim_result):
        """
        Apply Broadcast operations in the sequential order of parameter groups.

        Returns:
             bool, the status flag.
        """
        param_group = []
        key_group = []
        for _ in range(self.dev_num):
            param_group.append(F.make_tuple())
            key_group.append(F.make_tuple())
        for i in range(self.param_length):
            param_group[self.param_rank[i]] = param_group[
                self.param_rank[i]] + (self.parameters[i], )
            key = P.MakeRefKey(self.param_names[i])()
            key_group[
                self.param_rank[i]] = key_group[self.param_rank[i]] + (key, )
        new_param_group = []
        for root in range(self.dev_num):
            ops = P.Broadcast(root)
            next_params = ops(param_group[root])
            new_param_group.append(next_params)
            for i in range(F.tuple_len(next_params)):
                F.assign(key_group[root][i], next_params[i])
        status = F.control_depend(optim_result, new_param_group[0][0])
        for i in range(self.dev_num - 1):
            status = F.depend(
                F.control_depend(new_param_group[i],
                                 new_param_group[i + 1][0]), status)

        return status
예제 #3
0
 def construct(self, gradients):
     params = self.parameters
     if self.weight_decay > 0:
         gradients = self.hyper_map(
             F.partial(apply_decay, self.weight_decay), self.decay_tf,
             params, gradients)
     if self.reciprocal_scale != 1.0:
         gradients = self.hyper_map(
             F.partial(grad_scale, self.reciprocal_scale), gradients)
     if self.dynamic_lr:
         lr = self.gather(self.learning_rate, self.global_step, self.axis)
         F.control_depend(lr, self.assignadd(self.global_step, self.one))
     else:
         lr = self.learning_rate
     if self.centered:
         success = self.hyper_map(
             F.partial(centered_rmsprop_opt, self.opt, lr, self.decay,
                       self.epsilon, self.momentum), params, self.mg,
             self.ms, self.moment, gradients)
     else:
         success = self.hyper_map(
             F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
                       self.momentum), params, self.ms, self.moment,
             gradients)
     return success
예제 #4
0
    def construct(self, gradients):
        step = self.min(self.global_step, self.decay_steps)
        p = step / self.decay_steps
        lr = self.diff_learning_rate * \
            self.pow(self.one - p, self.power) + self.end_learning_rate
        if self.warmup_flag:
            warmup_percent = self.global_step / self.warmup_steps
            warmup_lr = self.start_learning_rate * warmup_percent
            is_warmup = self.cast(self.greater(
                self.warmup_steps, self.global_step), mstype.float32)
            lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
        if self.enable_graph_kernel:
            optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
                                                    self.beta1, self.beta2, self.eps, lr,
                                                    self.weight_decay_tensor, self.global_step),
                                          self.params, self.moments1, self.moments2, gradients, self.decay_flag)
        else:
            optim_result = self.hyper_map(F.partial(_lamb_opt,
                                                    self.beta1, self.beta2, self.eps, lr,
                                                    self.weight_decay_tensor, self.global_step),
                                          self.params, self.moments1, self.moments2, gradients,
                                          self.decay_flag, self.optim_filter)
        if self.use_parallel:
            optim_result = self.broadcast_params(optim_result)

        added_global_step = self.global_step + self.one
        F.control_depend(lr, added_global_step)
        self.global_step = added_global_step

        return optim_result
예제 #5
0
    def construct(self, gradients):
        params = self.parameters
        moment1 = self.moment1
        moment2 = self.moment2
        if self.weight_decay > 0:
            gradients = self.hyper_map(
                F.partial(apply_decay, self.weight_decay), self.decay_tf,
                params, gradients)
        if self.reciprocal_scale != 1.0:
            gradients = self.hyper_map(
                F.partial(grad_scale, self.reciprocal_scale), gradients)

        lr = self.learning_rate
        if self.dynamic_lr:
            lr = self.gather(self.learning_rate, self.global_step, self.axis)
            F.control_depend(lr, self.assignadd(self.global_step, self.one))

        beta1_power = self.beta1_power * self.beta1
        self.beta1_power = beta1_power
        beta2_power = self.beta2_power * self.beta2
        self.beta2_power = beta2_power
        success = self.hyper_map(
            F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power,
                      self.beta1, self.beta2, self.eps), gradients, params,
            moment1, moment2)

        return success
예제 #6
0
파일: quant.py 프로젝트: lyc4614/mindspore
 def construct(self, x):
     if self.training:
         beta = self.beta
         gamma = self.gamma
         gmean = self.moving_mean
         gvar = self.moving_variance
         step = self.step
         out_conv = self.conv(x, self.weight)
         batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train(
             out_conv, gmean, gvar, step)
         # BN fold1
         weight = self.correct_mul(self.weight, gamma, running_std)
         if self.fake:
             weight = self.fake_quant_weight(weight)
         out = self.conv(x, weight)
         # BN fold2
         out = self.batchnorm_fold2(out, beta, gamma, batch_std, batch_mean,
                                    running_std, running_mean, step)
         F.control_depend(out, self.assignadd(self.step, self.one))
     else:
         step = self.step
         batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
             x, self.moving_mean, self.moving_variance, step)
         weight = self.correct_mul(self.weight, self.gamma, running_std)
         if self.fake:
             weight = self.fake_quant_weight(weight)
         out = self.conv(x, weight)
         out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
                                          batch_std, batch_mean,
                                          running_std, running_mean, step)
     return out
def tensor_grad_scale(scale, grad, accu_grad):
    #mul = P.Mul()
    new_grad = accu_grad * reciprocal(scale)
    zeros = F.tensor_mul(accu_grad, 0.0)
    clear = F.assign(accu_grad, zeros)
    F.control_depend(new_grad, clear)
    F.control_depend(grad, new_grad)
    return new_grad
    def construct(self, x, b, sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(x, b)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        # update accumulation parameters
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
        self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
        self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
        mean_loss = self.accu_loss / self.local_step
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)

        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        self.clear_before_grad(init)
        grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))

        accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
        mean_loss = F.depend(mean_loss, accu_succ)

        self.get_status(init)
        flag_sum = self.reduce_sum(init, (0,))
        overflow = self.less_equal(self.base, flag_sum)
        overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)
        self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
        is_accu_step = self.reshape(is_accu_step, (()))

        if is_accu_step:
            succ = False
        else:
            # apply grad reducer on grads
            grads = self.grad_reducer(self.accu_grads)
            scaling = scaling_sens * self.degree * self.accumulation_steps
            grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
            grads = ClipByGlobalNorm()(grads)
            accu_overflow = self.overflow_reducer(accu_overflow)
            F.control_depend(grads, accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
            overflow = F.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            if sens is None:
                overflow = self.loss_scaling_manager(self.loss_scale, overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (mean_loss, overflow, scaling_sens)
        return F.depend(ret, succ)
예제 #9
0
    def get_lr(self):
        """
        Get the learning rate of current step.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            lr = self.gather(self.learning_rate, self.global_step, 0)
            F.control_depend(lr, self.assignadd(self.global_step, 1))

        return lr
예제 #10
0
    def construct(self, gradients):
        step = self.min(self.global_step, self.decay_steps)
        p = step / self.decay_steps
        lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
        updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
                                                    self.weight_decay_tensor),
                                          self.params, self.moments1, self.moments2, gradients, self.decay_flag)

        added_global_step = self.global_step + self.one
        F.control_depend(lr, added_global_step)
        self.global_step = added_global_step

        return updated_velocity
    def construct(self,
                  input_ids,
                  input_position,
                  attention_mask,
                  past=None,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids, input_position, attention_mask)
        if sens is None:
            scaling_sens = self.loss_scale
            scaling_sens = self.reshape(scaling_sens, (1,))
        else:
            scaling_sens = sens
        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        status_clear = self.clear_before_grad(init)
        #clear_depend = self.control(status_clear, self.weights)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_position,
                                                 attention_mask,
                                                 self.cast(scaling_sens / self.micro_size,
                                                           mstype.float32))
        get_status = self.get_status(init)
        get_status_depend = F.control_depend(grads, get_status)
        flag_sum = self.reduce_sum(init, (0,))
        flag_sum_depend = F.control_depend(get_status, flag_sum)
        loss = F.depend(loss, status_clear)
        loss = F.depend(loss, get_status_depend)
        loss = F.depend(loss, flag_sum_depend)
        # apply grad reducer on grads
        accu_grads = self.grad_reducer(self.accu_grads)
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)

        grads, global_norms = self.clip(grads)
        global_norm = P.Reshape()(global_norms, (()))
        if self.is_distributed:
            # sum overflow flag over devices
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, overflow, scaling_sens, global_norm)
        return F.depend(ret, succ)
예제 #12
0
    def construct(self, gradients):
        lr = self.get_lr()
        if self.enable_graph_kernel:
            if self.is_group:
                if self.is_group_lr:
                    optim_result = self.hyper_map(
                        F.partial(lamb_opt_graph_kernel, self.beta1,
                                  self.beta2, self.eps, self.global_step), lr,
                        self.weight_decay, self.params, self.moments1,
                        self.moments2, gradients, self.decay_flags)
                else:
                    optim_result = self.hyper_map(
                        F.partial(lamb_opt_graph_kernel, self.beta1,
                                  self.beta2, self.eps, self.global_step, lr),
                        self.weight_decay, self.params, self.moments1,
                        self.moments2, gradients, self.decay_flags)
            else:
                optim_result = self.hyper_map(
                    F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2,
                              self.eps, self.global_step, lr,
                              self.weight_decay), self.params, self.moments1,
                    self.moments2, gradients, self.decay_flags)
        else:
            if self.is_group:
                if self.is_group_lr:
                    optim_result = self.hyper_map(
                        F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
                                  self.global_step), lr, self.weight_decay,
                        self.params, self.moments1, self.moments2, gradients,
                        self.decay_flags, self.optim_filter)
                else:
                    optim_result = self.hyper_map(
                        F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
                                  self.global_step, lr), self.weight_decay,
                        self.params, self.moments1, self.moments2, gradients,
                        self.decay_flags, self.optim_filter)
            else:
                optim_result = self.hyper_map(
                    F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
                              self.global_step, lr, self.weight_decay),
                    self.params, self.moments1, self.moments2, gradients,
                    self.decay_flags, self.optim_filter)

        if self.use_parallel:
            self.broadcast_params(optim_result)

        if not self.dynamic_lr:
            F.control_depend(lr, self.assignadd(self.global_step, 1))

        return optim_result
예제 #13
0
파일: lars.py 프로젝트: Xylonwang/mindspore
    def construct(self, gradients):
        params = self.parameters
        if self.dynamic_lr:
            lr = self.gather(self.learning_rate, self.global_step, self.axis)
            F.control_depend(lr, self.assignadd(self.global_step, 1))
        else:
            lr = self.learning_rate
        if self.reciprocal_scale != 1.0:
            gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

        grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr),
                                gradients, params, self.decay_flag, self.lars_flag)
        success = self.opt(grad_t)

        return success
예제 #14
0
파일: sgd.py 프로젝트: zky001/mindspore
 def construct(self, gradients):
     params = self.params
     accum = self.accum
     stat = self.stat
     if self.reciprocal_scale != 1.0:
         gradients = self.hyper_map(
             F.partial(grad_scale, self.reciprocal_scale), gradients)
     if self.dynamic_lr:
         lr = self.gather(self.learning_rate, self.global_step, self.axis)
         F.control_depend(lr, self.assignadd(self.global_step, 1))
     else:
         lr = self.learning_rate
     success = self.hyper_map(
         F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params,
         accum, stat)
     return success
예제 #15
0
def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr,
                                weight_decay, param, m, v, gradient,
                                decay_flag):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.

    Returns:
        Tensor, the new value of v after updating.
    """
    op_mul = P.Mul()
    op_square = P.Square()
    op_cast = P.Cast()
    op_shape = P.Shape()
    op_pow = P.Pow()
    op_norm = layer.Norm()
    op_fill = P.Fill()
    op_dtype = P.DType()

    param_fp32 = op_cast(param, mstype.float32)
    gradient_fp32 = op_cast(gradient, mstype.float32)

    i6_ex = op_cast(global_step + num_one, mstype.float32)
    i9 = op_cast(num_one, mstype.float32) - beta1
    x1 = op_cast(num_one, mstype.float32) - beta2
    i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex)
    i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex)
    i1 = op_square(gradient_fp32)
    add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9,
                                  beta2, x1, weight_decay, eps)

    if decay_flag:
        update = update + op_mul(weight_decay, param_fp32)

    w_norm = op_norm(param_fp32)
    g_norm = op_norm(gradient_fp32)
    g_norm_hat = op_norm(add3)

    zeros = F.zeros_like(w_norm)
    ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
    tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0)

    next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update,
                                      param, zeros, ones, tens)
    next_v = F.control_depend(add3, next_param)
    return next_v
예제 #16
0
    def construct(self, gradients):
        step = self.min(self.global_step, self.decay_steps)
        p = step / self.decay_steps
        lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
        if self.warmup_flag:
            warmup_percent = self.global_step / self.warmup_steps
            warmup_lr = self.start_learning_rate * warmup_percent
            is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
            lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
        updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
                                                    self.weight_decay_tensor),
                                          self.params, self.moments1, self.moments2, gradients, self.decay_flag)

        added_global_step = self.global_step + self.one
        F.control_depend(lr, added_global_step)
        self.global_step = added_global_step

        return updated_velocity
예제 #17
0
    def get_lr(self):
        """
        Get the learning rate of current step.

        Returns:
            float, the learning rate of current step.
        """
        lr = self.learning_rate
        if self.dynamic_lr:
            if self.is_group_lr:
                lr = ()
                for learning_rate in self.learning_rate:
                    current_dynamic_lr = learning_rate(self.global_step)
                    lr += (current_dynamic_lr,)
            else:
                lr = self.learning_rate(self.global_step)

            F.control_depend(lr, self.assignadd(self.global_step, self.global_step_increase_tensor))
        return lr
예제 #18
0
    def get_lr(self):
        """
        Get the learning rate of current step.

        Returns:
            float, the learning rate of current step.
        """
        if self.is_group_lr:
            lr = self.learning_rate
            if self.dynamic_lr:
                lr = ()
                for i in range(self.param_length):
                    current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
                    lr += (current_dynamic_lr,)
                F.control_depend(lr, self.assignadd(self.global_step, 1))

        else:
            lr = self.learning_rate
            if self.dynamic_lr:
                lr = self.gather(self.learning_rate, self.global_step, 0)
                F.control_depend(lr, self.assignadd(self.global_step, 1))
        return lr
예제 #19
0
 def construct(self,
               input_ids,
               input_mask,
               token_type_id,
               start_position,
               end_position,
               unique_id,
               is_impossible,
               sens=None):
     """BertSquad"""
     weights = self.weights
     init = self.alloc_status()
     loss = self.network(input_ids,
                         input_mask,
                         token_type_id,
                         start_position,
                         end_position,
                         unique_id,
                         is_impossible)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     grads = self.grad(self.network, weights)(input_ids,
                                              input_mask,
                                              token_type_id,
                                              start_position,
                                              end_position,
                                              unique_id,
                                              is_impossible,
                                              self.cast(scaling_sens,
                                                        mstype.float32))
     clear_before_grad = self.clear_before_grad(init)
     F.control_depend(loss, init)
     self.depend_parameter_use(clear_before_grad, scaling_sens)
     grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
     grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
     if self.reducer_flag:
         grads = self.grad_reducer(grads)
     flag = self.get_status(init)
     flag_sum = self.reduce_sum(init, (0,))
     if self.is_distributed:
         flag_reduce = self.allreduce(flag_sum)
         cond = self.less_equal(self.base, flag_reduce)
     else:
         cond = self.less_equal(self.base, flag_sum)
     F.control_depend(grads, flag)
     F.control_depend(flag, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     ret = (loss, cond)
     return F.depend(ret, succ)
예제 #20
0
    def construct(self,
                  input_ids,
                  input_mask,
                  label_ids,
                  sens=None):
        """Bert Finetune"""

        weights = self.weights
        init = False
        loss = self.network(input_ids,
                            input_mask,
                            label_ids)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        if not self.gpu_target:
            init = self.alloc_status()
            clear_before_grad = self.clear_before_grad(init)
            F.control_depend(loss, init)
            self.depend_parameter_use(clear_before_grad, scaling_sens)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 label_ids,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            grads = self.grad_reducer(grads)
        if not self.gpu_target:
            flag = self.get_status(init)
            flag_sum = self.reduce_sum(init, (0,))
            F.control_depend(grads, flag)
            F.control_depend(flag, flag_sum)
        else:
            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
            flag_sum = self.addn(flag_sum)
            flag_sum = self.reshape(flag_sum, (()))
        if self.is_distributed:
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond)
        return F.depend(ret, succ)
예제 #21
0
 def construct(self,
               input_ids,
               input_mask,
               token_type_id,
               label_ids,
               sens=None):
     """Defines the computation performed."""
     weights = self.weights
     saved = ()
     for i in range(self.length):
         saved = saved + (F.assign(self.saved_params[i], weights[i]), )
     assign_embedding = ()
     for i in range(self.quant_embedding_list_length):
         quant_embedding = self.quantize_embedding(
             weights[self.quant_embedding_list[i]])
         assign_embedding = assign_embedding + (F.assign(
             weights[self.quant_embedding_list[i]], quant_embedding), )
         F.control_depend(saved, assign_embedding[i])
     assign_weight = ()
     for i in range(self.quant_weight_list_length):
         quant_weight = self.quantize_weight(
             weights[self.quant_weight_list[i]])
         assign_weight = assign_weight + (F.assign(
             weights[self.quant_weight_list[i]], quant_weight), )
         F.control_depend(saved, assign_weight[i])
     for i in range(self.quant_embedding_list_length):
         F.control_depend(assign_embedding[i], input_ids)
     for i in range(self.quant_weight_list_length):
         F.control_depend(assign_weight[i], input_ids)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     # alloc status and clear should be right before grad operation
     init = self.alloc_status()
     self.clear_before_grad(init)
     grads = self.grad(self.network,
                       weights)(input_ids, input_mask, token_type_id,
                                label_ids,
                                self.cast(scaling_sens, mstype.float32))
     F.control_depend(input_ids, grads)
     # apply grad reducer on grads
     grads = self.grad_reducer(grads)
     grads = self.hyper_map(
         F.partial(grad_scale, scaling_sens * self.degree), grads)
     grads = self.hyper_map(
         F.partial(clip_grad, gradient_cfg.clip_type,
                   gradient_cfg.clip_value), grads)
     restore = ()
     for i in range(self.length):
         restore = restore + (F.assign(weights[i], self.saved_params[i]), )
         F.control_depend(grads, restore[i])
     self.get_status(init)
     flag_sum = self.reduce_sum(init, (0, ))
     if self.is_distributed:
         # sum overflow flag over devices
         flag_reduce = self.allreduce(flag_sum)
         cond = self.less_equal(self.base, flag_reduce)
     else:
         cond = self.less_equal(self.base, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     for i in range(self.length):
         F.control_depend(restore[i], succ)
     return succ
예제 #22
0
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids, input_mask, token_type_id,
                            next_sentence_labels, masked_lm_positions,
                            masked_lm_ids, masked_lm_weights)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        # update accumulation parameters
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
        self.local_step = self.select(is_accu_step, self.local_step + self.one,
                                      self.one)
        self.loss = self.select(is_accu_step, self.loss + loss, loss)
        mean_loss = self.loss / self.local_step
        is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)

        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        self.clear_before_grad(init)
        grads = self.grad(self.network,
                          weights)(input_ids, input_mask, token_type_id,
                                   next_sentence_labels, masked_lm_positions,
                                   masked_lm_ids, masked_lm_weights,
                                   self.cast(scaling_sens, mstype.float32))

        accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
        mean_loss = F.depend(mean_loss, accu_succ)

        self.get_status(init)
        flag_sum = self.reduce_sum(init, (0, ))
        overflow = self.less_equal(self.base, flag_sum)
        overflow = self.logical_or(
            self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)
        self.accu_overflow = self.select(is_accu_step, accu_overflow,
                                         self.zero)

        if is_accu_step:
            succ = False
        else:
            # apply grad reducer on grads
            grads = self.grad_reducer(self.accu_grads)
            scaling = scaling_sens * self.degree * self.accumulation_steps
            grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
            if self.enable_global_norm:
                grads = C.clip_by_global_norm(grads, 1.0, None)
            else:
                grads = self.hyper_map(
                    F.partial(clip_grad, GRADIENT_CLIP_TYPE,
                              GRADIENT_CLIP_VALUE), grads)
            accu_overflow = self.overflow_reducer(accu_overflow)
            F.control_depend(grads, accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
            overflow = F.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            if sens is None:
                overflow = self.loss_scaling_manager(self.loss_scale, overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (mean_loss, overflow, scaling_sens)
        return F.depend(ret, succ)
예제 #23
0
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking,
                         use_nesterov, target, beta1_power, beta2_power, beta1,
                         beta2, eps, lr, gradient, param, m, v, ps_parameter):
    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
    success = True
    indices = gradient.indices
    values = gradient.values
    if ps_parameter:
        op_shape = P.Shape()
        shapes = (op_shape(param), op_shape(m), op_shape(v),
                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr),
                  op_shape(beta1), op_shape(beta2), op_shape(eps),
                  op_shape(values), op_shape(indices))
        success = F.depend(
            success,
            pull(
                push((beta1_power, beta2_power, lr, beta1, beta2, eps, values,
                      indices), shapes), param))
        return success

    if not target:
        success = F.depend(
            success,
            sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
                       eps, values, indices))
    else:
        op_mul = P.Mul()
        op_square = P.Square()
        op_sqrt = P.Sqrt()
        scatter_add = P.ScatterAdd(use_locking)

        assign_m = F.assign(m, op_mul(beta1, m))
        assign_v = F.assign(v, op_mul(beta2, v))

        grad_indices = gradient.indices
        grad_value = gradient.values

        next_m = scatter_add(
            m, grad_indices,
            op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value))

        next_v = scatter_add(
            v, grad_indices,
            op_mul(F.tuple_to_array((1.0, )) - beta2, op_square(grad_value)))

        if use_nesterov:
            m_temp = next_m * _scaler_ten
            assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
            div_value = scatter_add(
                m, op_mul(grad_indices, _scaler_one),
                op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value))
            param_update = div_value / (op_sqrt(next_v) + eps)

            m_recover = F.assign(m, m_temp / _scaler_ten)

            F.control_depend(m_temp, assign_m_nesterov)
            F.control_depend(assign_m_nesterov, div_value)
            F.control_depend(param_update, m_recover)
        else:
            param_update = next_m / (op_sqrt(next_v) + eps)

        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)

        next_param = param - lr_t * param_update

        F.control_depend(assign_m, next_m)
        F.control_depend(assign_v, next_v)

        success = F.depend(success, F.assign(param, next_param))
        success = F.depend(success, F.assign(m, next_m))
        success = F.depend(success, F.assign(v, next_v))

    return success
예제 #24
0
 def construct(self, input_ids, input_mask, token_type_id, label_ids):
     """Defines the computation performed."""
     weights = self.weights
     saved = ()
     for i in range(self.length):
         saved = saved + (F.assign(self.saved_params[i], weights[i]), )
     assign_embedding = ()
     for i in range(self.quant_embedding_list_length):
         quant_embedding = self.quantize_embedding(
             weights[self.quant_embedding_list[i]])
         assign_embedding = assign_embedding + (F.assign(
             weights[self.quant_embedding_list[i]], quant_embedding), )
         F.control_depend(saved, assign_embedding[i])
     assign_weight = ()
     for i in range(self.quant_weight_list_length):
         quant_weight = self.quantize_weight(
             weights[self.quant_weight_list[i]])
         assign_weight = assign_weight + (F.assign(
             weights[self.quant_weight_list[i]], quant_weight), )
         F.control_depend(saved, assign_weight[i])
     for i in range(self.quant_embedding_list_length):
         F.control_depend(assign_embedding[i], input_ids)
     for i in range(self.quant_weight_list_length):
         F.control_depend(assign_weight[i], input_ids)
     grads = self.grad(self.network,
                       weights)(input_ids, input_mask, token_type_id,
                                label_ids,
                                self.cast(F.tuple_to_array((self.sens, )),
                                          mstype.float32))
     F.control_depend(input_ids, grads)
     # apply grad reducer on grads
     grads = self.grad_reducer(grads)
     grads = self.hyper_map(
         F.partial(clip_grad, gradient_cfg.clip_type,
                   gradient_cfg.clip_value), grads)
     restore = ()
     for i in range(self.length):
         restore = restore + (F.assign(weights[i], self.saved_params[i]), )
         F.control_depend(grads, restore[i])
     succ = self.optimizer(grads)
     for i in range(self.length):
         F.control_depend(restore[i], succ)
     return succ