示例#1
0
 def __init__(self):
     super(HomoFMHost, self).__init__()
     self.gradient_operator = None
     self.loss_history = []
     self.is_converged = False
     self.role = consts.HOST
     self.aggregator = aggregator.Host()
     self.model_weights = None
示例#2
0
    def __init__(self):
        super(HomoLRHost, self).__init__()
        self.gradient_operator = None
        self.loss_history = []
        self.is_converged = False
        self.role = consts.HOST
        self.aggregator = aggregator.Host()
        self.model_weights = None
        self.cipher = paillier_cipher.Host()

        self.zcl_encrypt_operator = PaillierEncrypt()
示例#3
0
    def fit_binary(self, data_instances, validate_data=None):
        self.aggregator = aggregator.Host()
        self.aggregator.register_aggregator(self.transfer_variable)

        self._client_check_data(data_instances)
        self.callback_list.on_train_begin(data_instances, validate_data)

        pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',))
        if self.use_encrypt:
            self.cipher_operator.set_public_key(pubkey)

        if not self.component_properties.is_warm_start:
            self.model_weights = self._init_model_variables(data_instances)
            if self.use_encrypt:
                w = self.cipher_operator.encrypt_list(self.model_weights.unboxed)
            else:
                w = list(self.model_weights.unboxed)
            self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed))

        mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size)

        total_batch_num = mini_batch_obj.batch_nums

        if self.use_encrypt:
            re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1
            # LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format(
            #     re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches))
            self.cipher.set_re_cipher_time(re_encrypt_times)

        # total_data_num = data_instances.count()
        # LOGGER.debug("Current data count: {}".format(total_data_num))

        model_weights = self.model_weights
        self.prev_round_weights = copy.deepcopy(model_weights)
        degree = 0
        while self.n_iter_ < self.max_iter + 1:
            self.callback_list.on_epoch_begin(self.n_iter_)

            batch_data_generator = mini_batch_obj.mini_batch_data_generator()
            self.optimizer.set_iters(self.n_iter_)

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter:
                weight = self.aggregator.aggregate_then_get(model_weights, degree=degree,
                                                            suffix=self.n_iter_)
                # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format(
                #     model_weights.unboxed / degree,
                #     degree,
                #     weight.unboxed))
                self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept)
                if not self.use_encrypt:
                    loss = self._compute_loss(data_instances, self.prev_round_weights)
                    self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,))
                    LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss))
                degree = 0
                self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,))
                LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged))
                if self.is_converged or self.n_iter_ == self.max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                degree += n
                LOGGER.debug('before compute_gradient')
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add)
                grad /= n
                if self.use_proximal:  # use additional proximal term
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False,
                                                                prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(model_weights,
                                                                grad=grad,
                                                                has_applied=False)

                if self.use_encrypt and batch_num % self.re_encrypt_batches == 0:
                    LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num))
                    w = self.cipher.re_cipher(w=model_weights.unboxed,
                                              iter_num=self.n_iter_,
                                              batch_iter_num=batch_num)
                    model_weights = LogisticRegressionWeights(w, self.fit_intercept)
                batch_num += 1

            # validation_strategy.validate(self, self.n_iter_)
            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

        self.set_summary(self.get_model_summary())
        LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))