예제 #1
0
    def __init__(self):
        super(HomoLRGuest, self).__init__()
        self.gradient_operator = LogisticGradient()
        self.loss_history = []
        self.role = consts.GUEST
        self.aggregator = aggregator.Guest()

        self.zcl_encrypt_operator = PaillierEncrypt()
예제 #2
0
    def fit_binary(self, data_instances, validate_data=None):
        self.aggregator = aggregator.Guest()
        self.aggregator.register_aggregator(self.transfer_variable)

        self.callback_list.on_train_begin(data_instances, validate_data)

        # validation_strategy = self.init_validation_strategy(data_instances, validate_data)
        if not self.component_properties.is_warm_start:
            self.model_weights = self._init_model_variables(data_instances)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        max_iter = self.max_iter
        # total_data_num = data_instances.count()
        mini_batch_obj = MiniBatch(data_inst=data_instances,
                                   batch_size=self.batch_size)
        model_weights = self.model_weights

        degree = 0
        self.prev_round_weights = copy.deepcopy(model_weights)

        while self.n_iter_ < max_iter + 1:
            self.callback_list.on_epoch_begin(self.n_iter_)
            batch_data_generator = mini_batch_obj.mini_batch_data_generator()

            self.optimizer.set_iters(self.n_iter_)
            if ((self.n_iter_ + 1) % self.aggregate_iters
                    == 0) or self.n_iter_ == max_iter:
                weight = self.aggregator.aggregate_then_get(
                    model_weights, degree=degree, suffix=self.n_iter_)

                self.model_weights = LogisticRegressionWeights(
                    weight.unboxed, self.fit_intercept)

                # store prev_round_weights after aggregation
                self.prev_round_weights = copy.deepcopy(self.model_weights)
                # send loss to arbiter
                loss = self._compute_loss(data_instances,
                                          self.prev_round_weights)
                self.aggregator.send_loss(loss,
                                          degree=degree,
                                          suffix=(self.n_iter_, ))
                degree = 0

                self.is_converged = self.aggregator.get_converge_status(
                    suffix=(self.n_iter_, ))
                LOGGER.info(
                    "n_iters: {}, loss: {} converge flag is :{}".format(
                        self.n_iter_, loss, self.is_converged))
                if self.is_converged or self.n_iter_ == max_iter:
                    break
                model_weights = self.model_weights

            batch_num = 0
            for batch_data in batch_data_generator:
                n = batch_data.count()
                # LOGGER.debug("In each batch, lr_weight: {}, batch_data count: {}".format(model_weights.unboxed, n))
                f = functools.partial(self.gradient_operator.compute_gradient,
                                      coef=model_weights.coef_,
                                      intercept=model_weights.intercept_,
                                      fit_intercept=self.fit_intercept)
                grad = batch_data.applyPartitions(f).reduce(
                    fate_operator.reduce_add)
                grad /= n
                # LOGGER.debug('iter: {}, batch_index: {}, grad: {}, n: {}'.format(
                #     self.n_iter_, batch_num, grad, n))

                if self.use_proximal:  # use proximal term
                    model_weights = self.optimizer.update_model(
                        model_weights,
                        grad=grad,
                        has_applied=False,
                        prev_round_weights=self.prev_round_weights)
                else:
                    model_weights = self.optimizer.update_model(
                        model_weights, grad=grad, has_applied=False)

                batch_num += 1
                degree += n

            # validation_strategy.validate(self, self.n_iter_)
            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1

            if self.stop_training:
                break

        self.set_summary(self.get_model_summary())
예제 #3
0
 def __init__(self):
     super(HomoFMGuest, self).__init__()
     self.gradient_operator = FactorizationGradient()
     self.loss_history = []
     self.role = consts.GUEST
     self.aggregator = aggregator.Guest()