示例#1
0
    def predict(self, data_instances):

        LOGGER.info(f'Start predict task')
        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        suffix = ('predict',)
        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=suffix)
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        if self.use_encrypt:
            final_model = self.transfer_variable.aggregated_model.get(idx=0, suffix=suffix)
            model_weights = LogisticRegressionWeights(final_model.unboxed, self.fit_intercept)
            wx = self.compute_wx(data_instances, model_weights.coef_, model_weights.intercept_)
            self.transfer_variable.predict_wx.remote(wx, consts.ARBITER, 0, suffix=suffix)
            predict_result = self.transfer_variable.predict_result.get(idx=0, suffix=suffix)
            predict_result = predict_result.join(data_instances, lambda p, d: [d.label, p, None,
                                                                                     {"0": None, "1": None}])

        else:
            predict_wx = self.compute_wx(data_instances, self.model_weights.coef_, self.model_weights.intercept_)
            pred_table = self.classify(predict_wx, self.model_param.predict_param.threshold)
            predict_result = data_instances.mapValues(lambda x: x.label)
            predict_result = pred_table.join(predict_result, lambda x, y: [y, x[1], x[0],
                                                                           {"1": x[0], "0": 1 - x[0]}])
        return predict_result
示例#2
0
    def fit(self, data_instances=None, validate_data=None):
        self._server_check_data()

        host_ciphers = self.cipher.paillier_keygen(
            key_length=self.model_param.encrypt_param.key_length,
            suffix=('fit', ))
        host_has_no_cipher_ids = [
            idx for idx, cipher in host_ciphers.items() if cipher is None
        ]
        self.re_encrypt_times = self.cipher.set_re_cipher_time(host_ciphers)
        max_iter = self.max_iter
        # validation_strategy = self.init_validation_strategy()

        while self.n_iter_ < max_iter + 1:
            suffix = (self.n_iter_, )

            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == max_iter:
                merged_model = self.aggregator.aggregate_and_broadcast(
                    ciphers_dict=host_ciphers, suffix=suffix)
                total_loss = self.aggregator.aggregate_loss(
                    host_has_no_cipher_ids, suffix)
                self.callback_loss(self.n_iter_, total_loss)
                self.loss_history.append(total_loss)
                if self.use_loss:
                    converge_var = total_loss
                else:
                    converge_var = np.array(merged_model.unboxed)

                self.is_converged = self.aggregator.send_converge_status(
                    self.converge_func.is_converge, (converge_var, ),
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, total_loss: {}, converge flag is :{}".format(
                        self.n_iter_, total_loss, self.is_converged))

                self.model_weights = LogisticRegressionWeights(
                    merged_model.unboxed,
                    self.model_param.init_param.fit_intercept)
                if self.header is None:
                    self.header = [
                        'x' + str(i)
                        for i in range(len(self.model_weights.coef_))
                    ]

                if self.is_converged or self.n_iter_ == max_iter:
                    break

            self.cipher.re_cipher(iter_num=self.n_iter_,
                                  re_encrypt_times=self.re_encrypt_times,
                                  host_ciphers_dict=host_ciphers,
                                  re_encrypt_batches=self.re_encrypt_batches)

            # validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        LOGGER.info("Finish Training task, total iters: {}".format(
            self.n_iter_))
示例#3
0
    def load_single_model(self, single_model_obj):
        LOGGER.info("It's a binary task, start to load single model")
        feature_shape = len(self.header)
        tmp_vars = np.zeros(feature_shape)
        weight_dict = dict(single_model_obj.weight)

        for idx, header_name in enumerate(self.header):
            tmp_vars[idx] = weight_dict.get(header_name)

        if self.fit_intercept:
            tmp_vars = np.append(tmp_vars, single_model_obj.intercept)
        self.model_weights = LogisticRegressionWeights(
            tmp_vars, fit_intercept=self.fit_intercept)
        return self
