コード例 #1
0
 def construct(self, gradients):
     params = self.params
     moments = self.moments
     if self.weight_decay > 0:
         gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
     if self.reciprocal_scale != 1.0:
         gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
     if self.dynamic_lr:
         lr = self.gather(self.learning_rate, self.global_step, self.axis)
         F.control_depend(lr, self.assignadd(self.global_step, self.one))
     else:
         lr = self.learning_rate
     success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
     return success
コード例 #2
0
    def construct(self, grads):
        params = self.parameters
        moments = self.moments
        linear = self.linear
        lr = self.learning_rate
        if self.weight_decay > 0.0:
            grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay),
                                   self.decay_tf, params, grads)

        grads = self.scale_grad(grads)
        success = self.map_(
            F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1,
                      self.l2, self.lr_power), linear, grads, params, moments)
        return success
コード例 #3
0
 def construct(self,
               input_ids,
               input_mask,
               token_type_id,
               next_sentence_labels,
               masked_lm_positions,
               masked_lm_ids,
               masked_lm_weights,
               sens=None):
     """Defines the computation performed."""
     weights = self.weights
     loss = self.network(input_ids, input_mask, token_type_id,
                         next_sentence_labels, masked_lm_positions,
                         masked_lm_ids, masked_lm_weights)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     # alloc status and clear should be right before gradoperation
     init = self.alloc_status()
     self.clear_before_grad(init)
     grads = self.grad(self.network,
                       weights)(input_ids, input_mask, token_type_id,
                                next_sentence_labels, masked_lm_positions,
                                masked_lm_ids, masked_lm_weights,
                                self.cast(scaling_sens, mstype.float32))
     # apply grad reducer on grads
     grads = self.grad_reducer(grads)
     grads = self.hyper_map(
         F.partial(grad_scale, scaling_sens * self.degree), grads)
     grads = self.hyper_map(
         F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
         grads)
     self.get_status(init)
     flag_sum = self.reduce_sum(init, (0, ))
     if self.is_distributed:
         # sum overflow flag over devices
         flag_reduce = self.allreduce(flag_sum)
         cond = self.less_equal(self.base, flag_reduce)
     else:
         cond = self.less_equal(self.base, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     ret = (loss, cond, scaling_sens)
     return F.depend(ret, succ)
コード例 #4
0
 def construct(self, grads):
     params = self.parameters
     accum = self.accum
     grads = self.decay_weight(grads)
     grads = self.scale_grad(grads)
     grads = self.gradients_centralization(grads)
     lr = self.get_lr()
     if self.is_group_lr:
         success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
                             grads)
     else:
         success = self.map_(F.partial(_ada_grad_opt, self.opt, lr), params, accum,
                             grads)
     return success
コード例 #5
0
ファイル: utils.py プロジェクト: xyg320/mindspore
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids,
                  sens=None):


        weights = self.weights
        init = self.alloc_status()
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            label_ids)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 label_ids,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        clear_before_grad = self.clear_before_grad(init)
        F.control_depend(loss, init)
        self.depend_parameter_use(clear_before_grad, scaling_sens)
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            grads = self.grad_reducer(grads)
        flag = self.get_status(init)
        flag_sum = self.reduce_sum(init, (0,))
        if self.is_distributed:
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)
        F.control_depend(grads, flag)
        F.control_depend(flag, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond)
        return F.depend(ret, succ)
コード例 #6
0
 def construct(self, gradients):
     params = self.params
     moments = self.moments
     gradients = self.scale_grad(gradients)
     new_grads = ()
     if self.skfac:
         for i in range(54):
             g = gradients[i * 3]
             g_shape = self.shape(g)
             g = self.reshape(g, (g_shape[0], -1))
             matrix_A = self.matrix_A[i]
             matrix_G = self.matrix_G[i]
             g = self.matmul(self.matmul(matrix_G, g), matrix_A)
             fake_A = self.assign(self.matrix_A[i], matrix_A)
             fake_G = self.assign(self.matrix_G[i], matrix_G)
             g = F.depend(g, fake_A)
             g = F.depend(g, fake_G)
             if i == 53:
                 new_grads = new_grads + (g, )
             else:
                 g = self.reshape(g, g_shape)
                 new_grads = new_grads + (g, gradients[i * 3 + 1],
                                          gradients[i * 3 + 2])
     else:
         for i in range(54):
             g = gradients[i * 3]
             g_shape = self.shape(g)
             g = self.reshape(g, (g_shape[0], -1))
             matrix_A = self.matrix_A[i]
             matrix_G = self.matrix_G[i]
             matrix_A = F.depend(matrix_A, g)
             matrix_G = F.depend(matrix_G, g)
             g = self.matmul(self.matmul(matrix_G, g), matrix_A)
             if i == 53:
                 new_grads = new_grads + (g, )
             else:
                 g = self.reshape(g, g_shape)
                 new_grads = new_grads + (g, gradients[i * 3 + 1],
                                          gradients[i * 3 + 2])
     gradients = new_grads
     if self.weight_decay > 0:
         gradients = self.hyper_map(
             F.partial(apply_decay, self.weight_decay), self.decay_flags,
             params, gradients)
     lr = self.get_lr()
     success = self.hyper_map(
         F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients,
         params, moments)
     return success
