Esempio n. 1
0
    def add_summary(self, new_key, new_value):
        """
        Add key:value pair to model summary
        Parameters
        ----------
        new_key: str
        new_value: object

        Returns
        -------

        """

        original_value = self._summary.get(new_key, None)
        if original_value is not None:
            LOGGER.warning(
                f"{new_key} already exists in model summary."
                f"Corresponding value {original_value} will be replaced by {new_value}"
            )
        self._summary[new_key] = new_value
        LOGGER.debug(f"{new_key}: {new_value} added to summary.")
Esempio n. 2
0
    def fit_buckets(self, bucket_table, sample_count):
        if self.optimal_param.metric_method in ['iv', 'gini', 'chi_square']:
            optimal_binning_method = functools.partial(self.merge_optimal_binning,
                                                       optimal_param=self.optimal_param,
                                                       sample_count=sample_count)
        else:
            optimal_binning_method = functools.partial(self.split_optimal_binning,
                                                       optimal_param=self.optimal_param,
                                                       sample_count=sample_count)
        result_bucket = bucket_table.mapValues(optimal_binning_method)
        for col_name, (bucket_list, non_mixture_num, small_size_num) in result_bucket.collect():
            split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist()

            self.bin_results.put_col_split_points(col_name, split_points)
            self.__cal_single_col_result(col_name, bucket_list)
            if self.optimal_param.mixture and non_mixture_num > 0:
                LOGGER.warning("col: {}, non_mixture_num is: {}, cannot meet mixture condition".format(
                    col_name, non_mixture_num
                ))
            if small_size_num > 0:
                LOGGER.warning("col: {}, small_size_num is: {}, cannot meet small size condition".format(
                    col_name, small_size_num
                ))
            if len(bucket_list) > self.optimal_param.max_bin:
                LOGGER.warning("col: {}, bin_num is: {}, cannot meet max-bin condition".format(
                    col_name, small_size_num
                ))
        return result_bucket
Esempio n. 3
0
    def __share_info(self, data):
        LOGGER.info("Start to share information with another role")
        info_share = self.transfer_variable.info_share_from_guest if self.model_param.info_owner == consts.GUEST else \
            self.transfer_variable.info_share_from_host
        party_role = consts.GUEST if self.model_param.info_owner == consts.HOST else consts.HOST

        if self.role == self.model_param.info_owner:
            if data.schema.get('header') is not None:
                try:
                    share_info_col_idx = data.schema.get('header').index(consts.SHARE_INFO_COL_NAME)

                    one_data = data.first()
                    if isinstance(one_data[1], Instance):
                        share_data = data.join(self.intersect_ids, lambda d, i: [d.features[share_info_col_idx]])
                    else:
                        share_data = data.join(self.intersect_ids, lambda d, i: [d[share_info_col_idx]])

                    info_share.remote(share_data,
                                      role=party_role,
                                      idx=-1)
                    LOGGER.info("Remote share information to {}".format(party_role))

                except Exception as e:
                    LOGGER.warning("Something unexpected:{}, share a empty information to {}".format(e, party_role))
                    share_data = self.intersect_ids.mapValues(lambda v: ['null'])
                    info_share.remote(share_data,
                                      role=party_role,
                                      idx=-1)
            else:
                raise ValueError(
                    "'allow_info_share' is true, and 'info_owner' is {}, but can not get header in data, information sharing not done".format(
                        self.model_param.info_owner))
        else:
            self.intersect_ids = info_share.get(idx=0)
            self.intersect_ids.schema['header'] = [consts.SHARE_INFO_COL_NAME]
            LOGGER.info(
                "Get share information from {}, header:{}".format(self.model_param.info_owner, self.intersect_ids))

        return self.intersect_ids
Esempio n. 4
0
    def fit(self, data):
        """
        Apply scale for input data
        Parameters
        ----------
        data: data_instance, input data

        Returns
        ----------
        data:data_instance, data after scale
        scale_value_results: list, the fit results information of scale
        """
        LOGGER.info("Start scale data fit ...")

        if self.model_param.method == consts.MINMAXSCALE:
            self.scale_obj = MinMaxScale(self.model_param)
        elif self.model_param.method == consts.STANDARDSCALE:
            self.scale_obj = StandardScale(self.model_param)
        else:
            LOGGER.warning("Scale method is {}, do nothing and return!".format(self.model_param.method))

        if self.scale_obj:
            fit_data = self.scale_obj.fit(data)
            fit_data.schema = data.schema

            self.callback_meta(metric_name="scale", metric_namespace="train",
                               metric_meta=MetricMeta(name="scale", metric_type="SCALE",
                                                      extra_metas={"method": self.model_param.method}))
            
            LOGGER.info("start to get model summary ...")
            self.set_summary(self.scale_obj.get_model_summary())
            LOGGER.info("Finish getting model summary.")

        else:
            fit_data = data

        LOGGER.info("End fit data ...")
        return fit_data
Esempio n. 5
0
    def __generate_id_map(self, data) -> dict:
        if not self.repeated_id_owner:
            LOGGER.warning("Not a repeated id owner, will not generate id map")
            return {}

        one_feature = data.first()
        if isinstance(one_feature[1], Instance):
            data = data.mapValues(lambda v: v.features[0])
        else:
            data = data.mapValues(lambda v: v[0])

        local_data = data.collect()
        all_id_map = defaultdict(list)
        final_id_map = {}

        for _data in local_data:
            all_id_map[str(_data[1])].append(_data[0])

        for k, v in all_id_map.items():
            if len(v) >= 2:
                final_id_map[k] = v

        return final_id_map
Esempio n. 6
0
 def check(self):
     descr = "secure information retrieval param's "
     self.check_decimal_float(self.security_level, descr + "security_level")
     self.oblivious_transfer_protocol = self.check_and_change_lower(self.oblivious_transfer_protocol,
                                                                    [consts.OT_HAUCK.lower()],
                                                                    descr + "oblivious_transfer_protocol")
     self.commutative_encryption = self.check_and_change_lower(self.commutative_encryption,
                                                               [consts.CE_PH.lower()],
                                                               descr + "commutative_encryption")
     self.non_committing_encryption = self.check_and_change_lower(self.non_committing_encryption,
                                                                  [consts.AES.lower()],
                                                                  descr + "non_committing_encryption")
     if self._warn_to_deprecate_param("key_size", descr, "dh_param's key_length"):
         self.dh_params.key_length = self.key_size
     self.dh_params.check()
     if self._warn_to_deprecate_param("raw_retrieval", descr, "dh_param's security_level = 0"):
         self.check_boolean(self.raw_retrieval, descr)
     if not isinstance(self.target_cols, list):
         self.target_cols = [self.target_cols]
     for col in self.target_cols:
         self.check_string(col, descr + "target_cols")
     if len(self.target_cols) == 0:
         LOGGER.warning(f"Both 'target_cols' and 'target_indexes' are empty. Label will be retrieved.")
Esempio n. 7
0
    def transform(self, data_instances):
        LOGGER.info(f"Enter Sample Weight Transform")
        new_schema = copy.deepcopy(data_instances.schema)
        new_schema["sample_weight"] = "weight"
        weight_loc = None
        if self.weight_mode == "sample weight name":
            weight_loc = SampleWeight.get_weight_loc(data_instances,
                                                     self.sample_weight_name)
            if weight_loc is not None:
                new_schema["header"].pop(weight_loc)
            else:
                LOGGER.warning(
                    f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'."
                    f"Original input data returned")
                return data_instances
        result_instances = self.transform_weighted_instance(
            data_instances, weight_loc)
        result_instances.schema = new_schema

        self.callback_info()
        if result_instances.mapPartitions(check_negative_sample_weight).reduce(
                lambda x, y: x or y):
            LOGGER.warning(f"Negative weight found in weighted instances.")
        return result_instances
Esempio n. 8
0
    def check(self):
        if type(self.salt).__name__ != "str":
            raise ValueError(
                "encode param's salt {} not supported, should be str type".
                format(self.salt))

        descr = "encode param's "

        self.encode_method = self.check_and_change_lower(
            self.encode_method, [
                "none", consts.MD5, consts.SHA1, consts.SHA224, consts.SHA256,
                consts.SHA384, consts.SHA512, consts.SM3
            ], descr)

        if type(self.base64).__name__ != "bool":
            raise ValueError(
                "hash param's base64 {} not supported, should be bool type".
                format(self.base64))

        LOGGER.debug("Finish EncodeParam check!")
        LOGGER.warning(
            f"'EncodeParam' will be replaced by 'RAWParam' in future release."
            f"Please do not rely on current param naming in application.")
        return True