示例#4
0
    def predict(self, data_instances):

        LOGGER.info(f'Start predict task')
        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        data_instances = self.align_data_header(data_instances, self.header)
        suffix = ('predict', )
        if self.component_properties.has_arbiter:
            pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt,
                                                     suffix=suffix)
        else:
            if self.use_encrypt:
                raise ValueError(f"In use_encrypt case, arbiter should be set")
            pubkey = None
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

            final_model = self.transfer_variable.aggregated_model.get(
                idx=0, suffix=suffix)
            model_weights = LogisticRegressionWeights(final_model.unboxed,
                                                      self.fit_intercept)
            wx = self.compute_wx(data_instances, model_weights.coef_,
                                 model_weights.intercept_)
            self.transfer_variable.predict_wx.remote(wx,
                                                     consts.ARBITER,
                                                     0,
                                                     suffix=suffix)
            predict_result = self.transfer_variable.predict_result.get(
                idx=0, suffix=suffix)
            # predict_result = predict_result.join(data_instances, lambda p, d: [d.label, p, None,
            #                                                                    {"0": None, "1": None}])
            predict_result = predict_result.join(
                data_instances, lambda p, d: Instance(
                    features=[d.label, p, None, {
                        "0": None,
                        "1": None
                    }],
                    inst_id=d.inst_id))
        else:
            pred_prob = data_instances.mapValues(lambda v: activation.sigmoid(
                vec_dot(v.features, self.model_weights.coef_) + self.
                model_weights.intercept_))
            predict_result = self.predict_score_to_output(
                data_instances,
                pred_prob,
                classes=[0, 1],
                threshold=self.model_param.predict_param.threshold)

        return predict_result