コード例 #7
0
ファイル: lars.py プロジェクト: Xylonwang/mindspore
    def construct(self, gradients):
        params = self.parameters
        if self.dynamic_lr:
            lr = self.gather(self.learning_rate, self.global_step, self.axis)
            F.control_depend(lr, self.assignadd(self.global_step, 1))
        else:
            lr = self.learning_rate
        if self.reciprocal_scale != 1.0:
            gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

        grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr),
                                gradients, params, self.decay_flag, self.lars_flag)
        success = self.opt(grad_t)

        return success
コード例 #8
0
 def construct(self, grads):
     params = self.parameters
     accum = self.accum
     grads = self.decay_weight(grads)
     grads = self.scale_grad(grads)
     lr = self.get_lr()
     if self.is_group_lr:
         success = self.map_(
             F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt,
                       self.l1, self.l2), lr, grads, params, accum)
     else:
         success = self.map_(
             F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt,
                       self.l1, self.l2, lr), grads, params, accum)
     return success
コード例 #9
0
 def construct(self, grads):
     params = self.parameters
     moments = self.moments
     linear = self.linear
     if self.weight_decay > 0.0:
         grads = self.hyper_map(F.partial(apply_decay, self.weight_decay),
                                self.decay_tf, params, grads)
     if self.reciprocal_scale != 1.0:
         grads = self.hyper_map(
             F.partial(grad_scale, self.reciprocal_scale), grads)
     lr = self.learning_rate
     success = self.hyper_map(
         F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power),
         linear, grads, params, moments)
     return success
コード例 #10
0
 def construct(self, data, label):
     weights = self.weights
     loss = self.network(data, label)
     sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(data, label, sens)
     norm = self.hyper_map(F.partial(compute_norm), grads)
     norm = self.concat(norm)
     norm = self.norm(norm)
     cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
     clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
     grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
     if self.reducer_flag:
         # apply grad reducer on grads
         grads = self.grad_reducer(grads)
     return F.depend(loss, self.optimizer(grads))
コード例 #11
0
 def construct(self, gradients):
     params = self.params
     moments = self.moments
     gradients = self.decay_weight(gradients)
     gradients = self.scale_grad(gradients)
     lr = self.get_lr()
     if self.is_group_lr:
         success = self.hyper_map(
             F.partial(_momentum_opt, self.opt, self.momentum), lr,
             gradients, params, moments, self.ps_parameters)
     else:
         success = self.hyper_map(
             F.partial(_momentum_opt, self.opt, self.momentum, lr),
             gradients, params, moments, self.ps_parameters)
     return success
コード例 #12
0
    def construct(self, grads):
        """construct of DistributedGradReducerThor"""
        # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
        # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
        # and cast back after the operation.
        datatypes = self.hyper_map(F.partial(_get_datatype), grads)
        grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)

        if self.mean:
            new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
        else:
            new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)

        new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
        return new_grad
コード例 #13
0
 def construct(self, gradients):
     params = self.parameters
     accum = self.accum
     stat = self.stat
     gradients = self.scale_grad(gradients)
     lr = self.get_lr()
     if self.is_group_lr:
         success = self.hyper_map(
             F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients,
             params, accum, stat)
     else:
         success = self.hyper_map(
             F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients,
             params, accum, stat)
     return success
コード例 #14
0
 def construct(self, x):
     square_sum = self.hyper_map(get_square_sum, x)
     global_norm = F.sqrt(F.addn(square_sum))
     cond = self.greater_equal(global_norm, self.clip_norm)
     global_norm = F.select(cond, global_norm, self.clip_norm)
     clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x)
     return clip_x
