예제 #1
0
    def fit(self, data_instances, validate_data=None):
        """
        Train lr model of role guest
        Parameters
        ----------
        data_instances: Table of Instance, input data
        """

        LOGGER.info("Enter hetero_lr_guest fit")
        # self._abnormal_detection(data_instances)
        # self.check_abnormal_values(data_instances)
        # self.check_abnormal_values(validate_data)
        # self.header = self.get_header(data_instances)
        self.prepare_fit(data_instances, validate_data)

        classes = self.one_vs_rest_obj.get_data_classes(data_instances)

        if with_weight(data_instances):
            data_instances = scale_sample_weight(data_instances)
            self.gradient_loss_operator.set_use_sample_weight()
            LOGGER.debug(
                f"instance weight scaled; use weighted gradient loss operator")

        if len(classes) > 2:
            self.need_one_vs_rest = True
            self.need_call_back_loss = False
            self.one_vs_rest_fit(train_data=data_instances,
                                 validate_data=validate_data)
        else:
            self.need_one_vs_rest = False
            self.fit_binary(data_instances, validate_data)
예제 #2
0
    def fit(self, data_instances, validate_data=None):
        """
        Train linR model of role guest
        Parameters
        ----------
        data_instances: Table of Instance, input data
        """

        LOGGER.info("Enter hetero_linR_guest fit")
        self._abnormal_detection(data_instances)
        self.header = self.get_header(data_instances)
        self.callback_list.on_train_begin(data_instances, validate_data)
        # self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        use_async = False
        if with_weight(data_instances):
            if self.model_param.early_stop == "diff":
                LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.")
            data_instances = scale_sample_weight(data_instances)
            self.gradient_loss_operator.set_use_sample_weight()
            LOGGER.debug(f"instance weight scaled; use weighted gradient loss operator")
            # LOGGER.debug(f"data_instances after scale: {[v[1].weight for v in list(data_instances.collect())]}")
        elif len(self.component_properties.host_party_idlist) == 1:
            LOGGER.debug(f"set_use_async")
            self.gradient_loss_operator.set_use_async()
            use_async = True
        self.transfer_variable.use_async.remote(use_async)

        LOGGER.info("Generate mini-batch from input data")
        self.batch_generator.initialize_batch_generator(data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)

        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        if not self.component_properties.is_warm_start:
            w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
            self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept, raise_overflow_error=False)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        while self.n_iter_ < self.max_iter:
            self.callback_list.on_epoch_begin(self.n_iter_)
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # Start gradient procedure
                optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_data,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index
                )

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(batch_data, self.n_iter_, batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
            LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))

            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

            if self.is_converged:
                break
        self.callback_list.on_train_end()

        self.set_summary(self.get_model_summary())