Esempio n. 9
0
    def fit(self, data_instances):
        if self.sample_weight_name is None and self.class_weight is None:
            return data_instances

        self.header = data_overview.get_header(data_instances)

        if self.class_weight:
            self.weight_mode = "class weight"

        if self.sample_weight_name and self.class_weight:
            LOGGER.warning(
                f"Both 'sample_weight_name' and 'class_weight' provided. "
                f"Only weight from 'sample_weight_name' is used.")

        new_schema = copy.deepcopy(data_instances.schema)
        new_schema["sample_weight"] = "weight"
        weight_loc = None
        if self.sample_weight_name:
            self.weight_mode = "sample weight name"
            weight_loc = SampleWeight.get_weight_loc(data_instances,
                                                     self.sample_weight_name)
            if weight_loc is not None:
                new_schema["header"].pop(weight_loc)
            else:
                raise ValueError(
                    f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'."
                )
        result_instances = self.transform_weighted_instance(
            data_instances, weight_loc)
        result_instances.schema = new_schema

        self.callback_info()
        if result_instances.mapPartitions(check_negative_sample_weight).reduce(
                lambda x, y: x or y):
            LOGGER.warning(f"Negative weight found in weighted instances.")
        return result_instances
Esempio n. 10
0
    def _evaluate_clustering_metrics(self, mode, data):

        eval_result = defaultdict(list)
        rs0, rs1, run_outer_metric = self._clustering_extract(data)
        if rs0 is None and rs1 is None:  # skip evaluation computation if get this input format
            LOGGER.debug('skip computing, this clustering format is not for metric computation')
            return eval_result

        if not run_outer_metric:
            no_label = (np.array(rs0) == None).all()
            if no_label:
                LOGGER.debug('no label found in clustering result, skip metric computation')
                return eval_result

        for eval_metric in self.metrics:

            # if input format and required metrics matches ? XNOR
            if not ((not (eval_metric in self.clustering_intra_metric_list) and not run_outer_metric) + \
                    ((eval_metric in self.clustering_intra_metric_list) and run_outer_metric)):
                LOGGER.warning('input data format does not match current clustering metric: {}'.format(eval_metric))
                continue

            LOGGER.debug('clustering_metrics is {}'.format(eval_metric))

            if run_outer_metric:

                if eval_metric == consts.DISTANCE_MEASURE:
                    res = getattr(self.metric_interface, eval_metric)(rs0['avg_dist'], rs1, rs0['max_radius'])
                else:
                    res = getattr(self.metric_interface, eval_metric)(rs0['avg_dist'], rs1)
            else:
                res = getattr(self.metric_interface, eval_metric)(rs0, rs1)
            eval_result[eval_metric].append(mode)
            eval_result[eval_metric].append(res)

        return eval_result
Esempio n. 11
0
def get_batch_generator(data_size, batch_size, batch_strategy, masked_rate, shuffle):
    if batch_size >= data_size:
        LOGGER.warning("As batch_size >= data size, all batch strategy will be disabled")
        return FullBatchDataGenerator(data_size, data_size, shuffle=False)

    # if round((masked_rate + 1) * batch_size) >= data_size:
        # LOGGER.warning("Masked dataset's batch_size >= data size, batch shuffle will be disabled")
        # return FullBatchDataGenerator(data_size, data_size, shuffle=False, masked_rate=masked_rate)
    if batch_strategy == "full":
        if masked_rate > 0:
            LOGGER.warning("If using full batch strategy and masked rate > 0, shuffle will always be true")
            shuffle = True
        return FullBatchDataGenerator(data_size, batch_size, shuffle=shuffle, masked_rate=masked_rate)
    else:
        if shuffle:
            LOGGER.warning("if use random select batch strategy, shuffle will not work")
        return RandomBatchDataGenerator(data_size, batch_size, masked_rate)
Esempio n. 12
0
    def validate(self, model, epoch,
                 use_precomputed_train=False,
                 use_precomputed_validate=False,
                 train_scores=None,
                 validate_scores=None):

        """
        :param model: model instance, which has predict function
        :param epoch: int, epoch idx for generating flow id
        :param use_precomputed_validate: bool, use precomputed train scores or not, if True, check validate_scores
        :param use_precomputed_train: bool, use precomputed validate scores or not, if True, check train_scores
        :param validate_scores: dtable, key is sample id, value is a list contains precomputed predict scores.
                                             once offered, skip calling
                                             model.predict(self.validate_data) and use this as validate_predicts
        :param train_scores: dtable, key is sample id, value is a list contains precomputed predict scores.
                                             once offered, skip calling
                                             model.predict(self.train_data) and use this as validate_predicts
        :return:
        """

        LOGGER.debug("begin to check validate status, need_run_validation is {}".format(self.need_run_validation(epoch)))

        if not self.need_run_validation(epoch):
            return

        if self.mode == consts.H**O and self.role == consts.ARBITER:
            return

        if not use_precomputed_train:  # call model.predict()
            train_predicts = self.get_predict_result(model, epoch, self.train_data, "train")
        else:  # use precomputed scores
            train_predicts = self.handle_precompute_scores(train_scores, 'train')

        if not use_precomputed_validate:  # call model.predict()
            validate_predicts = self.get_predict_result(model, epoch, self.validate_data, "validate")
        else:  # use precomputed scores
            validate_predicts = self.handle_precompute_scores(validate_scores, 'validate')

        if train_predicts is not None or validate_predicts is not None:

            predicts = train_predicts
            if validate_predicts:
                predicts = predicts.union(validate_predicts)

            # running evaluation
            eval_result_dict = self.evaluate(predicts, model, epoch)
            LOGGER.debug('showing eval_result_dict here')
            LOGGER.debug(eval_result_dict)

            if self.early_stopping_rounds:

                if len(eval_result_dict) == 0:
                    raise ValueError("eval_result len is 0, no single value metric detected for early stopping checking")

                if self.use_first_metric_only:
                    if self.first_metric:
                        eval_result_dict = {self.first_metric: eval_result_dict[self.first_metric]}
                    else:
                        LOGGER.warning('use first metric only but no single metric found in metric list')

                self.performance_recorder.update(eval_result_dict)

        if self.sync_status:
            self.sync_performance_recorder(epoch)

        if self.early_stopping_rounds and self.mode == consts.HETERO:
            self.update_early_stopping_status(epoch, model)
Esempio n. 13
0
 def query_split_points(self, col_name):
     col_results = self.all_cols_results.get(col_name)
     if col_results is None:
         LOGGER.warning("Querying non-exist split_points")
         return None
     return col_results.split_points