示例#5
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        # validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt,
                                                 suffix=('fit', ))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(
            w, self.model_weights.fit_intercept)

        LOGGER.debug("After init, model_weights: {}".format(
            self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num -
                                1) // self.re_encrypt_batches + 1
            LOGGER.debug(
                "re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}"
                .format(re_encrypt_times, self.batch_size, total_batch_num,
                        self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances)
                    self.aggregator.send_loss(loss,
                                              degree=degree,
                                              suffix=(self.n_iter_, ))
                    LOGGER.info("n_iters: {}, loss: {}".format(
                        self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info("n_iters: {}, is_converge: {}".format(
                    self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                model_weights = self.optimizer.update_model(model_weights,
                                                            grad,
                                                            has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug(
                        "Before accept re_encrypted_model, batch_iter_num: {}".
                        format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(
                        w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        LOGGER.info("Finish Training task, total iters: {}".format(
            self.n_iter_))
示例#6
0
    def fit(self, data_instances, validate_data=None):
        LOGGER.debug("Start data count: {}".format(data_instances.count()))

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)
        validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        self.model_weights = self._init_model_variables(data_instances)
        w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
        self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)

        LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = total_batch_num // self.re_encrypt_batches + 1
            LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
                re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        total_data_num = data_instances.count()
        LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        degree = 0

        self.__synchronize_encryption()
        self.zcl_idx, self.zcl_num_party = self.transfer_variable.num_party.get(idx=0, suffix=('train',))
        LOGGER.debug("party num:" + str(self.zcl_num_party))
        self.__init_model()

        self.train_loss_results = []
        self.train_accuracy_results = []
        self.test_loss_results = []
        self.test_accuracy_results = []

        for iter_num in range(self.max_iter):
            # mini-batch
            LOGGER.debug("In iter: {}".format(iter_num))
            # batch_data_generator = self.mini_batch_obj.mini_batch_data_generator()
            batch_num = 0
            total_loss = 0
            epoch_train_loss_avg = tfe.metrics.Mean()
            epoch_train_accuracy = tfe.metrics.Accuracy()

            for train_x, train_y in self.zcl_dataset:
                LOGGER.info("Staring batch {}".format(batch_num))
                start_t = time.time()
                loss_value, grads = self.__grad(self.zcl_model, train_x, train_y)
                loss_value = loss_value.numpy()
                grads = [x.numpy() for x in grads]
                LOGGER.info("Start encrypting")
                loss_value = batch_encryption.encrypt(self.zcl_encrypt_operator.get_public_key(), loss_value)
                grads = [batch_encryption.encrypt_matrix(self.zcl_encrypt_operator.get_public_key(), x) for x in grads]
                LOGGER.info("Finish encrypting")
                grads = Gradients(grads)
                self.transfer_variable.host_grad.remote(obj=grads.for_remote(), role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent grads")
                self.transfer_variable.host_loss.remote(obj=loss_value, role=consts.ARBITER, idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Sent loss")

                sum_grads = self.transfer_variable.aggregated_grad.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got grads")
                sum_loss = self.transfer_variable.aggregated_loss.get(idx=0, suffix=(iter_num, batch_num))
                LOGGER.info("Got loss")

                sum_loss = batch_encryption.decrypt(self.zcl_encrypt_operator.get_privacy_key(), sum_loss)
                sum_grads = [
                    batch_encryption.decrypt_matrix(self.zcl_encrypt_operator.get_privacy_key(), x).astype(np.float32)
                    for x
                    in sum_grads.unboxed]
                LOGGER.info("Finish decrypting")

                # sum_grads = np.array(sum_grads) / self.zcl_num_party

                self.zcl_optimizer.apply_gradients(zip(sum_grads, self.zcl_model.trainable_variables),
                                                   self.zcl_global_step)

                elapsed_time = time.time() - start_t
                # epoch_train_loss_avg(loss_value)
                # epoch_train_accuracy(tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32),
                #                      train_y)
                self.train_loss_results.append(sum_loss)
                train_accuracy_v = accuracy_score(train_y,
                                                  tf.argmax(self.zcl_model(train_x), axis=1, output_type=tf.int32))
                self.train_accuracy_results.append(train_accuracy_v)
                test_loss_v = self.__loss(self.zcl_model, self.zcl_x_test, self.zcl_y_test)
                self.test_loss_results.append(test_loss_v)
                test_accuracy_v = accuracy_score(self.zcl_y_test,
                                                 tf.argmax(self.zcl_model(self.zcl_x_test), axis=1,
                                                           output_type=tf.int32))
                self.test_accuracy_results.append(test_accuracy_v)

                LOGGER.info(
                    "Epoch {:03d}, iteration {:03d}: train_loss: {:.3f}, train_accuracy: {:.3%}, test_loss: {:.3f}, "
                    "test_accuracy: {:.3%}, elapsed_time: {:.4f}".format(
                        iter_num,
                        batch_num,
                        sum_loss,
                        train_accuracy_v,
                        test_loss_v,
                        test_accuracy_v,
                        elapsed_time)
                )

                batch_num += 1

                if batch_num >= self.zcl_early_stop_batch:
                    return

            self.n_iter_ = iter_num
示例#7
0
    def fit_binary(self, data_instances, validate_data=None):
        self.aggregator = aggregator.Host()
        self.aggregator.register_aggregator(self.transfer_variable)

        self._client_check_data(data_instances)
        self.callback_list.on_train_begin(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        if not self.component_properties.is_warm_start:
            self.model_weights = self._init_model_variables(data_instances)
            if self.use_encrypt:
                w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
            else:
                w = list(self.model_weights.unboxed)
            self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1
            # LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
            #     re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        # total_data_num = data_instances.count()
        # LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        self.prev_round_weights = copy.deepcopy(model_weights)
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            self.callback_list.on_epoch_begin(self.n_iter_)

            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            self.optimizer.set_iters(self.n_iter_)

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(model_weights, degree=degree,
                                                            suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances, self.prev_round_weights)
                    self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,))
                    LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,))
                LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add)
                grad /= n
                if self.use_proximal:  # use additional proximal term
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False,
                                                                prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

        self.set_summary(self.get_model_summary())
        LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))
示例#8
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.check_abnormal_values(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        # total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        self.prev_round_weights = copy.deepcopy(model_weights)

        while self.n_iter_ < max_iter + 1:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == max_iter:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)

                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)

                # store prev_round_weights after aggregation
                self.prev_round_weights = copy.deepcopy(self.model_weights)
                # send loss to arbiter
                loss = self._compute_loss(data_instances,
                                          self.prev_round_weights)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged or self.n_iter_ == max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                # LOGGER.debug("In each batch, lr_weight: {}, batch_data count: {}".format(model_weights.unboxed, n))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                # LOGGER.debug('iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                #     self.n_iter_, batch_num, grad, n))

                if self.use_proximal:  # use proximal term
                    model_weights = self.optimizer.update_model(
                        model_weights,
                        grad=grad,
                        has_applied=False,
                        prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(
                        model_weights, grad=grad, has_applied=False)

                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1
        self.set_summary(self.get_model_summary())
示例#9
0
    def fit(self, data_instances, validate_data=None):

        self._abnormal_detection(data_instances)
        self.init_schema(data_instances)

        validation_strategy = self.init_validation_strategy(
            data_instances, validate_data)
        self.model_weights = self._init_model_variables(data_instances)

        max_iter = self.max_iter
        # total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        while self.n_iter_ < max_iter:
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if self.n_iter_ > 0 and self.n_iter_ % self.aggregate_iters == 0:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)
                LOGGER.debug(
                    "Before aggregate: {}, degree: {} after aggregated: {}".
                    format(model_weights.unboxed / degree, degree,
                           weight.unboxed))

                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)
                loss = self._compute_loss(data_instances)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                LOGGER.debug(
                    "In each batch, lr_weight: {}, batch_data count: {}".
                    format(model_weights.unboxed, n))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.mapPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                LOGGER.debug(
                    'iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                        self.n_iter_, batch_num, grad, n))
                model_weights = self.optimizer.update_model(model_weights,
                                                            grad,
                                                            has_applied=False)
                batch_num += 1
                degree += n

            validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1