コード例 #15
0
    def construct(self, input_ids, input_mask, token_type_id, label_ids):
        """Defines the computation performed."""
        weights = self.weights
        for i in range(self.length):
            F.assign(self.saved_params[i], weights[i])

        for i in range(self.quant_embedding_list_length):
            quant_embedding = self.quantize_embedding(
                weights[self.quant_embedding_list[i]])
            F.assign(weights[self.quant_embedding_list[i]], quant_embedding)

        for i in range(self.quant_weight_list_length):
            quant_weight = self.quantize_weight(
                weights[self.quant_weight_list[i]])
            F.assign(weights[self.quant_weight_list[i]], quant_weight)

        grads = self.grad(self.network,
                          weights)(input_ids, input_mask, token_type_id,
                                   label_ids,
                                   self.cast(F.tuple_to_array((self.sens, )),
                                             mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(
            F.partial(clip_grad, self.clip_type, self.clip_value), grads)

        for i in range(self.length):
            param = F.depend(self.saved_params[i], grads)
            F.assign(weights[i], param)

        succ = self.optimizer(grads)
        return succ
コード例 #16
0
    def construct(self, gradients):
        # TODO: perform all_reduce
        #     gradients = self._map(self._all_reduce, gradients)
        self.acc_step = self.acc_step + 1
        q = self.mod_op(self.acc_step, self.apply_period)

        # log_gradient = True
        log_gradient = False
        if log_gradient:
            gradients = self.hyper_map(log_tensor, gradients)

        accu_grads = self.hyper_map(add_grads, self.accu_grads, gradients)
        accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                   accu_grads)

        if q == 0:
            mean_grads = self.hyper_map(
                F.partial(grad_scale, self.apply_period), accu_grads)
            apply_succ = super(CumulativeSGDOptimizer,
                               self).construct(mean_grads)
            reset_succ = self.hyper_map(reset_accu_grads, self.accu_grads)

            succ = F.depend(reset_succ, apply_succ)
        else:
            succ = True
            succ = F.depend(succ, accu_succ)

        return F.depend(gradients, succ)
コード例 #17
0
ファイル: test_row_tensor.py プロジェクト: dongkcs/mindspore
 def construct(self, gradients):
     lr = self.get_lr()
     updated_velocity = self.map(
         F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
                   self.weight_decay_tensor), self.params, self.moments1,
         self.moments2, gradients, self.decay_flag)
     return updated_velocity
コード例 #18
0
    def construct(self, data, label):
        """
        construct a compute flow.
        """
        weights = self.weights
        record_datas = self._split(data)
        record_labels = self._split(label)
        loss = self.network(record_datas[0], record_labels[0])
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens)
        record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
        grads = record_grad
        total_loss = loss
        for i in range(1, self._micro_batches):
            loss = self.network(record_datas[i], record_labels[i])
            sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
            record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens)
            record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
            grads = self._tuple_add(grads, record_grad)
            total_loss = P.TensorAdd()(total_loss, loss)
        loss = P.Div()(total_loss, self._micro_float)

        if self._mech is not None:
            grad_noise = self._hyper_map(self._mech, grads)
            grads = self._tuple_add(grads, grad_noise)
            grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)

        if self.reducer_flag:
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
        return F.depend(loss, self.optimizer(grads))
コード例 #19
0
ファイル: rmsprop.py プロジェクト: pkuliuliu/mindspore
 def construct(self, gradients):
     params = self.parameters
     if self.reciprocal_scale != 1.0:
         gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
     if self.dynamic_lr:
         lr = self.gather(self.learning_rate, self.global_step, self.axis)
         F.control_depend(lr, self.assignadd(self.global_step, self.one))
     else:
         lr = self.learning_rate
     if self.centered:
         success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
                                            self.momentum), params, self.mg, self.ms, self.moment, gradients)
     else:
         success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
                                            self.momentum), params, self.ms, self.moment, gradients)
     return success
コード例 #20
0
ファイル: sgd.py プロジェクト: zky001/mindspore
 def construct(self, gradients):
     params = self.params
     accum = self.accum
     stat = self.stat
     if self.reciprocal_scale != 1.0:
         gradients = self.hyper_map(
             F.partial(grad_scale, self.reciprocal_scale), gradients)
     if self.dynamic_lr:
         lr = self.gather(self.learning_rate, self.global_step, self.axis)
         F.control_depend(lr, self.assignadd(self.global_step, 1))
     else:
         lr = self.learning_rate
     success = self.hyper_map(
         F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params,
         accum, stat)
     return success