Esempio n. 14
0
    def run(self, component_parameters, data_inst, original_model,
            host_do_evaluate):
        self._init_model(component_parameters)

        if data_inst is None:
            self._arbiter_run(original_model)
            return
        total_data_count = data_inst.count()
        LOGGER.debug("data_inst count: {}".format(data_inst.count()))
        if self.output_fold_history:
            if total_data_count * self.n_splits > consts.MAX_SAMPLE_OUTPUT_LIMIT:
                LOGGER.warning(
                    f"max sample output limit {consts.MAX_SAMPLE_OUTPUT_LIMIT} exceeded with n_splits ({self.n_splits}) * instance_count ({total_data_count})"
                )
        if self.mode == consts.H**O or self.role == consts.GUEST:
            data_generator = self.split(data_inst)
        else:
            data_generator = [(data_inst, data_inst)] * self.n_splits
        fold_num = 0

        summary_res = {}
        for train_data, test_data in data_generator:
            model = copy.deepcopy(original_model)
            LOGGER.debug("In CV, set_flowid flowid is : {}".format(fold_num))
            model.set_flowid(fold_num)
            model.set_cv_fold(fold_num)

            LOGGER.info("KFold fold_num is: {}".format(fold_num))
            if self.mode == consts.HETERO:
                train_data = self._align_data_index(train_data, model.flowid,
                                                    consts.TRAIN_DATA)
                LOGGER.info("Train data Synchronized")
                test_data = self._align_data_index(test_data, model.flowid,
                                                   consts.TEST_DATA)
                LOGGER.info("Test data Synchronized")
            LOGGER.debug("train_data count: {}".format(train_data.count()))
            if train_data.count() + test_data.count() != total_data_count:
                raise EnvironmentError(
                    "In cv fold: {}, train count: {}, test count: {}, original data count: {}."
                    "Thus, 'train count + test count = total count' condition is not satisfied"
                    .format(fold_num, train_data.count(), test_data.count(),
                            total_data_count))
            this_flowid = 'train.' + str(fold_num)
            LOGGER.debug(
                "In CV, set_flowid flowid is : {}".format(this_flowid))
            model.set_flowid(this_flowid)
            model.fit(train_data, test_data)

            this_flowid = 'predict_train.' + str(fold_num)
            LOGGER.debug(
                "In CV, set_flowid flowid is : {}".format(this_flowid))
            model.set_flowid(this_flowid)
            train_pred_res = model.predict(train_data)

            # if train_pred_res is not None:
            if self.role == consts.GUEST or host_do_evaluate:
                fold_name = "_".join(['train', 'fold', str(fold_num)])
                f = functools.partial(self._append_name, name='train')
                train_pred_res = train_pred_res.mapValues(f)
                train_pred_res = model.set_predict_data_schema(
                    train_pred_res, train_data.schema)
                # LOGGER.debug(f"train_pred_res schema: {train_pred_res.schema}")
                self.evaluate(train_pred_res, fold_name, model)

            this_flowid = 'predict_validate.' + str(fold_num)
            LOGGER.debug(
                "In CV, set_flowid flowid is : {}".format(this_flowid))
            model.set_flowid(this_flowid)
            test_pred_res = model.predict(test_data)

            # if pred_res is not None:
            if self.role == consts.GUEST or host_do_evaluate:
                fold_name = "_".join(['validate', 'fold', str(fold_num)])
                f = functools.partial(self._append_name, name='validate')
                test_pred_res = test_pred_res.mapValues(f)
                test_pred_res = model.set_predict_data_schema(
                    test_pred_res, test_data.schema)
                # LOGGER.debug(f"train_pred_res schema: {test_pred_res.schema}")
                self.evaluate(test_pred_res, fold_name, model)
            LOGGER.debug("Finish fold: {}".format(fold_num))

            if self.output_fold_history:
                LOGGER.debug(f"generating fold history for fold {fold_num}")
                fold_train_data = self.transform_history_data(
                    train_data, train_pred_res, fold_num, "train")
                fold_validate_data = self.transform_history_data(
                    test_data, test_pred_res, fold_num, "validate")

                fold_history_data = fold_train_data.union(fold_validate_data)
                fold_history_data.schema = fold_train_data.schema
                if self.fold_history is None:
                    self.fold_history = fold_history_data
                else:
                    new_fold_history = self.fold_history.union(
                        fold_history_data)
                    new_fold_history.schema = fold_history_data.schema
                    self.fold_history = new_fold_history

            summary_res[f"fold_{fold_num}"] = model.summary()
            fold_num += 1
        summary_res['fold_num'] = fold_num
        LOGGER.debug("Finish all fold running")
        original_model.set_summary(summary_res)
        if self.output_fold_history:
            LOGGER.debug(f"output data schema: {self.fold_history.schema}")
            #LOGGER.debug(f"output data: {list(self.fold_history.collect())}")
            LOGGER.debug(f"output data is: {self.fold_history}")
            return self.fold_history
        else:
            return data_inst
Esempio n. 15
0
    def intersect_online_process(self, data_inst, caches):
        # LOGGER.debug(f"caches is: {caches}")
        cache_data, cache_meta = list(caches.values())[0]
        intersect_meta = list(cache_meta.values())[0]["intersect_meta"]
        # LOGGER.debug(f"intersect_meta is: {intersect_meta}")
        self.callback_cache_meta(intersect_meta)
        self.load_intersect_meta(intersect_meta)
        self.init_intersect_method()
        self.intersection_obj.load_intersect_key(cache_meta)

        if data_overview.check_with_inst_id(data_inst):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")
        intersect_data = data_inst
        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                       f"sample_id_generator is set to Guest")

            proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id:
                proc_obj.use_sample_id()
            match_data = proc_obj.recover(data=data_inst)
            intersect_data = match_data

        if self.role == consts.HOST:
            cache_id = cache_meta[str(self.guest_party_id)].get("cache_id")
            self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0)
            guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0)
            if guest_cache_id != cache_id:
                raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        elif self.role == consts.GUEST:
            for i, party_id in enumerate(self.host_party_id_list):
                cache_id = cache_meta[str(party_id)].get("cache_id")
                self.transfer_variable.cache_id.remote(cache_id,
                                                       role=consts.HOST,
                                                       idx=i)
                host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i)
                if host_cache_id != cache_id:
                    raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        else:
            raise ValueError(f"Role {self.role} cannot run intersection transform.")

        self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data)
        if self.use_match_id_process:
            if not self.model_param.sync_intersect_ids:
                self.intersect_ids = proc_obj.expand(self.intersect_ids,
                                                     match_data=match_data,
                                                     owner_only=True)
            else:
                self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data)
            if self.intersect_ids and self.model_param.only_output_key:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"],
                                             "sid_name": data_inst.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data_inst.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process:
            if not self.intersection_obj.only_output_key and result_data:
                result_data = self.intersection_obj.get_value_from_data(result_data, data_inst)
                self.intersect_ids.schema = result_data.schema
                LOGGER.debug(f"not only_output_key, restore value called")
            if self.intersection_obj.only_output_key and result_data:
                schema = {"sid_name": data_inst.schema["sid_name"]}
                result_data = result_data.mapValues(lambda v: 1)
                result_data.schema = schema
                self.intersect_ids.schema = schema

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data_inst, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
Esempio n. 16
0
    def fit_binary(self, data_instances, validate_data=None):
        LOGGER.info("Enter hetero_lr_guest fit")
        self.header = self.get_header(data_instances)

        self.callback_list.on_train_begin(data_instances, validate_data)

        data_instances = data_instances.mapValues(HeteroLRGuest.load_data)
        LOGGER.debug(
            f"MODEL_STEP After load data, data count: {data_instances.count()}"
        )
        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        self.batch_generator.initialize_batch_generator(
            data_instances,
            self.batch_size,
            batch_strategy=self.batch_strategy,
            masked_rate=self.masked_rate,
            shuffle=self.shuffle)
        if self.batch_generator.batch_masked:
            self.batch_generator.verify_batch_legality()

        self.gradient_loss_operator.set_total_batch_nums(
            self.batch_generator.batch_nums)

        use_async = False
        if with_weight(data_instances):
            if self.model_param.early_stop == "diff":
                LOGGER.warning(
                    "input data with weight, please use 'weight_diff' for 'early_stop'."
                )
            # data_instances = scale_sample_weight(data_instances)
            # self.gradient_loss_operator.set_use_sample_weight()
            # LOGGER.debug(f"data_instances after scale: {[v[1].weight for v in list(data_instances.collect())]}")
        elif len(self.component_properties.host_party_idlist
                 ) == 1 and not self.batch_generator.batch_masked:
            LOGGER.debug(f"set_use_async")
            self.gradient_loss_operator.set_use_async()
            use_async = True
        self.transfer_variable.use_async.remote(use_async)

        LOGGER.info("Generate mini-batch from input data")

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        if not self.component_properties.is_warm_start:
            w = self.initializer.init_model(model_shape,
                                            init_params=self.init_param_obj)
            self.model_weights = LinearModelWeights(
                w, fit_intercept=self.fit_intercept)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        while self.n_iter_ < self.max_iter:
            self.callback_list.on_epoch_begin(self.n_iter_)
            LOGGER.info("iter: {}".format(self.n_iter_))
            batch_data_generator = self.batch_generator.generate_batch_data(
                suffix=(self.n_iter_, ), with_index=True)
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data, index_data in batch_data_generator:
                batch_feat_inst = batch_data
                if not self.batch_generator.batch_masked:
                    index_data = None

                # Start gradient procedure
                LOGGER.debug(
                    "iter: {}, batch: {}, before compute gradient, data count: {}"
                    .format(self.n_iter_, batch_index,
                            batch_feat_inst.count()))

                optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_feat_inst,
                    self.cipher_operator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index,
                    masked_index=index_data)

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(
                    batch_feat_inst,
                    self.model_weights,
                    self.n_iter_,
                    batch_index,
                    loss_norm,
                    batch_masked=self.batch_generator.batch_masked)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optim_guest_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))
            LOGGER.info("iter: {},  is_converged: {}".format(
                self.n_iter_, self.is_converged))

            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1

            if self.stop_training:
                break

            if self.is_converged:
                break
        self.callback_list.on_train_end()

        self.set_summary(self.get_model_summary())
