def __init__(self): super(HomoFMArbiter, self).__init__() # self.re_encrypt_times = [] # Record the times needed for each host self.loss_history = [] self.is_converged = False self.role = consts.ARBITER self.aggregator = aggregator.Arbiter() self.model_weights = None self.host_predict_results = []
def fit(self, data_instances=None, validate_data=None): self.aggregator = aggregator.Arbiter() self.aggregator.register_aggregator(self.transfer_variable) self._server_check_data() host_ciphers = self.cipher.paillier_keygen( key_length=self.model_param.encrypt_param.key_length, suffix=('fit', )) host_has_no_cipher_ids = [ idx for idx, cipher in host_ciphers.items() if cipher is None ] self.re_encrypt_times = self.cipher.set_re_cipher_time(host_ciphers) max_iter = self.max_iter # validation_strategy = self.init_validation_strategy() while self.n_iter_ < max_iter + 1: suffix = (self.n_iter_, ) if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == max_iter: merged_model = self.aggregator.aggregate_and_broadcast( ciphers_dict=host_ciphers, suffix=suffix) total_loss = self.aggregator.aggregate_loss( host_has_no_cipher_ids, suffix) self.callback_loss(self.n_iter_, total_loss) self.loss_history.append(total_loss) if self.use_loss: converge_var = total_loss else: converge_var = np.array(merged_model.unboxed) self.is_converged = self.aggregator.send_converge_status( self.converge_func.is_converge, (converge_var, ), suffix=(self.n_iter_, )) LOGGER.info( "n_iters: {}, total_loss: {}, converge flag is :{}".format( self.n_iter_, total_loss, self.is_converged)) self.model_weights = LogisticRegressionWeights( merged_model.unboxed, self.model_param.init_param.fit_intercept) if self.header is None: self.header = [ 'x' + str(i) for i in range(len(self.model_weights.coef_)) ] if self.is_converged or self.n_iter_ == max_iter: break self.cipher.re_cipher(iter_num=self.n_iter_, re_encrypt_times=self.re_encrypt_times, host_ciphers_dict=host_ciphers, re_encrypt_batches=self.re_encrypt_batches) # validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1 LOGGER.info("Finish Training task, total iters: {}".format( self.n_iter_))