コード例 #21
0
 def construct(self, gradients):
     params = self.parameters
     gradients = self.decay_weight(gradients)
     gradients = self.scale_grad(gradients)
     lr = self.get_lr()
     if self.centered:
         success = self.hyper_map(
             F.partial(centered_rmsprop_opt, self.opt, lr, self.decay,
                       self.epsilon, self.momentum), params, self.mg,
             self.ms, self.moment, gradients)
     else:
         success = self.hyper_map(
             F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
                       self.momentum), params, self.ms, self.moment,
             gradients)
     return success
コード例 #22
0
    def construct(self, img, label_indices, text, sequence_length, lab_len):
        """
        Cell's forward
        Args:
            img: input
            label_indices: get from the data generator
            text: label got from the data generator
            sequence_length: get from the data generator
            lab_len: get from the data generator

        Returns:
            loss: loss value
        """
        weights = self.weights
        loss = self.network(img, label_indices, text, sequence_length)

        scaling_sens = self.scale_sense

        grads = self.grad(self.network, weights)(img, label_indices, text, sequence_length,
                                                 self.cast(scaling_sens, mstype.float32))

        grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
        grads = self.clip_gradients(grads, GRADIENT_CLIP_MIN, GRADIENT_CLIP_MAX)

        if self.reducer_flag:
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)

        success = self.optimizer(grads)

        ret = (loss, scaling_sens)
        return F.depend(ret, success)
コード例 #23
0
 def construct(self, *inputs):
     weights = self.weights
     loss = self.network(*inputs)
     sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     return F.depend(
         loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
コード例 #24
0
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights):
        """Defines the computation performed."""
        weights = self.weights

        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            next_sentence_labels,
                            masked_lm_positions,
                            masked_lm_ids,
                            masked_lm_weights)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 next_sentence_labels,
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights,
                                                 self.cast(F.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        grads = self.grad_reducer(grads)
        succ = self.optimizer(grads)
        return F.depend(loss, succ)
コード例 #25
0
ファイル: adam.py プロジェクト: chexueji/mindspore
    def construct(self, gradients):
        updated_velocity = self.hyper_map(
            F.partial(adam_opt, self.beta1, self.beta2, self.eps, self.lr,
                      self.weight_decay_tensor), self.params, self.moments1,
            self.moments2, gradients)

        return updated_velocity
コード例 #26
0
ファイル: utils.py プロジェクト: pingping1122/mindspore
 def construct(self, grads):
     global_norm = self.global_norm(grads)
     cond = P.GreaterEqual()(global_norm, self.clip_norm)
     global_norm = F.select(cond, global_norm, self.clip_norm)
     grads = self.hyper_map(
         F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
     return grads
コード例 #27
0
 def construct(self, x, sens=None):
     """Defines the computation performed."""
     weights = self.weights
     loss = self.network(x)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     # alloc status and clear should be right before gradoperation
     init = self.alloc_status()
     self.clear_before_grad(init)
     grads = self.grad(self.network,
                       weights)(x, self.cast(scaling_sens, mstype.float32))
     # apply grad reducer on grads
     grads = self.grad_reducer(grads)
     grads = self.hyper_map(
         F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
         grads)
     self.get_status(init)
     flag_sum = self.reduce_sum(init, (0, ))
     cond = self.less_equal(self.base, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     ret = (loss, cond, scaling_sens)
     return F.depend(ret, succ)
コード例 #28
0
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights,
                  sens=None):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            next_sentence_labels,
                            masked_lm_positions,
                            masked_lm_ids,
                            masked_lm_weights)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 next_sentence_labels,
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        # apply grad reducer on grads
        grads = self.grad_reducer(grads)
        grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
        grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)

        cond = self.get_overflow_status(status, grads)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond, scaling_sens)
        return F.depend(ret, succ)
コード例 #29
0
 def construct(self, grads):
     success = True
     weights = self.weights
     moments = self.moments
     success = self.hyper_map(
         F.partial(run_opt, self.opt, self.iter, self.learning_rate,
                   self.momentum), grads, weights, moments)
     return success
コード例 #30
0
ファイル: lazyadam.py プロジェクト: huxian123/mindspore
    def construct(self, gradients):
        gradients = self.decay_weight(gradients)
        gradients = self.scale_grad(gradients)
        lr = self.get_lr()

        self.beta1_power = self.beta1_power * self.beta1
        self.beta2_power = self.beta2_power * self.beta2

        if self.is_group_lr:
            success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
                                          self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
                                lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
        else:
            success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
                                          self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
                                gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
        return success