Esempio n. 17
0
    def fit(self, data_instances, validate_data=None):
        """
        Train poisson model of role guest
        Parameters
        ----------
        data_instances: Table of Instance, input data
        """

        LOGGER.info("Enter hetero_poisson_guest fit")
        # self._abnormal_detection(data_instances)
        # self.header = copy.deepcopy(self.get_header(data_instances))
        self.prepare_fit(data_instances, validate_data)
        self.callback_list.on_train_begin(data_instances, validate_data)

        if with_weight(data_instances):
            LOGGER.warning(
                "input data with weight. Poisson regression does not support weighted training."
            )

        self.exposure_index = self.get_exposure_index(self.header,
                                                      self.exposure_colname)
        exposure_index = self.exposure_index
        if exposure_index > -1:
            self.header.pop(exposure_index)
            LOGGER.info("Guest provides exposure value.")
        exposure = data_instances.mapValues(
            lambda v: HeteroPoissonBase.load_exposure(v, exposure_index))
        data_instances = data_instances.mapValues(
            lambda v: HeteroPoissonBase.load_instance(v, exposure_index))

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        LOGGER.info("Generate mini-batch from input data")
        self.batch_generator.initialize_batch_generator(
            data_instances, self.batch_size)

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(
            self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        if not self.component_properties.is_warm_start:
            w = self.initializer.init_model(model_shape,
                                            init_params=self.init_param_obj)
            self.model_weights = LinearModelWeights(
                w,
                fit_intercept=self.fit_intercept,
                raise_overflow_error=False)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        while self.n_iter_ < self.max_iter:
            self.callback_list.on_epoch_begin(self.n_iter_)
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # compute offset of this batch
                batch_offset = exposure.join(
                    batch_data, lambda ei, d: HeteroPoissonBase.safe_log(ei))

                # Start gradient procedure
                optimized_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_data, self.cipher_operator, self.model_weights,
                    self.optimizer, self.n_iter_, batch_index, batch_offset)
                # LOGGER.debug("iteration:{} Guest's gradient: {}".format(self.n_iter_, optimized_gradient))
                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(
                    batch_data, self.model_weights, self.n_iter_, batch_index,
                    batch_offset, loss_norm)

                self.model_weights = self.optimizer.update_model(
                    self.model_weights, optimized_gradient)

                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(
                suffix=(self.n_iter_, ))
            LOGGER.info("iter: {},  is_converged: {}".format(
                self.n_iter_, self.is_converged))

            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1

            if self.stop_training:
                break

            if self.is_converged:
                break
        self.callback_list.on_train_end()
        self.set_summary(self.get_model_summary())
Esempio n. 18
0
    def fit(self, data):
        if self.component_properties.caches:
            return self.intersect_online_process(data, self.component_properties.caches)
        self.init_intersect_method()
        if data_overview.check_with_inst_id(data):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")

        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                         f"sample_id_generator is set to Guest")

            self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            self.proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id:
                self.proc_obj.use_sample_id()
            match_data = self.proc_obj.recover(data=data)
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(match_data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(match_data)
            else:
                intersect_data = match_data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(match_data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
        else:
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                # LOGGER.debug(f"callback intersect meta is: {intersect_meta}")
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(data)
            else:
                intersect_data = data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)

        if self.intersection_obj.cardinality_only:
            if self.intersection_obj.intersect_num is not None:
                data_count = data.count()
                self.intersect_num = self.intersection_obj.intersect_num
                self.intersect_rate = self.intersect_num / data_count
                self.unmatched_num = data_count - self.intersect_num
                self.unmatched_rate = 1 - self.intersect_rate
            # self.model = self.intersection_obj.get_model()
            self.set_summary(self.get_model_summary())
            self.callback()
            return data

        if self.use_match_id_process:
            if self.model_param.sync_intersect_ids:
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data)
            else:
                # self.intersect_ids = match_data
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids,
                                                          match_data=match_data,
                                                          owner_only=True)
            if self.model_param.only_output_key and self.intersect_ids:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"],
                                             "sid_name": data.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data:
            result_data = self.intersection_obj.get_value_from_data(result_data, data)
            LOGGER.debug(f"not only_output_key, restore value called")

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
Esempio n. 19
0
    def check(self):

        super(HeteroSecureBoostParam, self).check()
        self.tree_param.check()
        if not isinstance(self.use_missing, bool):
            raise ValueError('use missing should be bool type')
        if not isinstance(self.zero_as_missing, bool):
            raise ValueError('zero as missing should be bool type')
        self.check_boolean(self.complete_secure, 'complete_secure')
        self.check_boolean(self.run_goss, 'run goss')
        self.check_decimal_float(self.top_rate, 'top rate')
        self.check_decimal_float(self.other_rate, 'other rate')
        self.check_positive_number(self.other_rate, 'other_rate')
        self.check_positive_number(self.top_rate, 'top_rate')
        self.check_boolean(self.new_ver, 'code version switcher')
        self.check_boolean(self.cipher_compress, 'cipher compress')
        self.check_boolean(self.EINI_inference, 'eini inference')
        self.check_boolean(self.EINI_random_mask, 'eini random mask')
        self.check_boolean(self.EINI_complexity_check, 'eini complexity check')

        if self.EINI_inference and self.EINI_random_mask:
            LOGGER.warning(
                'To protect the inference decision path, notice that current setting will multiply'
                ' predict result by a random number, hence SecureBoost will return confused predict scores'
                ' that is not the same as the original predict scores')

        if self.work_mode == consts.MIX_TREE and self.EINI_inference:
            LOGGER.warning(
                'Mix tree mode does not support EINI, use default predict setting'
            )

        if self.work_mode is not None:
            self.boosting_strategy = self.work_mode

        if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
            raise ValueError('unsupported multi-classification mode')
        if self.multi_mode == consts.MULTI_OUTPUT:
            if self.boosting_strategy != consts.STD_TREE:
                raise ValueError(
                    'MO trees only works when boosting strategy is std tree')
            if not self.cipher_compress:
                raise ValueError(
                    'Mo trees only works when cipher compress is enabled')

        if self.boosting_strategy not in [
                consts.STD_TREE, consts.LAYERED_TREE, consts.MIX_TREE
        ]:
            raise ValueError('unknown sbt boosting strategy{}'.format(
                self.boosting_strategy))

        for p in [
                "early_stopping_rounds", "validation_freqs", "metrics",
                "use_first_metric_only"
        ]:
            # if self._warn_to_deprecate_param(p, "", ""):
            if self._deprecated_params_set.get(p):
                if "callback_param" in self.get_user_feeded():
                    raise ValueError(
                        f"{p} and callback param should not be set simultaneously,"
                        f"{self._deprecated_params_set}, {self.get_user_feeded()}"
                    )
                else:
                    self.callback_param.callbacks = ["PerformanceEvaluate"]
                break

        descr = "boosting_param's"

        if self._warn_to_deprecate_param(
                "validation_freqs", descr,
                "callback_param's 'validation_freqs'"):
            self.callback_param.validation_freqs = self.validation_freqs

        if self._warn_to_deprecate_param(
                "early_stopping_rounds", descr,
                "callback_param's 'early_stopping_rounds'"):
            self.callback_param.early_stopping_rounds = self.early_stopping_rounds

        if self._warn_to_deprecate_param("metrics", descr,
                                         "callback_param's 'metrics'"):
            self.callback_param.metrics = self.metrics

        if self._warn_to_deprecate_param(
                "use_first_metric_only", descr,
                "callback_param's 'use_first_metric_only'"):
            self.callback_param.use_first_metric_only = self.use_first_metric_only

        if self.top_rate + self.other_rate >= 1:
            raise ValueError(
                'sum of top rate and other rate should be smaller than 1')

        return True
