Example #1
0
    def __init__(self, logistic_params):
        super(HeteroLRGuest, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.data_batch_count = []

        self.wx = None
        self.guest_forward = None
Example #2
0
 def __init__(self, logistic_params):
     # LogisticParamChecker.check_param(logistic_params)
     super(HeteroLRHost, self).__init__(logistic_params)
     self.transfer_variable = HeteroLRTransferVariable()
     self.batch_num = None
     self.batch_index_list = []
Example #3
0
class HeteroLRHost(BaseLogisticRegression):
    def __init__(self, logistic_params):
        # LogisticParamChecker.check_param(logistic_params)
        super(HeteroLRHost, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.batch_num = None
        self.batch_index_list = []

    def compute_forward(self, data_instances, coef_, intercept_):
        wx = self.compute_wx(data_instances, coef_, intercept_)
        encrypt_operator = self.encrypt_operator
        host_forward = wx.mapValues(lambda v: (encrypt_operator.encrypt(
            v), encrypt_operator.encrypt(np.square(v))))
        return host_forward

    def fit(self, data_instances):
        LOGGER.info("Enter hetero_lr host")
        self._abnormal_detection(data_instances)

        self.header = data_instances.schema.get("header")
        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)

        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        batch_info = federation.get(
            name=self.transfer_variable.batch_info.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.batch_info),
            idx=0)
        LOGGER.info("Get batch_info from guest:" + str(batch_info))
        self.batch_size = batch_info["batch_size"]
        self.batch_num = batch_info["batch_num"]

        LOGGER.info("Start initialize model.")
        model_shape = data_overview.get_features_shape(data_instances)

        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        if self.fit_intercept:
            self.fit_intercept = False

        self.coef_ = self.initializer.init_model(
            model_shape, init_params=self.init_param_obj)

        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            batch_index = 0
            while batch_index < self.batch_num:
                LOGGER.info("batch:{}".format(batch_index))
                # set batch_data
                if len(self.batch_index_list) < self.batch_num:
                    batch_data_index = federation.get(
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get batch_index from Guest")

                    batch_size = batch_data_index.count()
                    if batch_size < consts.MIN_BATCH_SIZE and batch_size != -1:
                        raise ValueError(
                            "Batch size get from guest should not less than 10, except -1, batch_size is {}"
                            .format(batch_size))

                    self.batch_index_list.append(batch_data_index)
                else:
                    batch_data_index = self.batch_index_list[batch_index]

                # Get mini-batch train data
                if len(index_data_inst_map) < self.batch_num:
                    batch_data_inst = batch_data_index.join(
                        data_instances, lambda g, d: d)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_index]

                LOGGER.info("batch_data_inst size:{}".format(
                    batch_data_inst.count()))
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data_inst)

                # compute forward
                host_forward = self.compute_forward(batch_feat_inst,
                                                    self.coef_,
                                                    self.intercept_)
                federation.remote(
                    host_forward,
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    role=consts.GUEST,
                    idx=0)
                LOGGER.info("Remote host_forward to guest")

                # compute host gradient
                fore_gradient = federation.get(
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get fore_gradient from guest")
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                host_gradient = self.gradient_operator.compute_gradient(
                    batch_feat_inst, fore_gradient, fit_intercept=False)
                # regulation if necessary
                if self.updater is not None:
                    loss_regular = self.updater.loss_norm(self.coef_)
                    en_loss_regular = self.encrypt_operator.encrypt(
                        loss_regular)
                    federation.remote(
                        en_loss_regular,
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        role=consts.GUEST,
                        idx=0)
                    LOGGER.info("Remote host_loss_regular to guest")

                federation.remote(
                    host_gradient,
                    name=self.transfer_variable.host_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote host_gradient to arbiter")

                # Get optimize host gradient and update model
                optim_host_gradient = federation.get(
                    name=self.transfer_variable.host_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_host_gradient from arbiter")

                LOGGER.info("update_model")
                self.update_model(optim_host_gradient)

                # update local model that transforms features of raw input 'batch_data_inst'
                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.update_local_model(fore_gradient, batch_data_inst,
                                        self.coef_, **training_info)

                # is converge

                batch_index += 1
                # if is_stopped:
                #    break

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped, self.n_iter_,
                    batch_index),
                idx=0)
            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                LOGGER.info(
                    "Get stop signal from arbiter, model is converged, iter:{}"
                    .format(self.n_iter_))
                break

        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))

    def predict(self, data_instances, predict_param=None):
        LOGGER.info("Start predict ...")

        data_features = self.transform(data_instances)

        prob_host = self.compute_wx(data_features, self.coef_, self.intercept_)
        federation.remote(prob_host,
                          name=self.transfer_variable.host_prob.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.host_prob),
                          role=consts.GUEST,
                          idx=0)
        LOGGER.info("Remote probability to Host")
