Esempio n. 1
0
    def __init__(self):
        super(HomoFMArbiter, self).__init__()
        # self.re_encrypt_times = []  # Record the times needed for each host

        self.loss_history = []
        self.is_converged = False
        self.role = consts.ARBITER
        self.aggregator = aggregator.Arbiter()
        self.model_weights = None
        self.host_predict_results = []
Esempio n. 2
0
    def fit(self, data_instances=None, validate_data=None):
        self.aggregator = aggregator.Arbiter()
        self.aggregator.register_aggregator(self.transfer_variable)

        self._server_check_data()

        host_ciphers = self.cipher.paillier_keygen(
            key_length=self.model_param.encrypt_param.key_length,
            suffix=('fit', ))
        host_has_no_cipher_ids = [
            idx for idx, cipher in host_ciphers.items() if cipher is None
        ]
        self.re_encrypt_times = self.cipher.set_re_cipher_time(host_ciphers)
        max_iter = self.max_iter
        # validation_strategy = self.init_validation_strategy()

        while self.n_iter_ < max_iter + 1:
            suffix = (self.n_iter_, )

            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == max_iter:
                merged_model = self.aggregator.aggregate_and_broadcast(
                    ciphers_dict=host_ciphers, suffix=suffix)
                total_loss = self.aggregator.aggregate_loss(
                    host_has_no_cipher_ids, suffix)
                self.callback_loss(self.n_iter_, total_loss)
                self.loss_history.append(total_loss)
                if self.use_loss:
                    converge_var = total_loss
                else:
                    converge_var = np.array(merged_model.unboxed)

                self.is_converged = self.aggregator.send_converge_status(
                    self.converge_func.is_converge, (converge_var, ),
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, total_loss: {}, converge flag is :{}".format(
                        self.n_iter_, total_loss, self.is_converged))

                self.model_weights = LogisticRegressionWeights(
                    merged_model.unboxed,
                    self.model_param.init_param.fit_intercept)
                if self.header is None:
                    self.header = [
                        'x' + str(i)
                        for i in range(len(self.model_weights.coef_))
                    ]

                if self.is_converged or self.n_iter_ == max_iter:
                    break

            self.cipher.re_cipher(iter_num=self.n_iter_,
                                  re_encrypt_times=self.re_encrypt_times,
                                  host_ciphers_dict=host_ciphers,
                                  re_encrypt_batches=self.re_encrypt_batches)

            # validation_strategy.validate(self, self.n_iter_)
            self.n_iter_ += 1

        LOGGER.info("Finish Training task, total iters: {}".format(
            self.n_iter_))