Esempio n. 20
0
    def __init__(self,
                 tree_param: DecisionTreeParam,
                 data_bin=None,
                 bin_split_points: np.array = None,
                 bin_sparse_point=None,
                 g_h=None,
                 valid_feature: dict = None,
                 epoch_idx: int = None,
                 role: str = None,
                 tree_idx: int = None,
                 flow_id: int = None,
                 mode='train'):
        """
        Parameters
        ----------
        tree_param: decision tree parameter object
        data_bin binned: data instance
        bin_split_points: data split points
        bin_sparse_point: sparse data point
        g_h computed: g val and h val of instances
        valid_feature: dict points out valid features {valid:true,invalid:false}
        epoch_idx: current epoch index
        role: host or guest
        flow_id: flow id
        mode: train / predict
        """

        super(HomoDecisionTreeClient, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)
        self.data_bin = data_bin
        self.g_h = g_h
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_point
        self.epoch_idx = epoch_idx
        self.tree_idx = tree_idx

        # check max_split_nodes
        if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1:
            self.max_split_nodes += 1
            LOGGER.warning(
                'an even max_split_nodes value is suggested '
                'when using histogram-subtraction, max_split_nodes reset to {}'
                .format(self.max_split_nodes))

        self.transfer_inst = HomoDecisionTreeTransferVariable()
        """
        initializing here
        """
        self.valid_features = valid_feature
        self.tree_node = []  # start from root node
        self.tree_node_num = 0
        self.cur_layer_node = []
        self.runtime_idx = 0
        self.sitename = consts.GUEST
        self.feature_importance = {}

        # secure aggregator, class SecureBoostClientAggregator
        if mode == 'train':
            self.role = role
            self.set_flowid(flow_id)
            self.aggregator = DecisionTreeClientAggregator(verbose=False)

        elif mode == 'predict':
            self.role, self.aggregator = None, None
Esempio n. 21
0
    def fit(self, data_instances, validate_data=None):
        """
        Train linR model of role guest
        Parameters
        ----------
        data_instances: Table of Instance, input data
        """

        LOGGER.info("Enter hetero_linR_guest fit")
        self._abnormal_detection(data_instances)
        self.header = self.get_header(data_instances)
        self.callback_list.on_train_begin(data_instances, validate_data)
        # self.validation_strategy = self.init_validation_strategy(data_instances, validate_data)

        self.cipher_operator = self.cipher.gen_paillier_cipher_operator()

        use_async = False
        if with_weight(data_instances):
            if self.model_param.early_stop == "diff":
                LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.")
            data_instances = scale_sample_weight(data_instances)
            self.gradient_loss_operator.set_use_sample_weight()
            LOGGER.debug(f"instance weight scaled; use weighted gradient loss operator")
            # LOGGER.debug(f"data_instances after scale: {[v[1].weight for v in list(data_instances.collect())]}")
        elif len(self.component_properties.host_party_idlist) == 1:
            LOGGER.debug(f"set_use_async")
            self.gradient_loss_operator.set_use_async()
            use_async = True
        self.transfer_variable.use_async.remote(use_async)

        LOGGER.info("Generate mini-batch from input data")
        self.batch_generator.initialize_batch_generator(data_instances, self.batch_size)
        self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums)

        self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator,
                                                           self.encrypted_mode_calculator_param.mode,
                                                           self.encrypted_mode_calculator_param.re_encrypted_rate) for _
                                     in range(self.batch_generator.batch_nums)]

        LOGGER.info("Start initialize model.")
        LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept))
        model_shape = self.get_features_shape(data_instances)
        if not self.component_properties.is_warm_start:
            w = self.initializer.init_model(model_shape, init_params=self.init_param_obj)
            self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept, raise_overflow_error=False)
        else:
            self.callback_warm_start_init_iter(self.n_iter_)

        while self.n_iter_ < self.max_iter:
            self.callback_list.on_epoch_begin(self.n_iter_)
            LOGGER.info("iter:{}".format(self.n_iter_))
            # each iter will get the same batch_data_generator
            batch_data_generator = self.batch_generator.generate_batch_data()
            self.optimizer.set_iters(self.n_iter_)
            batch_index = 0
            for batch_data in batch_data_generator:
                # Start gradient procedure
                optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure(
                    batch_data,
                    self.encrypted_calculator,
                    self.model_weights,
                    self.optimizer,
                    self.n_iter_,
                    batch_index
                )

                loss_norm = self.optimizer.loss_norm(self.model_weights)
                self.gradient_loss_operator.compute_loss(batch_data, self.n_iter_, batch_index, loss_norm)

                self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient)
                batch_index += 1

            self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,))
            LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged))

            self.callback_list.on_epoch_end(self.n_iter_)
            self.n_iter_ += 1
            if self.stop_training:
                break

            if self.is_converged:
                break
        self.callback_list.on_train_end()

        self.set_summary(self.get_model_summary())
Esempio n. 22
0
    def fit(self, data_instances):
        LOGGER.info("Start Hetero Selection Fit and transform.")
        self._abnormal_detection(data_instances)
        self._init_select_params(data_instances)

        original_col_nums = len(
            self.curt_select_properties.last_left_col_names)

        empty_cols = False
        if len(self.curt_select_properties.select_col_indexes) == 0:
            LOGGER.warning(
                "None of columns has been set to select, "
                "will randomly select one column to participate in fitting filter(s). "
                "All columns will be kept, "
                "but be aware that this may lead to unexpected behavior.")
            header = data_instances.schema.get("header")
            select_idx = random.choice(range(len(header)))
            self.curt_select_properties.select_col_indexes = [select_idx]
            self.curt_select_properties.select_col_names = [header[select_idx]]
            empty_cols = True
        suffix = self.filter_methods
        if self.role == consts.HOST:
            self.transfer_variable.host_empty_cols.remote(empty_cols,
                                                          role=consts.GUEST,
                                                          idx=0,
                                                          suffix=suffix)
        else:
            host_empty_cols_list = self.transfer_variable.host_empty_cols.get(
                idx=-1, suffix=suffix)
            host_list = self.component_properties.host_party_idlist
            for idx, res in enumerate(host_empty_cols_list):
                if res:
                    LOGGER.warning(
                        f"Host {host_list[idx]}'s select columns are empty;"
                        f"host {host_list[idx]} will randomly select one "
                        f"column to participate in fitting filter(s). "
                        f"All columns from this host will be kept, "
                        f"but be aware that this may lead to unexpected behavior."
                    )

        for filter_idx, method in enumerate(self.filter_methods):
            if method in [
                    consts.STATISTIC_FILTER, consts.IV_FILTER,
                    consts.PSI_FILTER, consts.HETERO_SBT_FILTER,
                    consts.HOMO_SBT_FILTER, consts.HETERO_FAST_SBT_FILTER,
                    consts.VIF_FILTER
            ]:
                if method == consts.STATISTIC_FILTER:
                    metrics = self.model_param.statistic_param.metrics
                elif method == consts.IV_FILTER:
                    metrics = self.model_param.iv_param.metrics
                elif method == consts.PSI_FILTER:
                    metrics = self.model_param.psi_param.metrics
                elif method in [
                        consts.HETERO_SBT_FILTER, consts.HOMO_SBT_FILTER,
                        consts.HETERO_FAST_SBT_FILTER
                ]:
                    metrics = self.model_param.sbt_param.metrics
                elif method == consts.VIF_FILTER:
                    metrics = self.model_param.vif_param.metrics
                else:
                    raise ValueError(f"method: {method} is not supported")
                for idx, _ in enumerate(metrics):
                    self._filter(data_instances,
                                 method,
                                 suffix=(str(filter_idx), str(idx)),
                                 idx=idx)
            else:
                self._filter(data_instances, method, suffix=str(filter_idx))
        last_col_nums = self.curt_select_properties.last_left_col_names

        self.add_summary(
            "all", {
                "last_col_nums": original_col_nums,
                "left_col_nums": len(last_col_nums),
                "left_col_names": last_col_nums
            })

        new_data = self._transfer_data(data_instances)
        # LOGGER.debug(f"Final summary: {self.summary()}")
        LOGGER.info("Finish Hetero Selection Fit and transform.")
        return new_data
Esempio n. 23
0
 def _warn_deprecated_param(self, param_name, descr):
     if self._deprecated_params_set.get(param_name):
         LOGGER.warning(
             f"{descr} {param_name} is deprecated and ignored in this version."
         )