예제 #3
0
    def fit_single_model(self, data_instances, validate_data=None):
        LOGGER.info(f"Start to train single {self.model_name}")
        if len(self.component_properties.host_party_idlist) > 1:
            raise ValueError(f"Hetero SSHE Model does not support multi-host training.")
        self.callback_list.on_train_begin(data_instances, validate_data)

        model_shape = self.get_features_shape(data_instances)
        instances_count = data_instances.count()

        if not self.component_properties.is_warm_start:
            w = self._init_weights(model_shape)
            self.model_weights = LinearModelWeights(l=w,
                                                    fit_intercept=self.model_param.init_param.fit_intercept)
            last_models = copy.deepcopy(self.model_weights)
        else:
            last_models = copy.deepcopy(self.model_weights)
            w = last_models.unboxed
            self.callback_warm_start_init_iter(self.n_iter_)

        if self.role == consts.GUEST:
            if with_weight(data_instances):
                LOGGER.info(f"data with sample weight, use sample weight.")
                if self.model_param.early_stop == "diff":
                    LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.")
                data_instances = scale_sample_weight(data_instances)
        self.batch_generator.initialize_batch_generator(data_instances, batch_size=self.batch_size)

        with SPDZ(
            "hetero_sshe",
            local_party=self.local_party,
            all_parties=self.parties,
            q_field=self.q_field,
            use_mix_rand=self.model_param.use_mix_rand,
        ) as spdz:
            spdz.set_flowid(self.flowid)
            self.secure_matrix_obj.set_flowid(self.flowid)
            # not sharing the model when reveal_every_iter
            if not self.reveal_every_iter:
                w_self, w_remote = self.share_model(w, suffix="init")
                last_w_self, last_w_remote = w_self, w_remote
                LOGGER.debug(f"first_w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}")
            batch_data_generator = self.batch_generator.generate_batch_data()

            encoded_batch_data = []
            batch_labels_list = []
            batch_weight_list = []

            for batch_data in batch_data_generator:
                if self.fit_intercept:
                    batch_features = batch_data.mapValues(lambda x: np.hstack((x.features, 1.0)))
                else:
                    batch_features = batch_data.mapValues(lambda x: x.features)
                if self.role == consts.GUEST:
                    batch_labels = batch_data.mapValues(lambda x: np.array([x.label], dtype=self.label_type))
                    batch_labels_list.append(batch_labels)
                    if self.weight:
                        batch_weight = batch_data.mapValues(lambda x: np.array([x.weight], dtype=float))
                        batch_weight_list.append(batch_weight)
                    else:
                        batch_weight_list.append(None)

                self.batch_num.append(batch_data.count())

                encoded_batch_data.append(
                    fixedpoint_table.FixedPointTensor(self.fixedpoint_encoder.encode(batch_features),
                                                      q_field=self.fixedpoint_encoder.n,
                                                      endec=self.fixedpoint_encoder))

            while self.n_iter_ < self.max_iter:
                self.callback_list.on_epoch_begin(self.n_iter_)
                LOGGER.info(f"start to n_iter: {self.n_iter_}")

                loss_list = []

                self.optimizer.set_iters(self.n_iter_)
                if not self.reveal_every_iter:
                    self.self_optimizer.set_iters(self.n_iter_)
                    self.remote_optimizer.set_iters(self.n_iter_)

                for batch_idx, batch_data in enumerate(encoded_batch_data):
                    current_suffix = (str(self.n_iter_), str(batch_idx))
                    if self.role == consts.GUEST:
                        batch_labels = batch_labels_list[batch_idx]
                        batch_weight = batch_weight_list[batch_idx]
                    else:
                        batch_labels = None
                        batch_weight = None

                    if self.reveal_every_iter:
                        y = self.forward(weights=self.model_weights,
                                         features=batch_data,
                                         labels=batch_labels,
                                         suffix=current_suffix,
                                         cipher=self.cipher,
                                         batch_weight=batch_weight)
                    else:
                        y = self.forward(weights=(w_self, w_remote),
                                         features=batch_data,
                                         labels=batch_labels,
                                         suffix=current_suffix,
                                         cipher=self.cipher,
                                         batch_weight=batch_weight)

                    if self.role == consts.GUEST:
                        if self.weight:
                            error = y - batch_labels.join(batch_weight, lambda y, b: y * b)
                        else:
                            error = y - batch_labels

                        self_g, remote_g = self.backward(error=error,
                                                         features=batch_data,
                                                         suffix=current_suffix,
                                                         cipher=self.cipher)
                    else:
                        self_g, remote_g = self.backward(error=y,
                                                         features=batch_data,
                                                         suffix=current_suffix,
                                                         cipher=self.cipher)

                    # loss computing;
                    suffix = ("loss",) + current_suffix
                    if self.reveal_every_iter:
                        batch_loss = self.compute_loss(weights=self.model_weights,
                                                       labels=batch_labels,
                                                       suffix=suffix,
                                                       cipher=self.cipher)
                    else:
                        batch_loss = self.compute_loss(weights=(w_self, w_remote),
                                                       labels=batch_labels,
                                                       suffix=suffix,
                                                       cipher=self.cipher)

                    if batch_loss is not None:
                        batch_loss = batch_loss * self.batch_num[batch_idx]
                    loss_list.append(batch_loss)

                    if self.reveal_every_iter:
                        # LOGGER.debug(f"before reveal: self_g shape: {self_g.shape}, remote_g_shape: {remote_g},"
                        #              f"self_g: {self_g}")

                        new_g = self.reveal_models(self_g, remote_g, suffix=current_suffix)

                        # LOGGER.debug(f"after reveal: new_g shape: {new_g.shape}, new_g: {new_g}"
                        #              f"self.model_param.reveal_strategy: {self.model_param.reveal_strategy}")

                        if new_g is not None:
                            self.model_weights = self.optimizer.update_model(self.model_weights, new_g,
                                                                             has_applied=False)

                        else:
                            self.model_weights = LinearModelWeights(
                                l=np.zeros(self_g.shape),
                                fit_intercept=self.model_param.init_param.fit_intercept)
                    else:
                        if self.optimizer.penalty == consts.L2_PENALTY:
                            self_g = self_g + self.self_optimizer.alpha * w_self
                            remote_g = remote_g + self.remote_optimizer.alpha * w_remote

                        # LOGGER.debug(f"before optimizer: {self_g}, {remote_g}")

                        self_g = self.self_optimizer.apply_gradients(self_g)
                        remote_g = self.remote_optimizer.apply_gradients(remote_g)

                        # LOGGER.debug(f"after optimizer: {self_g}, {remote_g}")
                        w_self -= self_g
                        w_remote -= remote_g

                        LOGGER.debug(f"w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}")

                if self.role == consts.GUEST:
                    loss = np.sum(loss_list) / instances_count
                    self.loss_history.append(loss)
                    if self.need_call_back_loss:
                        self.callback_loss(self.n_iter_, loss)
                else:
                    loss = None

                if self.converge_func_name in ["diff", "abs"]:
                    self.is_converged = self.check_converge_by_loss(loss, suffix=(str(self.n_iter_),))
                elif self.converge_func_name == "weight_diff":
                    if self.reveal_every_iter:
                        self.is_converged = self.check_converge_by_weights(
                            last_w=last_models.unboxed,
                            new_w=self.model_weights.unboxed,
                            suffix=(str(self.n_iter_),))
                        last_models = copy.deepcopy(self.model_weights)
                    else:
                        self.is_converged = self.check_converge_by_weights(
                            last_w=(last_w_self, last_w_remote),
                            new_w=(w_self, w_remote),
                            suffix=(str(self.n_iter_),))
                        last_w_self, last_w_remote = copy.deepcopy(w_self), copy.deepcopy(w_remote)
                else:
                    raise ValueError(f"Cannot recognize early_stop function: {self.converge_func_name}")

                LOGGER.info("iter: {},  is_converged: {}".format(self.n_iter_, self.is_converged))
                self.callback_list.on_epoch_end(self.n_iter_)
                self.n_iter_ += 1

                if self.stop_training:
                    break

                if self.is_converged:
                    break

            # Finally reconstruct
            if not self.reveal_every_iter:
                new_w = self.reveal_models(w_self, w_remote, suffix=("final",))
                if new_w is not None:
                    self.model_weights = LinearModelWeights(
                        l=new_w,
                        fit_intercept=self.model_param.init_param.fit_intercept)

        LOGGER.debug(f"loss_history: {self.loss_history}")
        self.set_summary(self.get_model_summary())