Example #4
0
class HeteroLRGuest(BaseLogisticRegression):
    def __init__(self, logistic_params):
        super(HeteroLRGuest, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.data_batch_count = []

        self.encrypted_calculator = None

        self.guest_forward = None

    def compute_forward(self,
                        data_instances,
                        coef_,
                        intercept_,
                        batch_index=-1):
        """
        Compute W * X + b and (W * X + b)^2, where X is the input data, W is the coefficient of lr,
        and b is the interception
        Parameters
        ----------
        data_instance: DTable of Instance, input data
        coef_: list, coefficient of lr
        intercept_: float, the interception of lr
        """
        wx = self.compute_wx(data_instances, coef_, intercept_)

        en_wx = self.encrypted_calculator[batch_index].encrypt(wx)
        wx_square = wx.mapValues(lambda v: np.square(v))
        en_wx_square = self.encrypted_calculator[batch_index].encrypt(
            wx_square)

        en_wx_join_en_wx_square = en_wx.join(
            en_wx_square, lambda wx, wx_square: (wx, wx_square))
        self.guest_forward = en_wx_join_en_wx_square.join(
            wx, lambda e, wx: (e[0], e[1], wx))

        # temporary resource recovery and will be removed in the future
        rubbish_list = [
            wx, en_wx, wx_square, en_wx_square, en_wx_join_en_wx_square
        ]
        rubbish_clear(rubbish_list)

    def aggregate_forward(self, host_forward):
        """
        Compute (en_wx_g + en_wx_h)^2 = en_wx_g^2 + en_wx_h^2 + 2 * wx_g * en_wx_h , where en_wx_g is the encrypted W * X + b of guest, wx_g is unencrypted W * X + b,
        and en_wx_h is the encrypted W * X + b of host.
        #因为是在guest段aggregate的,所以这里的 wx_g不用加密
        Parameters
        ----------
        host_forward: DTable, include encrypted W * X and (W * X)^2

        Returns
        ----------
        aggregate_forward_res
        list
            include W * X and (W * X)^2 federate with guest and host
        """
        aggregate_forward_res = self.guest_forward.join(
            host_forward, lambda g, h:
            (g[0] + h[0], g[1] + h[1] + 2 * g[2] * h[0]))
        return aggregate_forward_res

    @staticmethod
    def load_data(data_instance):
        """
        set the negative label to -1
        Parameters
        ----------
        data_instance: DTable of Instance, input data
        """
        # 这里要将样本label=0的设为-1,方便计算logistic loss
        if data_instance.label != 1:
            data_instance.label = -1
        return data_instance

    def fit(self, data_instances):
        """
        Train lr model of role guest
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_lr_guest fit")
        self._abnormal_detection(data_instances)

        self.header = self.get_header(data_instances)
        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)

        # 获得密钥
        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        LOGGER.info("Generate mini-batch from input data")
        mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size)
        batch_num = mini_batch_obj.batch_nums
        if self.batch_size == -1:
            LOGGER.info(
                "batch size is -1, set it to the number of data in data_instances"
            )
            self.batch_size = data_instances.count()

        batch_info = {"batch_size": self.batch_size, "batch_num": batch_num}
        LOGGER.info("batch_info:{}".format(batch_info))
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.HOST,
                          idx=0)
        LOGGER.info("Remote batch_info to Host")
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.ARBITER,
                          idx=0)
        LOGGER.info("Remote batch_info to Arbiter")

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.encrypt_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(batch_num)
        ]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        weight = self.initializer.init_model(model_shape,
                                             init_params=self.init_param_obj)
        if self.init_param_obj.fit_intercept is True:
            self.coef_ = weight[:-1]
            self.intercept_ = weight[-1]
        else:
            self.coef_ = weight

        is_send_all_batch_index = False
        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = mini_batch_obj.mini_batch_data_generator(
                result='index')

            batch_index = 0
            for batch_data_index in batch_data_generator:
                LOGGER.info("batch:{}".format(batch_index))
                if not is_send_all_batch_index:
                    LOGGER.info("remote mini-batch index to Host")
                    federation.remote(
                        batch_data_index,
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        role=consts.HOST,
                        idx=0)
                    if batch_index >= mini_batch_obj.batch_nums - 1:
                        is_send_all_batch_index = True

                # Get mini-batch train data
                if len(index_data_inst_map) < batch_num:
                    batch_data_inst = data_instances.join(
                        batch_data_index, lambda data_inst, index: data_inst)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_index]

                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data_inst)

                # guest/host forward
                self.compute_forward(batch_feat_inst, self.coef_,
                                     self.intercept_, batch_index)
                host_forward = federation.get(
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get host_forward from host")
                aggregate_forward_res = self.aggregate_forward(host_forward)
                en_aggregate_wx = aggregate_forward_res.mapValues(
                    lambda v: v[0])
                en_aggregate_wx_square = aggregate_forward_res.mapValues(
                    lambda v: v[1])

                # compute [[d]]
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                fore_gradient = self.gradient_operator.compute_fore_gradient(
                    batch_feat_inst, en_aggregate_wx)
                federation.remote(
                    fore_gradient,
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    role=consts.HOST,
                    idx=0)

                LOGGER.info("Remote fore_gradient to Host")
                # compute guest gradient and loss
                guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss(
                    batch_feat_inst, fore_gradient, en_aggregate_wx,
                    en_aggregate_wx_square, self.fit_intercept)

                # loss regulation if necessary
                if self.updater is not None:
                    guest_loss_regular = self.updater.loss_norm(self.coef_)
                    loss += self.encrypt_operator.encrypt(guest_loss_regular)

                federation.remote(
                    guest_gradient,
                    name=self.transfer_variable.guest_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote guest_gradient to arbiter")

                optim_guest_gradient = federation.get(
                    name=self.transfer_variable.guest_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_guest_gradient from arbiter")

                # update model
                LOGGER.info("update_model")
                self.update_model(optim_guest_gradient)

                # update local model that transforms features of raw input 'batch_data_inst'
                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.update_local_model(fore_gradient, batch_data_inst,
                                        self.coef_, **training_info)

                # Get loss regulation from Host if regulation is set
                if self.updater is not None:
                    en_host_loss_regular = federation.get(
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get host_loss_regular from Host")
                    loss += en_host_loss_regular

                federation.remote(
                    loss,
                    name=self.transfer_variable.loss.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.loss, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote loss to arbiter")

                # is converge of loss in arbiter
                batch_index += 1

                # temporary resource recovery and will be removed in the future
                rubbish_list = [
                    host_forward, aggregate_forward_res, en_aggregate_wx,
                    en_aggregate_wx_square, fore_gradient, self.guest_forward
                ]
                rubbish_clear(rubbish_list)

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped, self.n_iter_,
                    batch_index),
                idx=0)
            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                LOGGER.info(
                    "Get stop signal from arbiter, model is converged, iter:{}"
                    .format(self.n_iter_))
                break

        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))

    def predict(self, data_instances, predict_param):
        """
        Prediction of lr
        Parameters
        ----------
        data_instance:DTable of Instance, input data
        predict_param: PredictParam, the setting of prediction.

        Returns
        ----------
        DTable
            include input data label, predict probably, label
        """
        LOGGER.info("Start predict ...")

        data_features = self.transform(data_instances)

        prob_guest = self.compute_wx(data_features, self.coef_,
                                     self.intercept_)
        prob_host = federation.get(
            name=self.transfer_variable.host_prob.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.host_prob),
            idx=0)
        LOGGER.info("Get probability from Host")

        # guest probability
        pred_prob = prob_guest.join(prob_host,
                                    lambda g, h: activation.sigmoid(g + h))
        pred_label = self.classified(pred_prob, predict_param.threshold)
        if predict_param.with_proba:
            labels = data_instances.mapValues(lambda v: v.label)
            predict_result = labels.join(pred_prob, lambda label, prob:
                                         (label, prob))
        else:
            predict_result = data_instances.mapValues(lambda v:
                                                      (v.label, None))

        predict_result = predict_result.join(pred_label, lambda r, p:
                                             (r[0], r[1], p))
        return predict_result
class HeteroLRHost(BaseLogisticRegression):
    def __init__(self, logistic_params):
        # LogisticParamChecker.check_param(logistic_params)
        super(HeteroLRHost, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.batch_num = None
        self.batch_index_list = []

    def compute_forward(self,
                        data_instances,
                        coef_,
                        intercept_,
                        batch_index=-1):
        """
        Compute W * X + b and (W * X + b)^2, where X is the input data, W is the coefficient of lr,
        and b is the interception
        Parameters
        ----------
        data_instance: DTable of Instance, input data
        coef_: list, coefficient of lr
        intercept_: float, the interception of lr
        """
        wx = self.compute_wx(data_instances, coef_, intercept_)
        en_wx = self.encrypted_calculator[batch_index].encrypt(wx)
        wx_square = wx.mapValues(lambda v: np.square(v))
        en_wx_square = self.encrypted_calculator[batch_index].encrypt(
            wx_square)

        host_forward = en_wx.join(en_wx_square, lambda wx, wx_square:
                                  (wx, wx_square))

        # temporary resource recovery and will be removed in the future
        rubbish_list = [wx, en_wx, wx_square, en_wx_square]
        rubbish_clear(rubbish_list)

        return host_forward

    def fit(self, data_instances):
        """
        Train lr model of role host
        Parameters
        ----------
        data_instances: DTable of Instance, input data
        """

        LOGGER.info("Enter hetero_lr host")
        self._abnormal_detection(data_instances)

        self.header = self.get_header(data_instances)
        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)

        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        batch_info = federation.get(
            name=self.transfer_variable.batch_info.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.batch_info),
            idx=0)
        LOGGER.info("Get batch_info from guest:" + str(batch_info))
        self.batch_size = batch_info["batch_size"]
        self.batch_num = batch_info["batch_num"]
        if self.batch_size < consts.MIN_BATCH_SIZE and self.batch_size != -1:
            raise ValueError(
                "Batch size get from guest should not less than 10, except -1, batch_size is {}"
                .format(self.batch_size))

        self.encrypted_calculator = [
            EncryptModeCalculator(
                self.encrypt_operator,
                self.encrypted_mode_calculator_param.mode,
                self.encrypted_mode_calculator_param.re_encrypted_rate)
            for _ in range(self.batch_num)
        ]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)

        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        if self.fit_intercept:
            self.fit_intercept = False

        self.coef_ = self.initializer.init_model(
            model_shape, init_params=self.init_param_obj)

        self.n_iter_ = 0
        index_data_inst_map = {}

        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            batch_index = 0
            while batch_index < self.batch_num:
                LOGGER.info("batch:{}".format(batch_index))
                # set batch_data
                if len(self.batch_index_list) < self.batch_num:
                    batch_data_index = federation.get(
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get batch_index from Guest")
                    self.batch_index_list.append(batch_data_index)
                else:
                    batch_data_index = self.batch_index_list[batch_index]

                # Get mini-batch train data
                if len(index_data_inst_map) < self.batch_num:
                    batch_data_inst = batch_data_index.join(
                        data_instances, lambda g, d: d)
                    index_data_inst_map[batch_index] = batch_data_inst
                else:
                    batch_data_inst = index_data_inst_map[batch_index]

                LOGGER.info("batch_data_inst size:{}".format(
                    batch_data_inst.count()))
                # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst'
                batch_feat_inst = self.transform(batch_data_inst)

                # compute forward
                host_forward = self.compute_forward(batch_feat_inst,
                                                    self.coef_,
                                                    self.intercept_,
                                                    batch_index)
                federation.remote(
                    host_forward,
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    role=consts.GUEST,
                    idx=0)
                LOGGER.info("Remote host_forward to guest")

                # compute host gradient
                fore_gradient = federation.get(
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get fore_gradient from guest")
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                host_gradient = self.gradient_operator.compute_gradient(
                    batch_feat_inst, fore_gradient, fit_intercept=False)
                # regulation if necessary
                if self.updater is not None:
                    loss_regular = self.updater.loss_norm(self.coef_)
                    en_loss_regular = self.encrypt_operator.encrypt(
                        loss_regular)
                    federation.remote(
                        en_loss_regular,
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        role=consts.GUEST,
                        idx=0)
                    LOGGER.info("Remote host_loss_regular to guest")

                federation.remote(
                    host_gradient,
                    name=self.transfer_variable.host_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote host_gradient to arbiter")

                # Get optimize host gradient and update model
                optim_host_gradient = federation.get(
                    name=self.transfer_variable.host_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_host_gradient from arbiter")

                LOGGER.info("update_model")
                self.update_model(optim_host_gradient)

                # update local model that transforms features of raw input 'batch_data_inst'
                training_info = {
                    "iteration": self.n_iter_,
                    "batch_index": batch_index
                }
                self.update_local_model(fore_gradient, batch_data_inst,
                                        self.coef_, **training_info)

                batch_index += 1

                # temporary resource recovery and will be removed in the future
                rubbish_list = [host_forward, fore_gradient]
                rubbish_clear(rubbish_list)

            is_stopped = federation.get(
                name=self.transfer_variable.is_stopped.name,
                tag=self.transfer_variable.generate_transferid(
                    self.transfer_variable.is_stopped, self.n_iter_,
                    batch_index),
                idx=0)
            LOGGER.info("Get is_stop flag from arbiter:{}".format(is_stopped))

            self.n_iter_ += 1
            if is_stopped:
                LOGGER.info(
                    "Get stop signal from arbiter, model is converged, iter:{}"
                    .format(self.n_iter_))
                break

        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))

    def predict(self, data_instances, predict_param=None):
        """
        Prediction of lr
        Parameters
        ----------
        data_instance:DTable of Instance, input data
        predict_param: PredictParam, the setting of prediction. Host may not have predict_param
        """
        LOGGER.info("Start predict ...")

        data_features = self.transform(data_instances)

        prob_host = self.compute_wx(data_features, self.coef_, self.intercept_)
        federation.remote(prob_host,
                          name=self.transfer_variable.host_prob.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.host_prob),
                          role=consts.GUEST,
                          idx=0)
        LOGGER.info("Remote probability to Guest")
Example #6
0
class HeteroLRHost(BaseLogisticRegression):
    def __init__(self, logistic_params):
        super(HeteroLRHost, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.batch_num = None
        self.batch_index_list = []

    def compute_forward(self, data_instances, coef_, intercept_):
        wx = self.compute_wx(data_instances, coef_, intercept_)
        encrypt_operator = self.encrypt_operator
        host_forward = wx.mapValues(lambda v: (encrypt_operator.encrypt(
            v), encrypt_operator.encrypt(np.square(v))))
        return host_forward

    def fit(self, data_instances):
        LOGGER.info("Enter hetero_lr host")
        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)

        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        batch_info = federation.get(
            name=self.transfer_variable.batch_info.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.batch_info),
            idx=0)
        LOGGER.info("Get batch_info from guest:" + str(batch_info))
        self.batch_size = batch_info["batch_size"]
        self.batch_num = batch_info["batch_num"]

        LOGGER.info("Start initialize model.")
        model_shape = self.get_features_shape(data_instances)

        if self.init_param_obj.fit_intercept:
            self.init_param_obj.fit_intercept = False

        if self.fit_intercept:
            self.fit_intercept = False

        self.coef_ = self.initializer.init_model(
            model_shape, init_params=self.init_param_obj)

        is_stopped = False
        self.n_iter_ = 0
        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:" + str(self.n_iter_))
            batch_index = 0
            while batch_index < self.batch_num:
                # set batch_data
                if len(self.batch_index_list) < self.batch_num:
                    batch_data_index = federation.get(
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get batch_index from Guest")
                    self.batch_index_list.append(batch_data_index)
                else:
                    batch_data_index = self.batch_index_list[batch_index]

                # Get mini-batch train data
                batch_data_inst = batch_data_index.join(
                    data_instances, lambda g, d: d)

                # compute forward
                host_forward = self.compute_forward(batch_data_inst,
                                                    self.coef_,
                                                    self.intercept_)
                federation.remote(
                    host_forward,
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    role=consts.GUEST,
                    idx=0)
                LOGGER.info("Remote host_forward to guest")

                # compute host gradient
                fore_gradient = federation.get(
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get fore_gradient from guest")
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                host_gradient = self.gradient_operator.compute_gradient(
                    data_instances, fore_gradient, fit_intercept=False)
                # regulation if necessary
                if self.updater is not None:
                    loss_regular = self.updater.loss_norm(self.coef_)
                    en_loss_regular = self.encrypt_operator.encrypt(
                        loss_regular)
                    federation.remote(
                        en_loss_regular,
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        role=consts.GUEST,
                        idx=0)
                    LOGGER.info("Remote host_loss_regular to guest")

                federation.remote(
                    host_gradient,
                    name=self.transfer_variable.host_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote host_gradient to arbiter")

                # Get optimize host gradient and update model
                optim_host_gradient = federation.get(
                    name=self.transfer_variable.host_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_host_gradient from arbiter")

                LOGGER.info("update_model")
                self.update_model(optim_host_gradient)

                # is converge
                is_stopped = federation.get(
                    name=self.transfer_variable.is_stopped.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.is_stopped, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info(
                    "Get is_stop flag from arbiter:{}".format(is_stopped))

                batch_index += 1
                if is_stopped:
                    LOGGER.info(
                        "Get stop signal from arbiter, model is converged, iter:{}"
                        .format(self.n_iter_))
                    break

            self.n_iter_ += 1
            if is_stopped:
                break

        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))

    def predict(self, data_instances, predict_param=None):
        LOGGER.info("Start predict ...")
        prob_host = self.compute_wx(data_instances, self.coef_,
                                    self.intercept_)
        federation.remote(prob_host,
                          name=self.transfer_variable.host_prob.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.host_prob),
                          role=consts.GUEST,
                          idx=0)
        LOGGER.info("Remote probability to Host")
Example #7
0
class HeteroLRGuest(BaseLogisticRegression):
    def __init__(self, logistic_params):
        super(HeteroLRGuest, self).__init__(logistic_params)
        self.transfer_variable = HeteroLRTransferVariable()
        self.data_batch_count = []

        self.wx = None
        self.guest_forward = None

    def compute_forward(self, data_instances, coef_, intercept_):
        self.wx = self.compute_wx(data_instances, coef_, intercept_)
        encrypt_operator = self.encrypt_operator
        self.guest_forward = self.wx.mapValues(
            lambda v: (encrypt_operator.encrypt(v),
                       encrypt_operator.encrypt(np.square(v)), v))

    def aggregate_forward(self, host_forward):
        aggregate_forward_res = self.guest_forward.join(
            host_forward, lambda g, h:
            (g[0] + h[0], g[1] + h[1] + 2 * g[2] * h[0]))
        return aggregate_forward_res

    @staticmethod
    def load_data(data_instance):
        if data_instance.label != 1:
            data_instance.label = -1
        return data_instance

    def fit(self, data_instances):
        LOGGER.info("Enter hetero_lr_guest fit")
        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)

        public_key = federation.get(
            name=self.transfer_variable.paillier_pubkey.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.paillier_pubkey),
            idx=0)
        LOGGER.info("Get public_key from arbiter:{}".format(public_key))
        self.encrypt_operator.set_public_key(public_key)

        LOGGER.info("Generate mini-batch from input data")
        mini_batch_obj = MiniBatch(data_instances, batch_size=self.batch_size)
        batch_info = {
            "batch_size": self.batch_size,
            "batch_num": mini_batch_obj.batch_nums
        }
        LOGGER.info("batch_info:" + str(batch_info))
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.HOST,
                          idx=0)
        LOGGER.info("Remote batch_info to Host")
        federation.remote(batch_info,
                          name=self.transfer_variable.batch_info.name,
                          tag=self.transfer_variable.generate_transferid(
                              self.transfer_variable.batch_info),
                          role=consts.ARBITER,
                          idx=0)
        LOGGER.info("Remote batch_info to Arbiter")

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        weight = self.initializer.init_model(model_shape,
                                             init_params=self.init_param_obj)
        if self.init_param_obj.fit_intercept is True:
            self.coef_ = weight[:-1]
            self.intercept_ = weight[-1]
        else:
            self.coef_ = weight

        is_stopped = False
        is_send_all_batch_index = False
        self.n_iter_ = 0
        while self.n_iter_ < self.max_iter:
            LOGGER.info("iter:{}".format(self.n_iter_))
            batch_data_generator = mini_batch_obj.mini_batch_index_generator(
                data_inst=data_instances, batch_size=self.batch_size)
            batch_index = 0
            for batch_data_index in batch_data_generator:
                LOGGER.info("batch:{}".format(batch_index))
                if not is_send_all_batch_index:
                    LOGGER.info("remote mini-batch index to Host")
                    federation.remote(
                        batch_data_index,
                        name=self.transfer_variable.batch_data_index.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.batch_data_index,
                            self.n_iter_, batch_index),
                        role=consts.HOST,
                        idx=0)
                    if batch_index >= mini_batch_obj.batch_nums - 1:
                        is_send_all_batch_index = True

                # Get mini-batch train data
                batch_data_inst = data_instances.join(
                    batch_data_index, lambda data_inst, index: data_inst)

                # guest/host forward
                self.compute_forward(batch_data_inst, self.coef_,
                                     self.intercept_)
                host_forward = federation.get(
                    name=self.transfer_variable.host_forward_dict.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.host_forward_dict, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info("Get host_forward from host")
                aggregate_forward_res = self.aggregate_forward(host_forward)
                en_aggregate_wx = aggregate_forward_res.mapValues(
                    lambda v: v[0])
                en_aggregate_wx_square = aggregate_forward_res.mapValues(
                    lambda v: v[1])

                # compute [[d]]
                if self.gradient_operator is None:
                    self.gradient_operator = HeteroLogisticGradient(
                        self.encrypt_operator)
                fore_gradient = self.gradient_operator.compute_fore_gradient(
                    batch_data_inst, en_aggregate_wx)
                federation.remote(
                    fore_gradient,
                    name=self.transfer_variable.fore_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.fore_gradient, self.n_iter_,
                        batch_index),
                    role=consts.HOST,
                    idx=0)
                LOGGER.info("Remote fore_gradient to Host")
                # compute guest gradient and loss
                guest_gradient, loss = self.gradient_operator.compute_gradient_and_loss(
                    batch_data_inst, fore_gradient, en_aggregate_wx,
                    en_aggregate_wx_square, self.fit_intercept)

                # loss regulation if necessary
                if self.updater is not None:
                    guest_loss_regular = self.updater.loss_norm(self.coef_)
                    loss += self.encrypt_operator.encrypt(guest_loss_regular)

                federation.remote(
                    guest_gradient,
                    name=self.transfer_variable.guest_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_gradient, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote guest_gradient to arbiter")

                optim_guest_gradient = federation.get(
                    name=self.transfer_variable.guest_optim_gradient.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.guest_optim_gradient,
                        self.n_iter_, batch_index),
                    idx=0)
                LOGGER.info("Get optim_guest_gradient from arbiter")

                # update model
                LOGGER.info("update_model")
                self.update_model(optim_guest_gradient)

                # Get loss regulation from Host if regulation is set
                if self.updater is not None:
                    en_host_loss_regular = federation.get(
                        name=self.transfer_variable.host_loss_regular.name,
                        tag=self.transfer_variable.generate_transferid(
                            self.transfer_variable.host_loss_regular,
                            self.n_iter_, batch_index),
                        idx=0)
                    LOGGER.info("Get host_loss_regular from Host")
                    loss += en_host_loss_regular

                federation.remote(
                    loss,
                    name=self.transfer_variable.loss.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.loss, self.n_iter_,
                        batch_index),
                    role=consts.ARBITER,
                    idx=0)
                LOGGER.info("Remote loss to arbiter")

                # is converge of loss in arbiter
                is_stopped = federation.get(
                    name=self.transfer_variable.is_stopped.name,
                    tag=self.transfer_variable.generate_transferid(
                        self.transfer_variable.is_stopped, self.n_iter_,
                        batch_index),
                    idx=0)
                LOGGER.info(
                    "Get is_stop flag from arbiter:{}".format(is_stopped))
                batch_index += 1
                if is_stopped:
                    LOGGER.info(
                        "Get stop signal from arbiter, model is converged, iter:{}"
                        .format(self.n_iter_))
                    break

            self.n_iter_ += 1
            if is_stopped:
                break
        LOGGER.info("Reach max iter {}, train model finish!".format(
            self.max_iter))

    def predict(self, data_instances, predict_param):
        LOGGER.info("Start predict ...")
        prob_guest = self.compute_wx(data_instances, self.coef_,
                                     self.intercept_)
        prob_host = federation.get(
            name=self.transfer_variable.host_prob.name,
            tag=self.transfer_variable.generate_transferid(
                self.transfer_variable.host_prob),
            idx=0)
        LOGGER.info("Get probability from Host")

        # guest probability
        pred_prob = prob_guest.join(prob_host,
                                    lambda g, h: activation.sigmoid(g + h))
        pred_label = self.classified(pred_prob, predict_param.threshold)
        if predict_param.with_proba:
            labels = data_instances.mapValues(lambda v: v.label)
            predict_result = labels.join(pred_prob, lambda label, prob:
                                         (label, prob))
        else:
            predict_result = data_instances.mapValues(lambda v:
                                                      (v.label, None))

        predict_result = predict_result.join(pred_label, lambda r, p:
                                             (r[0], r[1], p))
        return predict_result