Esempio n. 24
0
File: union.py Progetto: zpskt/FATE
    def fit(self, data):
        LOGGER.debug(f"fit receives data is {data}")
        if not isinstance(data, dict):
            raise ValueError(
                "Union module must receive more than one table as input.")
        empty_count = 0
        combined_table = None
        combined_schema = None
        metrics = []

        for (key, local_table) in data.items():
            LOGGER.debug("table to combine name: {}".format(key))
            num_data = local_table.count()
            LOGGER.debug("table count: {}".format(num_data))
            metrics.append(Metric(key, num_data))
            self.add_summary(key, num_data)

            if num_data == 0:
                LOGGER.warning("Table {} is empty.".format(key))
                if combined_table is None:
                    combined_table = local_table
                    combined_schema = local_table.schema
                empty_count += 1
                continue

            local_is_data_instance = self.check_is_data_instance(local_table)
            if combined_table is None:
                self.is_data_instance = local_is_data_instance
            LOGGER.debug(f"self.is_data_instance is {self.is_data_instance}, "
                         f"local_is_data_instance is {local_is_data_instance}")
            if self.is_data_instance != local_is_data_instance:
                raise ValueError(
                    f"Cannot combine DataInstance and non-DataInstance object. Union aborted."
                )

            if self.is_data_instance:
                self.is_empty_feature = data_overview.is_empty_feature(
                    local_table)
                if self.is_empty_feature:
                    LOGGER.warning("Table {} has empty feature.".format(key))
                else:
                    self.check_schema_content(local_table.schema)

            if combined_table is None:
                # first table to combine
                combined_table = local_table
                combined_schema = local_table.schema
            else:
                self.check_id(local_table, combined_table)
                self.check_label_name(local_table, combined_table)
                self.check_header(local_table, combined_table)
                if self.keep_duplicate:
                    repeated_ids = combined_table.join(local_table,
                                                       lambda v1, v2: 1)
                    self.repeated_ids = [v[0] for v in repeated_ids.collect()]
                    self.key = key
                    local_table = local_table.flatMap(self._renew_id)

                combined_table = combined_table.union(local_table,
                                                      self._keep_first)

                combined_table.schema = combined_schema

            # only check feature length if not empty
            if self.is_data_instance and not self.is_empty_feature:
                self.feature_count = len(combined_schema.get("header"))
                LOGGER.debug("feature count: {}".format(self.feature_count))
                combined_table.mapValues(self.check_feature_length)

        if combined_table is None:
            num_data = 0
            LOGGER.warning(
                "All tables provided are empty or have empty features.")
        else:
            num_data = combined_table.count()
        metrics.append(Metric("Total", num_data))
        self.add_summary("Total", num_data)

        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=metrics)
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=MetricMeta(
                                         name=self.metric_name,
                                         metric_type=self.metric_type))

        LOGGER.info(
            "Union operation finished. Total {} empty tables encountered.".
            format(empty_count))

        return combined_table
Esempio n. 25
0
    def __save_pr_curve(self, precision_and_recall, data_name):

        precision_res = precision_and_recall[consts.PRECISION]
        recall_res = precision_and_recall[consts.RECALL]

        if precision_res[0] != recall_res[0]:
            LOGGER.warning(
                "precision mode:{} is not equal to recall mode:{}".format(
                    precision_res[0], recall_res[0]))
            return

        metric_namespace = precision_res[0]
        metric_name_precision = '_'.join([data_name, "precision"])
        metric_name_recall = '_'.join([data_name, "recall"])

        pos_precision_score = precision_res[1][0]
        precision_cuts = precision_res[1][1]
        if len(precision_res[1]) >= 3:
            precision_thresholds = precision_res[1][2]
        else:
            precision_thresholds = None

        pos_recall_score = recall_res[1][0]
        recall_cuts = recall_res[1][1]
        if len(recall_res[1]) >= 3:
            recall_thresholds = recall_res[1][2]
        else:
            recall_thresholds = None

        precision_curve_name = data_name
        recall_curve_name = data_name

        if self.eval_type == consts.BINARY:
            pos_precision_score = [score[1] for score in pos_precision_score]
            pos_recall_score = [score[1] for score in pos_recall_score]

            pos_recall_score, pos_precision_score, idx_list = self.__filt_override_unit_ordinate_coordinate(
                pos_recall_score, pos_precision_score)

            precision_cuts = [precision_cuts[idx] for idx in idx_list]
            recall_cuts = [recall_cuts[idx] for idx in idx_list]

            edge_idx = idx_list[-1]
            if edge_idx == len(precision_thresholds) - 1:
                idx_list = idx_list[:-1]
            precision_thresholds = [
                precision_thresholds[idx] for idx in idx_list
            ]
            recall_thresholds = [recall_thresholds[idx] for idx in idx_list]

        elif self.eval_type == consts.MULTY:

            pos_recall_score, recall_cuts = self.__multi_class_label_padding(
                pos_recall_score, recall_cuts)
            pos_precision_score, precision_cuts = self.__multi_class_label_padding(
                pos_precision_score, precision_cuts)

        self.__save_curve_data(precision_cuts, pos_precision_score,
                               metric_name_precision, metric_namespace)
        self.__save_curve_meta(
            metric_name_precision,
            metric_namespace,
            "_".join([consts.PRECISION.upper(),
                      self.eval_type.upper()]),
            unit_name="",
            ordinate_name="Precision",
            curve_name=precision_curve_name,
            pair_type=data_name,
            thresholds=precision_thresholds)

        self.__save_curve_data(recall_cuts, pos_recall_score,
                               metric_name_recall, metric_namespace)
        self.__save_curve_meta(
            metric_name_recall,
            metric_namespace,
            "_".join([consts.RECALL.upper(),
                      self.eval_type.upper()]),
            unit_name="",
            ordinate_name="Recall",
            curve_name=recall_curve_name,
            pair_type=data_name,
            thresholds=recall_thresholds)
Esempio n. 26
0
    def split_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
        min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
        final_max_bin = optimal_param.max_bin

        def _compute_ks(start_idx, end_idx):
            acc_event = []
            acc_non_event = []
            curt_event_total = 0
            curt_non_event_total = 0
            for bucket in bucket_list[start_idx: end_idx]:
                acc_event.append(bucket.event_count + curt_event_total)
                curt_event_total += bucket.event_count
                acc_non_event.append(bucket.non_event_count + curt_non_event_total)
                curt_non_event_total += bucket.non_event_count

            if curt_event_total == 0 or curt_non_event_total == 0:
                return None, None, None

            acc_event_rate = [x / curt_event_total for x in acc_event]
            acc_non_event_rate = [x / curt_non_event_total for x in acc_non_event]
            ks_list = [math.fabs(eve - non_eve) for eve, non_eve in zip(acc_event_rate, acc_non_event_rate)]
            if max(ks_list) == 0:
                best_index = len(ks_list) // 2
            else:
                best_index = ks_list.index(max(ks_list))

            left_event = acc_event[best_index]
            right_event = curt_event_total - left_event
            left_non_event = acc_non_event[best_index]
            right_non_event = curt_non_event_total - left_non_event
            left_total = left_event + left_non_event
            right_total = right_event + right_non_event

            if left_total < min_item_num or right_total < min_item_num:
                best_index = len(ks_list) // 2
                left_event = acc_event[best_index]
                right_event = curt_event_total - left_event
                left_non_event = acc_non_event[best_index]
                right_non_event = curt_non_event_total - left_non_event
                left_total = left_event + left_non_event
                right_total = right_event + right_non_event

            best_ks = ks_list[best_index]
            res_dict = {
                'left_event': left_event,
                'right_event': right_event,
                'left_non_event': left_non_event,
                'right_non_event': right_non_event,
                'left_total': left_total,
                'right_total': right_total,
                'left_is_mixed': left_event > 0 and left_non_event > 0,
                'right_is_mixed': right_event > 0 and right_non_event > 0
            }
            return best_ks, start + best_index, res_dict

        def _merge_buckets(start_idx, end_idx, bucket_idx):
            res_bucket = copy.deepcopy(bucket_list[start_idx])
            res_bucket.idx = bucket_idx

            for bucket in bucket_list[start_idx + 1: end_idx]:
                res_bucket = res_bucket.merge(bucket)
            return res_bucket

        res_split_index = []
        to_split_pair = [(0, len(bucket_list))]

        # iteratively split
        while len(to_split_pair) > 0:
            if len(res_split_index) >= final_max_bin - 1:
                break
            start, end = to_split_pair.pop(0)
            if start >= end:
                continue
            best_ks, best_index, res_dict = _compute_ks(start, end)
            if best_ks is None:
                continue
            if optimal_param.mixture:
                if not (res_dict.get('left_is_mixed') and res_dict.get('right_is_mixed')):
                    continue
            if res_dict.get('left_total') < min_item_num or res_dict.get('right_total') < min_item_num:
                continue
            res_split_index.append(best_index + 1)

            if res_dict.get('right_total') > res_dict.get('left_total'):
                to_split_pair.append((best_index + 1, end))
                to_split_pair.append((start, best_index + 1))
            else:
                to_split_pair.append((start, best_index + 1))
                to_split_pair.append((best_index + 1, end))
                # LOGGER.debug("to_split_pair: {}".format(to_split_pair))

        if len(res_split_index) == 0:
            LOGGER.warning("Best ks optimal binning fail to split. Take middle split point instead")
            res_split_index.append(len(bucket_list) // 2)
        res_split_index = sorted(res_split_index)
        res_split_index.append(len(bucket_list))
        start = 0
        bucket_res = []
        non_mixture_num = 0
        small_size_num = 0
        for bucket_idx, end in enumerate(res_split_index):
            new_bucket = _merge_buckets(start, end, bucket_idx)
            bucket_res.append(new_bucket)
            if not new_bucket.is_mixed:
                non_mixture_num += 1
            if new_bucket.total_count < min_item_num:
                small_size_num += 1
            start = end
        return bucket_res, non_mixture_num, small_size_num
Esempio n. 27
0
    def fit(self):

        LOGGER.info(
            'begin to fit h**o decision tree, epoch {}, tree idx {}'.format(
                self.epoch_idx, self.tree_idx))

        g_sum, h_sum = self.aggregator.aggregate_root_node_info(
            suffix=('root_node_sync1', self.epoch_idx))

        self.aggregator.broadcast_root_info(g_sum,
                                            h_sum,
                                            suffix=('root_node_sync2',
                                                    self.epoch_idx))

        if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1:
            self.max_split_nodes += 1
            LOGGER.warning(
                'an even max_split_nodes value is suggested when using histogram-subtraction, '
                'max_split_nodes reset to {}'.format(self.max_split_nodes))

        tree_height = self.max_depth + 1  # non-leaf node height + 1 layer leaf

        for dep in range(tree_height):

            if dep + 1 == tree_height:
                break

            LOGGER.debug('current dep is {}'.format(dep))

            split_info = []
            # get cur layer node num:
            cur_layer_node_num = self.sync_node_sample_numbers(
                suffix=(dep, self.epoch_idx, self.tree_idx))
            LOGGER.debug(
                '{} nodes to split at this layer'.format(cur_layer_node_num))

            layer_stored_hist = {}

            for batch_id, i in enumerate(
                    range(0, cur_layer_node_num, self.max_split_nodes)):

                LOGGER.debug('cur batch id is {}'.format(batch_id))

                left_node_histogram = self.sync_local_histogram(
                    suffix=(batch_id, dep, self.epoch_idx, self.tree_idx))

                all_histograms = self.histogram_subtraction(
                    left_node_histogram, self.stored_histograms)

                # store histogram
                for hist in all_histograms:
                    layer_stored_hist[hist.hid] = hist

                # FIXME stable parallel_partitions
                best_splits = self.federated_find_best_split(
                    all_histograms, parallel_partitions=10)
                split_info += best_splits

            self.stored_histograms = layer_stored_hist

            self.sync_best_splits(split_info, suffix=(dep, self.epoch_idx))
            LOGGER.debug('best_splits_sent')
Esempio n. 28
0
    def check(self):
        descr = "intersect param's "

        self.intersect_method = self.check_and_change_lower(
            self.intersect_method, [consts.RSA, consts.RAW, consts.DH],
            f"{descr}intersect_method")

        if self._warn_to_deprecate_param("random_bit", descr,
                                         "rsa_params' 'random_bit'"):
            if "rsa_params.random_bit" in self.get_user_feeded():
                raise ValueError(
                    f"random_bit and rsa_params.random_bit should not be set simultaneously"
                )
            self.rsa_params.random_bit = self.random_bit

        self.check_boolean(self.sync_intersect_ids, f"{descr}intersect_ids")

        if self._warn_to_deprecate_param("encode_param", "", ""):
            if "raw_params" in self.get_user_feeded():
                raise ValueError(
                    f"encode_param and raw_params should not be set simultaneously"
                )
            else:
                self.callback_param.callbacks = ["PerformanceEvaluate"]

        if self._warn_to_deprecate_param("join_role", descr,
                                         "raw_params' 'join_role'"):
            if "raw_params.join_role" in self.get_user_feeded():
                raise ValueError(
                    f"join_role and raw_params.join_role should not be set simultaneously"
                )
            self.raw_params.join_role = self.join_role

        self.check_boolean(self.only_output_key, f"{descr}only_output_key")

        self.join_method = self.check_and_change_lower(
            self.join_method, [consts.INNER_JOIN, consts.LEFT_JOIN],
            f"{descr}join_method")
        self.check_boolean(self.new_sample_id, f"{descr}new_sample_id")
        self.sample_id_generator = self.check_and_change_lower(
            self.sample_id_generator, [consts.GUEST, consts.HOST],
            f"{descr}sample_id_generator")

        if self.join_method == consts.LEFT_JOIN:
            if not self.sync_intersect_ids:
                raise ValueError(
                    f"Cannot perform left join without sync intersect ids")

        self.check_boolean(self.run_cache, f"{descr} run_cache")

        if self._warn_to_deprecate_param("encode_params", descr, "raw_params") or \
            self._warn_to_deprecate_param("with_encode", descr, "raw_params' 'use_hash'"):
            # self.encode_params.check()
            if "with_encode" in self.get_user_feeded(
            ) and "raw_params.use_hash" in self.get_user_feeded():
                raise ValueError(
                    f"'raw_params' and 'encode_params' should not be set simultaneously."
                )
            if "raw_params" in self.get_user_feeded(
            ) and "encode_params" in self.get_user_feeded():
                raise ValueError(
                    f"'raw_params' and 'encode_params' should not be set simultaneously."
                )
            LOGGER.warning(
                f"Param values from 'encode_params' will override 'raw_params' settings."
            )
            self.raw_params.use_hash = self.with_encode
            self.raw_params.hash_method = self.encode_params.encode_method
            self.raw_params.salt = self.encode_params.salt
            self.raw_params.base64 = self.encode_params.base64

        self.raw_params.check()
        self.rsa_params.check()
        self.dh_params.check()
        # self.intersect_cache_param.check()
        self.check_boolean(self.cardinality_only, f"{descr}cardinality_only")
        self.check_boolean(self.sync_cardinality, f"{descr}sync_cardinality")
        self.check_boolean(self.run_preprocess, f"{descr}run_preprocess")
        self.intersect_preprocess_params.check()
        if self.cardinality_only:
            if self.intersect_method not in [consts.RSA]:
                raise ValueError(f"cardinality-only mode only support rsa.")
            if self.intersect_method == consts.RSA and self.rsa_params.split_calculation:
                raise ValueError(
                    f"cardinality-only mode only supports unified calculation."
                )
        if self.run_preprocess:
            if self.intersect_preprocess_params.false_positive_rate < 0.01:
                raise ValueError(
                    f"for preprocessing ids, false_positive_rate must be no less than 0.01"
                )
            if self.cardinality_only:
                raise ValueError(
                    f"cardinality_only mode cannot run preprocessing.")
        if self.run_cache:
            if self.intersect_method not in [consts.RSA, consts.DH]:
                raise ValueError(f"Only rsa or dh method supports cache.")
            if self.intersect_method == consts.RSA and self.rsa_params.split_calculation:
                raise ValueError(
                    f"RSA split_calculation does not support cache.")
            if self.cardinality_only:
                raise ValueError(
                    f"cache is not available for cardinality_only mode.")
            if self.run_preprocess:
                raise ValueError(f"Preprocessing does not support cache.")

        deprecated_param_list = [
            "repeated_id_process", "repeated_id_owner",
            "intersect_cache_param", "allow_info_share", "info_owner",
            "with_sample_id"
        ]
        for param in deprecated_param_list:
            self._warn_deprecated_param(param, descr)

        LOGGER.debug("Finish intersect parameter check!")
        return True
Esempio n. 29
0
    def fit_single_model(self, data_instances, validate_data=None):
        LOGGER.info(f"Start to train single {self.model_name}")
        if len(self.component_properties.host_party_idlist) > 1:
            raise ValueError(f"Hetero SSHE Model does not support multi-host training.")
        self.callback_list.on_train_begin(data_instances, validate_data)

        model_shape = self.get_features_shape(data_instances)
        instances_count = data_instances.count()

        if not self.component_properties.is_warm_start:
            w = self._init_weights(model_shape)
            self.model_weights = LinearModelWeights(l=w,
                                                    fit_intercept=self.model_param.init_param.fit_intercept)
            last_models = copy.deepcopy(self.model_weights)
        else:
            last_models = copy.deepcopy(self.model_weights)
            w = last_models.unboxed
            self.callback_warm_start_init_iter(self.n_iter_)

        if self.role == consts.GUEST:
            if with_weight(data_instances):
                LOGGER.info(f"data with sample weight, use sample weight.")
                if self.model_param.early_stop == "diff":
                    LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.")
                data_instances = scale_sample_weight(data_instances)
        self.batch_generator.initialize_batch_generator(data_instances, batch_size=self.batch_size)

        with SPDZ(
            "hetero_sshe",
            local_party=self.local_party,
            all_parties=self.parties,
            q_field=self.q_field,
            use_mix_rand=self.model_param.use_mix_rand,
        ) as spdz:
            spdz.set_flowid(self.flowid)
            self.secure_matrix_obj.set_flowid(self.flowid)
            # not sharing the model when reveal_every_iter
            if not self.reveal_every_iter:
                w_self, w_remote = self.share_model(w, suffix="init")
                last_w_self, last_w_remote = w_self, w_remote
                LOGGER.debug(f"first_w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}")
            batch_data_generator = self.batch_generator.generate_batch_data()

            encoded_batch_data = []
            batch_labels_list = []
            batch_weight_list = []

            for batch_data in batch_data_generator:
                if self.fit_intercept:
                    batch_features = batch_data.mapValues(lambda x: np.hstack((x.features, 1.0)))
                else:
                    batch_features = batch_data.mapValues(lambda x: x.features)
                if self.role == consts.GUEST:
                    batch_labels = batch_data.mapValues(lambda x: np.array([x.label], dtype=self.label_type))
                    batch_labels_list.append(batch_labels)
                    if self.weight:
                        batch_weight = batch_data.mapValues(lambda x: np.array([x.weight], dtype=float))
                        batch_weight_list.append(batch_weight)
                    else:
                        batch_weight_list.append(None)

                self.batch_num.append(batch_data.count())

                encoded_batch_data.append(
                    fixedpoint_table.FixedPointTensor(self.fixedpoint_encoder.encode(batch_features),
                                                      q_field=self.fixedpoint_encoder.n,
                                                      endec=self.fixedpoint_encoder))

            while self.n_iter_ < self.max_iter:
                self.callback_list.on_epoch_begin(self.n_iter_)
                LOGGER.info(f"start to n_iter: {self.n_iter_}")

                loss_list = []

                self.optimizer.set_iters(self.n_iter_)
                if not self.reveal_every_iter:
                    self.self_optimizer.set_iters(self.n_iter_)
                    self.remote_optimizer.set_iters(self.n_iter_)

                for batch_idx, batch_data in enumerate(encoded_batch_data):
                    current_suffix = (str(self.n_iter_), str(batch_idx))
                    if self.role == consts.GUEST:
                        batch_labels = batch_labels_list[batch_idx]
                        batch_weight = batch_weight_list[batch_idx]
                    else:
                        batch_labels = None
                        batch_weight = None

                    if self.reveal_every_iter:
                        y = self.forward(weights=self.model_weights,
                                         features=batch_data,
                                         labels=batch_labels,
                                         suffix=current_suffix,
                                         cipher=self.cipher,
                                         batch_weight=batch_weight)
                    else:
                        y = self.forward(weights=(w_self, w_remote),
                                         features=batch_data,
                                         labels=batch_labels,
                                         suffix=current_suffix,
                                         cipher=self.cipher,
                                         batch_weight=batch_weight)

                    if self.role == consts.GUEST:
                        if self.weight:
                            error = y - batch_labels.join(batch_weight, lambda y, b: y * b)
                        else:
                            error = y - batch_labels

                        self_g, remote_g = self.backward(error=error,
                                                         features=batch_data,
                                                         suffix=current_suffix,
                                                         cipher=self.cipher)
                    else:
                        self_g, remote_g = self.backward(error=y,
                                                         features=batch_data,
                                                         suffix=current_suffix,
                                                         cipher=self.cipher)

                    # loss computing;
                    suffix = ("loss",) + current_suffix
                    if self.reveal_every_iter:
                        batch_loss = self.compute_loss(weights=self.model_weights,
                                                       labels=batch_labels,
                                                       suffix=suffix,
                                                       cipher=self.cipher)
                    else:
                        batch_loss = self.compute_loss(weights=(w_self, w_remote),
                                                       labels=batch_labels,
                                                       suffix=suffix,
                                                       cipher=self.cipher)

                    if batch_loss is not None:
                        batch_loss = batch_loss * self.batch_num[batch_idx]
                    loss_list.append(batch_loss)

                    if self.reveal_every_iter:
                        # LOGGER.debug(f"before reveal: self_g shape: {self_g.shape}, remote_g_shape: {remote_g},"
                        #              f"self_g: {self_g}")

                        new_g = self.reveal_models(self_g, remote_g, suffix=current_suffix)

                        # LOGGER.debug(f"after reveal: new_g shape: {new_g.shape}, new_g: {new_g}"
                        #              f"self.model_param.reveal_strategy: {self.model_param.reveal_strategy}")

                        if new_g is not None:
                            self.model_weights = self.optimizer.update_model(self.model_weights, new_g,
                                                                             has_applied=False)

                        else:
                            self.model_weights = LinearModelWeights(
                                l=np.zeros(self_g.shape),
                                fit_intercept=self.model_param.init_param.fit_intercept)
                    else:
                        if self.optimizer.penalty == consts.L2_PENALTY:
                            self_g = self_g + self.self_optimizer.alpha * w_self
                            remote_g = remote_g + self.remote_optimizer.alpha * w_remote

                        # LOGGER.debug(f"before optimizer: {self_g}, {remote_g}")

                        self_g = self.self_optimizer.apply_gradients(self_g)
                        remote_g = self.remote_optimizer.apply_gradients(remote_g)

                        # LOGGER.debug(f"after optimizer: {self_g}, {remote_g}")
                        w_self -= self_g
                        w_remote -= remote_g

                        LOGGER.debug(f"w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}")

                if self.role == consts.GUEST:
                    loss = np.sum(loss_list) / instances_count
                    self.loss_history.append(loss)
                    if self.need_call_back_loss:
                        self.callback_loss(self.n_iter_, loss)
                else:
                    loss = None

                if self.converge_func_name in ["diff", "abs"]:
                    self.is_converged = self.check_converge_by_loss(loss, suffix=(str(self.n_iter_),))
                elif self.converge_func_name == "weight_diff":
                    if self.reveal_every_iter:
                        self.is_converged = self.check_converge_by_weights(
                            last_w=last_models.unboxed,
                            new_w=self.model_weights.unboxed,
                            suffix=(str(self.n_iter_),))
                        last_models = copy.deepcopy(self.model_weights)
                    else:
                        self.is_converged = self.check_converge_by_weights(
                            last_w=(last_w_self, last_w_remote),
                            new_w=(w_self, w_remote),
                            suffix=(str(self.n_iter_),))
                        last_w_self, last_w_remote = copy.deepcopy(w_self), copy.deepcopy(w_remote)
                else:
                    raise ValueError(f"Cannot recognize early_stop function: {self.converge_func_name}")

                LOGGER.info("iter: {},  is_converged: {}".format(self.n_iter_, self.is_converged))
                self.callback_list.on_epoch_end(self.n_iter_)
                self.n_iter_ += 1

                if self.stop_training:
                    break

                if self.is_converged:
                    break

            # Finally reconstruct
            if not self.reveal_every_iter:
                new_w = self.reveal_models(w_self, w_remote, suffix=("final",))
                if new_w is not None:
                    self.model_weights = LinearModelWeights(
                        l=new_w,
                        fit_intercept=self.model_param.init_param.fit_intercept)

        LOGGER.debug(f"loss_history: {self.loss_history}")
        self.set_summary(self.get_model_summary())