def add_summary(self, new_key, new_value): """ Add key:value pair to model summary Parameters ---------- new_key: str new_value: object Returns ------- """ original_value = self._summary.get(new_key, None) if original_value is not None: LOGGER.warning( f"{new_key} already exists in model summary." f"Corresponding value {original_value} will be replaced by {new_value}" ) self._summary[new_key] = new_value LOGGER.debug(f"{new_key}: {new_value} added to summary.")
def fit_buckets(self, bucket_table, sample_count): if self.optimal_param.metric_method in ['iv', 'gini', 'chi_square']: optimal_binning_method = functools.partial(self.merge_optimal_binning, optimal_param=self.optimal_param, sample_count=sample_count) else: optimal_binning_method = functools.partial(self.split_optimal_binning, optimal_param=self.optimal_param, sample_count=sample_count) result_bucket = bucket_table.mapValues(optimal_binning_method) for col_name, (bucket_list, non_mixture_num, small_size_num) in result_bucket.collect(): split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist() self.bin_results.put_col_split_points(col_name, split_points) self.__cal_single_col_result(col_name, bucket_list) if self.optimal_param.mixture and non_mixture_num > 0: LOGGER.warning("col: {}, non_mixture_num is: {}, cannot meet mixture condition".format( col_name, non_mixture_num )) if small_size_num > 0: LOGGER.warning("col: {}, small_size_num is: {}, cannot meet small size condition".format( col_name, small_size_num )) if len(bucket_list) > self.optimal_param.max_bin: LOGGER.warning("col: {}, bin_num is: {}, cannot meet max-bin condition".format( col_name, small_size_num )) return result_bucket
def __share_info(self, data): LOGGER.info("Start to share information with another role") info_share = self.transfer_variable.info_share_from_guest if self.model_param.info_owner == consts.GUEST else \ self.transfer_variable.info_share_from_host party_role = consts.GUEST if self.model_param.info_owner == consts.HOST else consts.HOST if self.role == self.model_param.info_owner: if data.schema.get('header') is not None: try: share_info_col_idx = data.schema.get('header').index(consts.SHARE_INFO_COL_NAME) one_data = data.first() if isinstance(one_data[1], Instance): share_data = data.join(self.intersect_ids, lambda d, i: [d.features[share_info_col_idx]]) else: share_data = data.join(self.intersect_ids, lambda d, i: [d[share_info_col_idx]]) info_share.remote(share_data, role=party_role, idx=-1) LOGGER.info("Remote share information to {}".format(party_role)) except Exception as e: LOGGER.warning("Something unexpected:{}, share a empty information to {}".format(e, party_role)) share_data = self.intersect_ids.mapValues(lambda v: ['null']) info_share.remote(share_data, role=party_role, idx=-1) else: raise ValueError( "'allow_info_share' is true, and 'info_owner' is {}, but can not get header in data, information sharing not done".format( self.model_param.info_owner)) else: self.intersect_ids = info_share.get(idx=0) self.intersect_ids.schema['header'] = [consts.SHARE_INFO_COL_NAME] LOGGER.info( "Get share information from {}, header:{}".format(self.model_param.info_owner, self.intersect_ids)) return self.intersect_ids
def fit(self, data): """ Apply scale for input data Parameters ---------- data: data_instance, input data Returns ---------- data:data_instance, data after scale scale_value_results: list, the fit results information of scale """ LOGGER.info("Start scale data fit ...") if self.model_param.method == consts.MINMAXSCALE: self.scale_obj = MinMaxScale(self.model_param) elif self.model_param.method == consts.STANDARDSCALE: self.scale_obj = StandardScale(self.model_param) else: LOGGER.warning("Scale method is {}, do nothing and return!".format(self.model_param.method)) if self.scale_obj: fit_data = self.scale_obj.fit(data) fit_data.schema = data.schema self.callback_meta(metric_name="scale", metric_namespace="train", metric_meta=MetricMeta(name="scale", metric_type="SCALE", extra_metas={"method": self.model_param.method})) LOGGER.info("start to get model summary ...") self.set_summary(self.scale_obj.get_model_summary()) LOGGER.info("Finish getting model summary.") else: fit_data = data LOGGER.info("End fit data ...") return fit_data
def __generate_id_map(self, data) -> dict: if not self.repeated_id_owner: LOGGER.warning("Not a repeated id owner, will not generate id map") return {} one_feature = data.first() if isinstance(one_feature[1], Instance): data = data.mapValues(lambda v: v.features[0]) else: data = data.mapValues(lambda v: v[0]) local_data = data.collect() all_id_map = defaultdict(list) final_id_map = {} for _data in local_data: all_id_map[str(_data[1])].append(_data[0]) for k, v in all_id_map.items(): if len(v) >= 2: final_id_map[k] = v return final_id_map
def check(self): descr = "secure information retrieval param's " self.check_decimal_float(self.security_level, descr + "security_level") self.oblivious_transfer_protocol = self.check_and_change_lower(self.oblivious_transfer_protocol, [consts.OT_HAUCK.lower()], descr + "oblivious_transfer_protocol") self.commutative_encryption = self.check_and_change_lower(self.commutative_encryption, [consts.CE_PH.lower()], descr + "commutative_encryption") self.non_committing_encryption = self.check_and_change_lower(self.non_committing_encryption, [consts.AES.lower()], descr + "non_committing_encryption") if self._warn_to_deprecate_param("key_size", descr, "dh_param's key_length"): self.dh_params.key_length = self.key_size self.dh_params.check() if self._warn_to_deprecate_param("raw_retrieval", descr, "dh_param's security_level = 0"): self.check_boolean(self.raw_retrieval, descr) if not isinstance(self.target_cols, list): self.target_cols = [self.target_cols] for col in self.target_cols: self.check_string(col, descr + "target_cols") if len(self.target_cols) == 0: LOGGER.warning(f"Both 'target_cols' and 'target_indexes' are empty. Label will be retrieved.")
def transform(self, data_instances): LOGGER.info(f"Enter Sample Weight Transform") new_schema = copy.deepcopy(data_instances.schema) new_schema["sample_weight"] = "weight" weight_loc = None if self.weight_mode == "sample weight name": weight_loc = SampleWeight.get_weight_loc(data_instances, self.sample_weight_name) if weight_loc is not None: new_schema["header"].pop(weight_loc) else: LOGGER.warning( f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'." f"Original input data returned") return data_instances result_instances = self.transform_weighted_instance( data_instances, weight_loc) result_instances.schema = new_schema self.callback_info() if result_instances.mapPartitions(check_negative_sample_weight).reduce( lambda x, y: x or y): LOGGER.warning(f"Negative weight found in weighted instances.") return result_instances
def check(self): if type(self.salt).__name__ != "str": raise ValueError( "encode param's salt {} not supported, should be str type". format(self.salt)) descr = "encode param's " self.encode_method = self.check_and_change_lower( self.encode_method, [ "none", consts.MD5, consts.SHA1, consts.SHA224, consts.SHA256, consts.SHA384, consts.SHA512, consts.SM3 ], descr) if type(self.base64).__name__ != "bool": raise ValueError( "hash param's base64 {} not supported, should be bool type". format(self.base64)) LOGGER.debug("Finish EncodeParam check!") LOGGER.warning( f"'EncodeParam' will be replaced by 'RAWParam' in future release." f"Please do not rely on current param naming in application.") return True
def fit(self, data_instances): if self.sample_weight_name is None and self.class_weight is None: return data_instances self.header = data_overview.get_header(data_instances) if self.class_weight: self.weight_mode = "class weight" if self.sample_weight_name and self.class_weight: LOGGER.warning( f"Both 'sample_weight_name' and 'class_weight' provided. " f"Only weight from 'sample_weight_name' is used.") new_schema = copy.deepcopy(data_instances.schema) new_schema["sample_weight"] = "weight" weight_loc = None if self.sample_weight_name: self.weight_mode = "sample weight name" weight_loc = SampleWeight.get_weight_loc(data_instances, self.sample_weight_name) if weight_loc is not None: new_schema["header"].pop(weight_loc) else: raise ValueError( f"Cannot find weight column of given sample_weight_name '{self.sample_weight_name}'." ) result_instances = self.transform_weighted_instance( data_instances, weight_loc) result_instances.schema = new_schema self.callback_info() if result_instances.mapPartitions(check_negative_sample_weight).reduce( lambda x, y: x or y): LOGGER.warning(f"Negative weight found in weighted instances.") return result_instances
def _evaluate_clustering_metrics(self, mode, data): eval_result = defaultdict(list) rs0, rs1, run_outer_metric = self._clustering_extract(data) if rs0 is None and rs1 is None: # skip evaluation computation if get this input format LOGGER.debug('skip computing, this clustering format is not for metric computation') return eval_result if not run_outer_metric: no_label = (np.array(rs0) == None).all() if no_label: LOGGER.debug('no label found in clustering result, skip metric computation') return eval_result for eval_metric in self.metrics: # if input format and required metrics matches ? XNOR if not ((not (eval_metric in self.clustering_intra_metric_list) and not run_outer_metric) + \ ((eval_metric in self.clustering_intra_metric_list) and run_outer_metric)): LOGGER.warning('input data format does not match current clustering metric: {}'.format(eval_metric)) continue LOGGER.debug('clustering_metrics is {}'.format(eval_metric)) if run_outer_metric: if eval_metric == consts.DISTANCE_MEASURE: res = getattr(self.metric_interface, eval_metric)(rs0['avg_dist'], rs1, rs0['max_radius']) else: res = getattr(self.metric_interface, eval_metric)(rs0['avg_dist'], rs1) else: res = getattr(self.metric_interface, eval_metric)(rs0, rs1) eval_result[eval_metric].append(mode) eval_result[eval_metric].append(res) return eval_result
def get_batch_generator(data_size, batch_size, batch_strategy, masked_rate, shuffle): if batch_size >= data_size: LOGGER.warning("As batch_size >= data size, all batch strategy will be disabled") return FullBatchDataGenerator(data_size, data_size, shuffle=False) # if round((masked_rate + 1) * batch_size) >= data_size: # LOGGER.warning("Masked dataset's batch_size >= data size, batch shuffle will be disabled") # return FullBatchDataGenerator(data_size, data_size, shuffle=False, masked_rate=masked_rate) if batch_strategy == "full": if masked_rate > 0: LOGGER.warning("If using full batch strategy and masked rate > 0, shuffle will always be true") shuffle = True return FullBatchDataGenerator(data_size, batch_size, shuffle=shuffle, masked_rate=masked_rate) else: if shuffle: LOGGER.warning("if use random select batch strategy, shuffle will not work") return RandomBatchDataGenerator(data_size, batch_size, masked_rate)
def validate(self, model, epoch, use_precomputed_train=False, use_precomputed_validate=False, train_scores=None, validate_scores=None): """ :param model: model instance, which has predict function :param epoch: int, epoch idx for generating flow id :param use_precomputed_validate: bool, use precomputed train scores or not, if True, check validate_scores :param use_precomputed_train: bool, use precomputed validate scores or not, if True, check train_scores :param validate_scores: dtable, key is sample id, value is a list contains precomputed predict scores. once offered, skip calling model.predict(self.validate_data) and use this as validate_predicts :param train_scores: dtable, key is sample id, value is a list contains precomputed predict scores. once offered, skip calling model.predict(self.train_data) and use this as validate_predicts :return: """ LOGGER.debug("begin to check validate status, need_run_validation is {}".format(self.need_run_validation(epoch))) if not self.need_run_validation(epoch): return if self.mode == consts.H**O and self.role == consts.ARBITER: return if not use_precomputed_train: # call model.predict() train_predicts = self.get_predict_result(model, epoch, self.train_data, "train") else: # use precomputed scores train_predicts = self.handle_precompute_scores(train_scores, 'train') if not use_precomputed_validate: # call model.predict() validate_predicts = self.get_predict_result(model, epoch, self.validate_data, "validate") else: # use precomputed scores validate_predicts = self.handle_precompute_scores(validate_scores, 'validate') if train_predicts is not None or validate_predicts is not None: predicts = train_predicts if validate_predicts: predicts = predicts.union(validate_predicts) # running evaluation eval_result_dict = self.evaluate(predicts, model, epoch) LOGGER.debug('showing eval_result_dict here') LOGGER.debug(eval_result_dict) if self.early_stopping_rounds: if len(eval_result_dict) == 0: raise ValueError("eval_result len is 0, no single value metric detected for early stopping checking") if self.use_first_metric_only: if self.first_metric: eval_result_dict = {self.first_metric: eval_result_dict[self.first_metric]} else: LOGGER.warning('use first metric only but no single metric found in metric list') self.performance_recorder.update(eval_result_dict) if self.sync_status: self.sync_performance_recorder(epoch) if self.early_stopping_rounds and self.mode == consts.HETERO: self.update_early_stopping_status(epoch, model)
def query_split_points(self, col_name): col_results = self.all_cols_results.get(col_name) if col_results is None: LOGGER.warning("Querying non-exist split_points") return None return col_results.split_points
def run(self, component_parameters, data_inst, original_model, host_do_evaluate): self._init_model(component_parameters) if data_inst is None: self._arbiter_run(original_model) return total_data_count = data_inst.count() LOGGER.debug("data_inst count: {}".format(data_inst.count())) if self.output_fold_history: if total_data_count * self.n_splits > consts.MAX_SAMPLE_OUTPUT_LIMIT: LOGGER.warning( f"max sample output limit {consts.MAX_SAMPLE_OUTPUT_LIMIT} exceeded with n_splits ({self.n_splits}) * instance_count ({total_data_count})" ) if self.mode == consts.H**O or self.role == consts.GUEST: data_generator = self.split(data_inst) else: data_generator = [(data_inst, data_inst)] * self.n_splits fold_num = 0 summary_res = {} for train_data, test_data in data_generator: model = copy.deepcopy(original_model) LOGGER.debug("In CV, set_flowid flowid is : {}".format(fold_num)) model.set_flowid(fold_num) model.set_cv_fold(fold_num) LOGGER.info("KFold fold_num is: {}".format(fold_num)) if self.mode == consts.HETERO: train_data = self._align_data_index(train_data, model.flowid, consts.TRAIN_DATA) LOGGER.info("Train data Synchronized") test_data = self._align_data_index(test_data, model.flowid, consts.TEST_DATA) LOGGER.info("Test data Synchronized") LOGGER.debug("train_data count: {}".format(train_data.count())) if train_data.count() + test_data.count() != total_data_count: raise EnvironmentError( "In cv fold: {}, train count: {}, test count: {}, original data count: {}." "Thus, 'train count + test count = total count' condition is not satisfied" .format(fold_num, train_data.count(), test_data.count(), total_data_count)) this_flowid = 'train.' + str(fold_num) LOGGER.debug( "In CV, set_flowid flowid is : {}".format(this_flowid)) model.set_flowid(this_flowid) model.fit(train_data, test_data) this_flowid = 'predict_train.' + str(fold_num) LOGGER.debug( "In CV, set_flowid flowid is : {}".format(this_flowid)) model.set_flowid(this_flowid) train_pred_res = model.predict(train_data) # if train_pred_res is not None: if self.role == consts.GUEST or host_do_evaluate: fold_name = "_".join(['train', 'fold', str(fold_num)]) f = functools.partial(self._append_name, name='train') train_pred_res = train_pred_res.mapValues(f) train_pred_res = model.set_predict_data_schema( train_pred_res, train_data.schema) # LOGGER.debug(f"train_pred_res schema: {train_pred_res.schema}") self.evaluate(train_pred_res, fold_name, model) this_flowid = 'predict_validate.' + str(fold_num) LOGGER.debug( "In CV, set_flowid flowid is : {}".format(this_flowid)) model.set_flowid(this_flowid) test_pred_res = model.predict(test_data) # if pred_res is not None: if self.role == consts.GUEST or host_do_evaluate: fold_name = "_".join(['validate', 'fold', str(fold_num)]) f = functools.partial(self._append_name, name='validate') test_pred_res = test_pred_res.mapValues(f) test_pred_res = model.set_predict_data_schema( test_pred_res, test_data.schema) # LOGGER.debug(f"train_pred_res schema: {test_pred_res.schema}") self.evaluate(test_pred_res, fold_name, model) LOGGER.debug("Finish fold: {}".format(fold_num)) if self.output_fold_history: LOGGER.debug(f"generating fold history for fold {fold_num}") fold_train_data = self.transform_history_data( train_data, train_pred_res, fold_num, "train") fold_validate_data = self.transform_history_data( test_data, test_pred_res, fold_num, "validate") fold_history_data = fold_train_data.union(fold_validate_data) fold_history_data.schema = fold_train_data.schema if self.fold_history is None: self.fold_history = fold_history_data else: new_fold_history = self.fold_history.union( fold_history_data) new_fold_history.schema = fold_history_data.schema self.fold_history = new_fold_history summary_res[f"fold_{fold_num}"] = model.summary() fold_num += 1 summary_res['fold_num'] = fold_num LOGGER.debug("Finish all fold running") original_model.set_summary(summary_res) if self.output_fold_history: LOGGER.debug(f"output data schema: {self.fold_history.schema}") #LOGGER.debug(f"output data: {list(self.fold_history.collect())}") LOGGER.debug(f"output data is: {self.fold_history}") return self.fold_history else: return data_inst
def intersect_online_process(self, data_inst, caches): # LOGGER.debug(f"caches is: {caches}") cache_data, cache_meta = list(caches.values())[0] intersect_meta = list(cache_meta.values())[0]["intersect_meta"] # LOGGER.debug(f"intersect_meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) self.load_intersect_meta(intersect_meta) self.init_intersect_method() self.intersection_obj.load_intersect_key(cache_meta) if data_overview.check_with_inst_id(data_inst): self.use_match_id_process = True LOGGER.info(f"use match_id_process") intersect_data = data_inst if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id: proc_obj.use_sample_id() match_data = proc_obj.recover(data=data_inst) intersect_data = match_data if self.role == consts.HOST: cache_id = cache_meta[str(self.guest_party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0) guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0) if guest_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") elif self.role == consts.GUEST: for i, party_id in enumerate(self.host_party_id_list): cache_id = cache_meta[str(party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.HOST, idx=i) host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i) if host_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") else: raise ValueError(f"Role {self.role} cannot run intersection transform.") self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data) if self.use_match_id_process: if not self.model_param.sync_intersect_ids: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) else: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data) if self.intersect_ids and self.model_param.only_output_key: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"], "sid_name": data_inst.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data_inst.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process: if not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data_inst) self.intersect_ids.schema = result_data.schema LOGGER.debug(f"not only_output_key, restore value called") if self.intersection_obj.only_output_key and result_data: schema = {"sid_name": data_inst.schema["sid_name"]} result_data = result_data.mapValues(lambda v: 1) result_data.schema = schema self.intersect_ids.schema = schema if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data_inst, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def fit_binary(self, data_instances, validate_data=None): LOGGER.info("Enter hetero_lr_guest fit") self.header = self.get_header(data_instances) self.callback_list.on_train_begin(data_instances, validate_data) data_instances = data_instances.mapValues(HeteroLRGuest.load_data) LOGGER.debug( f"MODEL_STEP After load data, data count: {data_instances.count()}" ) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() self.batch_generator.initialize_batch_generator( data_instances, self.batch_size, batch_strategy=self.batch_strategy, masked_rate=self.masked_rate, shuffle=self.shuffle) if self.batch_generator.batch_masked: self.batch_generator.verify_batch_legality() self.gradient_loss_operator.set_total_batch_nums( self.batch_generator.batch_nums) use_async = False if with_weight(data_instances): if self.model_param.early_stop == "diff": LOGGER.warning( "input data with weight, please use 'weight_diff' for 'early_stop'." ) # data_instances = scale_sample_weight(data_instances) # self.gradient_loss_operator.set_use_sample_weight() # LOGGER.debug(f"data_instances after scale: {[v[1].weight for v in list(data_instances.collect())]}") elif len(self.component_properties.host_party_idlist ) == 1 and not self.batch_generator.batch_masked: LOGGER.debug(f"set_use_async") self.gradient_loss_operator.set_use_async() use_async = True self.transfer_variable.use_async.remote(use_async) LOGGER.info("Generate mini-batch from input data") LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) if not self.component_properties.is_warm_start: w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights( w, fit_intercept=self.fit_intercept) else: self.callback_warm_start_init_iter(self.n_iter_) while self.n_iter_ < self.max_iter: self.callback_list.on_epoch_begin(self.n_iter_) LOGGER.info("iter: {}".format(self.n_iter_)) batch_data_generator = self.batch_generator.generate_batch_data( suffix=(self.n_iter_, ), with_index=True) self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data, index_data in batch_data_generator: batch_feat_inst = batch_data if not self.batch_generator.batch_masked: index_data = None # Start gradient procedure LOGGER.debug( "iter: {}, batch: {}, before compute gradient, data count: {}" .format(self.n_iter_, batch_index, batch_feat_inst.count())) optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.cipher_operator, self.model_weights, self.optimizer, self.n_iter_, batch_index, masked_index=index_data) loss_norm = self.optimizer.loss_norm(self.model_weights) self.gradient_loss_operator.compute_loss( batch_feat_inst, self.model_weights, self.n_iter_, batch_index, loss_norm, batch_masked=self.batch_generator.batch_masked) self.model_weights = self.optimizer.update_model( self.model_weights, optim_guest_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break if self.is_converged: break self.callback_list.on_train_end() self.set_summary(self.get_model_summary())
def fit(self, data_instances, validate_data=None): """ Train poisson model of role guest Parameters ---------- data_instances: Table of Instance, input data """ LOGGER.info("Enter hetero_poisson_guest fit") # self._abnormal_detection(data_instances) # self.header = copy.deepcopy(self.get_header(data_instances)) self.prepare_fit(data_instances, validate_data) self.callback_list.on_train_begin(data_instances, validate_data) if with_weight(data_instances): LOGGER.warning( "input data with weight. Poisson regression does not support weighted training." ) self.exposure_index = self.get_exposure_index(self.header, self.exposure_colname) exposure_index = self.exposure_index if exposure_index > -1: self.header.pop(exposure_index) LOGGER.info("Guest provides exposure value.") exposure = data_instances.mapValues( lambda v: HeteroPoissonBase.load_exposure(v, exposure_index)) data_instances = data_instances.mapValues( lambda v: HeteroPoissonBase.load_instance(v, exposure_index)) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() LOGGER.info("Generate mini-batch from input data") self.batch_generator.initialize_batch_generator( data_instances, self.batch_size) LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) if not self.component_properties.is_warm_start: w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights( w, fit_intercept=self.fit_intercept, raise_overflow_error=False) else: self.callback_warm_start_init_iter(self.n_iter_) while self.n_iter_ < self.max_iter: self.callback_list.on_epoch_begin(self.n_iter_) LOGGER.info("iter:{}".format(self.n_iter_)) # each iter will get the same batch_data_generator batch_data_generator = self.batch_generator.generate_batch_data() self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data in batch_data_generator: # compute offset of this batch batch_offset = exposure.join( batch_data, lambda ei, d: HeteroPoissonBase.safe_log(ei)) # Start gradient procedure optimized_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_data, self.cipher_operator, self.model_weights, self.optimizer, self.n_iter_, batch_index, batch_offset) # LOGGER.debug("iteration:{} Guest's gradient: {}".format(self.n_iter_, optimized_gradient)) loss_norm = self.optimizer.loss_norm(self.model_weights) self.gradient_loss_operator.compute_loss( batch_data, self.model_weights, self.n_iter_, batch_index, batch_offset, loss_norm) self.model_weights = self.optimizer.update_model( self.model_weights, optimized_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break if self.is_converged: break self.callback_list.on_train_end() self.set_summary(self.get_model_summary())
def fit(self, data): if self.component_properties.caches: return self.intersect_online_process(data, self.component_properties.caches) self.init_intersect_method() if data_overview.check_with_inst_id(data): self.use_match_id_process = True LOGGER.info(f"use match_id_process") if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) self.proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id: self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data) if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(match_data) intersect_meta = self.intersection_obj.get_intersect_method_meta() self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(match_data) else: intersect_data = match_data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(match_data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) else: if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(data) intersect_meta = self.intersection_obj.get_intersect_method_meta() # LOGGER.debug(f"callback intersect meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(data) else: intersect_data = data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) if self.intersection_obj.cardinality_only: if self.intersection_obj.intersect_num is not None: data_count = data.count() self.intersect_num = self.intersection_obj.intersect_num self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate # self.model = self.intersection_obj.get_model() self.set_summary(self.get_model_summary()) self.callback() return data if self.use_match_id_process: if self.model_param.sync_intersect_ids: self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data) else: # self.intersect_ids = match_data self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) if self.model_param.only_output_key and self.intersect_ids: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"], "sid_name": data.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data) LOGGER.debug(f"not only_output_key, restore value called") if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def check(self): super(HeteroSecureBoostParam, self).check() self.tree_param.check() if not isinstance(self.use_missing, bool): raise ValueError('use missing should be bool type') if not isinstance(self.zero_as_missing, bool): raise ValueError('zero as missing should be bool type') self.check_boolean(self.complete_secure, 'complete_secure') self.check_boolean(self.run_goss, 'run goss') self.check_decimal_float(self.top_rate, 'top rate') self.check_decimal_float(self.other_rate, 'other rate') self.check_positive_number(self.other_rate, 'other_rate') self.check_positive_number(self.top_rate, 'top_rate') self.check_boolean(self.new_ver, 'code version switcher') self.check_boolean(self.cipher_compress, 'cipher compress') self.check_boolean(self.EINI_inference, 'eini inference') self.check_boolean(self.EINI_random_mask, 'eini random mask') self.check_boolean(self.EINI_complexity_check, 'eini complexity check') if self.EINI_inference and self.EINI_random_mask: LOGGER.warning( 'To protect the inference decision path, notice that current setting will multiply' ' predict result by a random number, hence SecureBoost will return confused predict scores' ' that is not the same as the original predict scores') if self.work_mode == consts.MIX_TREE and self.EINI_inference: LOGGER.warning( 'Mix tree mode does not support EINI, use default predict setting' ) if self.work_mode is not None: self.boosting_strategy = self.work_mode if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]: raise ValueError('unsupported multi-classification mode') if self.multi_mode == consts.MULTI_OUTPUT: if self.boosting_strategy != consts.STD_TREE: raise ValueError( 'MO trees only works when boosting strategy is std tree') if not self.cipher_compress: raise ValueError( 'Mo trees only works when cipher compress is enabled') if self.boosting_strategy not in [ consts.STD_TREE, consts.LAYERED_TREE, consts.MIX_TREE ]: raise ValueError('unknown sbt boosting strategy{}'.format( self.boosting_strategy)) for p in [ "early_stopping_rounds", "validation_freqs", "metrics", "use_first_metric_only" ]: # if self._warn_to_deprecate_param(p, "", ""): if self._deprecated_params_set.get(p): if "callback_param" in self.get_user_feeded(): raise ValueError( f"{p} and callback param should not be set simultaneously," f"{self._deprecated_params_set}, {self.get_user_feeded()}" ) else: self.callback_param.callbacks = ["PerformanceEvaluate"] break descr = "boosting_param's" if self._warn_to_deprecate_param( "validation_freqs", descr, "callback_param's 'validation_freqs'"): self.callback_param.validation_freqs = self.validation_freqs if self._warn_to_deprecate_param( "early_stopping_rounds", descr, "callback_param's 'early_stopping_rounds'"): self.callback_param.early_stopping_rounds = self.early_stopping_rounds if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"): self.callback_param.metrics = self.metrics if self._warn_to_deprecate_param( "use_first_metric_only", descr, "callback_param's 'use_first_metric_only'"): self.callback_param.use_first_metric_only = self.use_first_metric_only if self.top_rate + self.other_rate >= 1: raise ValueError( 'sum of top rate and other rate should be smaller than 1') return True
def __init__(self, tree_param: DecisionTreeParam, data_bin=None, bin_split_points: np.array = None, bin_sparse_point=None, g_h=None, valid_feature: dict = None, epoch_idx: int = None, role: str = None, tree_idx: int = None, flow_id: int = None, mode='train'): """ Parameters ---------- tree_param: decision tree parameter object data_bin binned: data instance bin_split_points: data split points bin_sparse_point: sparse data point g_h computed: g val and h val of instances valid_feature: dict points out valid features {valid:true,invalid:false} epoch_idx: current epoch index role: host or guest flow_id: flow id mode: train / predict """ super(HomoDecisionTreeClient, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = data_bin self.g_h = g_h self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_point self.epoch_idx = epoch_idx self.tree_idx = tree_idx # check max_split_nodes if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1: self.max_split_nodes += 1 LOGGER.warning( 'an even max_split_nodes value is suggested ' 'when using histogram-subtraction, max_split_nodes reset to {}' .format(self.max_split_nodes)) self.transfer_inst = HomoDecisionTreeTransferVariable() """ initializing here """ self.valid_features = valid_feature self.tree_node = [] # start from root node self.tree_node_num = 0 self.cur_layer_node = [] self.runtime_idx = 0 self.sitename = consts.GUEST self.feature_importance = {} # secure aggregator, class SecureBoostClientAggregator if mode == 'train': self.role = role self.set_flowid(flow_id) self.aggregator = DecisionTreeClientAggregator(verbose=False) elif mode == 'predict': self.role, self.aggregator = None, None
def fit(self, data_instances, validate_data=None): """ Train linR model of role guest Parameters ---------- data_instances: Table of Instance, input data """ LOGGER.info("Enter hetero_linR_guest fit") self._abnormal_detection(data_instances) self.header = self.get_header(data_instances) self.callback_list.on_train_begin(data_instances, validate_data) # self.validation_strategy = self.init_validation_strategy(data_instances, validate_data) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() use_async = False if with_weight(data_instances): if self.model_param.early_stop == "diff": LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.") data_instances = scale_sample_weight(data_instances) self.gradient_loss_operator.set_use_sample_weight() LOGGER.debug(f"instance weight scaled; use weighted gradient loss operator") # LOGGER.debug(f"data_instances after scale: {[v[1].weight for v in list(data_instances.collect())]}") elif len(self.component_properties.host_party_idlist) == 1: LOGGER.debug(f"set_use_async") self.gradient_loss_operator.set_use_async() use_async = True self.transfer_variable.use_async.remote(use_async) LOGGER.info("Generate mini-batch from input data") self.batch_generator.initialize_batch_generator(data_instances, self.batch_size) self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums) self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums)] LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format(self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) if not self.component_properties.is_warm_start: w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights(w, fit_intercept=self.fit_intercept, raise_overflow_error=False) else: self.callback_warm_start_init_iter(self.n_iter_) while self.n_iter_ < self.max_iter: self.callback_list.on_epoch_begin(self.n_iter_) LOGGER.info("iter:{}".format(self.n_iter_)) # each iter will get the same batch_data_generator batch_data_generator = self.batch_generator.generate_batch_data() self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data in batch_data_generator: # Start gradient procedure optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_data, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index ) loss_norm = self.optimizer.loss_norm(self.model_weights) self.gradient_loss_operator.compute_loss(batch_data, self.n_iter_, batch_index, loss_norm) self.model_weights = self.optimizer.update_model(self.model_weights, optim_guest_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,)) LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged)) self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break if self.is_converged: break self.callback_list.on_train_end() self.set_summary(self.get_model_summary())
def fit(self, data_instances): LOGGER.info("Start Hetero Selection Fit and transform.") self._abnormal_detection(data_instances) self._init_select_params(data_instances) original_col_nums = len( self.curt_select_properties.last_left_col_names) empty_cols = False if len(self.curt_select_properties.select_col_indexes) == 0: LOGGER.warning( "None of columns has been set to select, " "will randomly select one column to participate in fitting filter(s). " "All columns will be kept, " "but be aware that this may lead to unexpected behavior.") header = data_instances.schema.get("header") select_idx = random.choice(range(len(header))) self.curt_select_properties.select_col_indexes = [select_idx] self.curt_select_properties.select_col_names = [header[select_idx]] empty_cols = True suffix = self.filter_methods if self.role == consts.HOST: self.transfer_variable.host_empty_cols.remote(empty_cols, role=consts.GUEST, idx=0, suffix=suffix) else: host_empty_cols_list = self.transfer_variable.host_empty_cols.get( idx=-1, suffix=suffix) host_list = self.component_properties.host_party_idlist for idx, res in enumerate(host_empty_cols_list): if res: LOGGER.warning( f"Host {host_list[idx]}'s select columns are empty;" f"host {host_list[idx]} will randomly select one " f"column to participate in fitting filter(s). " f"All columns from this host will be kept, " f"but be aware that this may lead to unexpected behavior." ) for filter_idx, method in enumerate(self.filter_methods): if method in [ consts.STATISTIC_FILTER, consts.IV_FILTER, consts.PSI_FILTER, consts.HETERO_SBT_FILTER, consts.HOMO_SBT_FILTER, consts.HETERO_FAST_SBT_FILTER, consts.VIF_FILTER ]: if method == consts.STATISTIC_FILTER: metrics = self.model_param.statistic_param.metrics elif method == consts.IV_FILTER: metrics = self.model_param.iv_param.metrics elif method == consts.PSI_FILTER: metrics = self.model_param.psi_param.metrics elif method in [ consts.HETERO_SBT_FILTER, consts.HOMO_SBT_FILTER, consts.HETERO_FAST_SBT_FILTER ]: metrics = self.model_param.sbt_param.metrics elif method == consts.VIF_FILTER: metrics = self.model_param.vif_param.metrics else: raise ValueError(f"method: {method} is not supported") for idx, _ in enumerate(metrics): self._filter(data_instances, method, suffix=(str(filter_idx), str(idx)), idx=idx) else: self._filter(data_instances, method, suffix=str(filter_idx)) last_col_nums = self.curt_select_properties.last_left_col_names self.add_summary( "all", { "last_col_nums": original_col_nums, "left_col_nums": len(last_col_nums), "left_col_names": last_col_nums }) new_data = self._transfer_data(data_instances) # LOGGER.debug(f"Final summary: {self.summary()}") LOGGER.info("Finish Hetero Selection Fit and transform.") return new_data
def _warn_deprecated_param(self, param_name, descr): if self._deprecated_params_set.get(param_name): LOGGER.warning( f"{descr} {param_name} is deprecated and ignored in this version." )
def fit(self, data): LOGGER.debug(f"fit receives data is {data}") if not isinstance(data, dict): raise ValueError( "Union module must receive more than one table as input.") empty_count = 0 combined_table = None combined_schema = None metrics = [] for (key, local_table) in data.items(): LOGGER.debug("table to combine name: {}".format(key)) num_data = local_table.count() LOGGER.debug("table count: {}".format(num_data)) metrics.append(Metric(key, num_data)) self.add_summary(key, num_data) if num_data == 0: LOGGER.warning("Table {} is empty.".format(key)) if combined_table is None: combined_table = local_table combined_schema = local_table.schema empty_count += 1 continue local_is_data_instance = self.check_is_data_instance(local_table) if combined_table is None: self.is_data_instance = local_is_data_instance LOGGER.debug(f"self.is_data_instance is {self.is_data_instance}, " f"local_is_data_instance is {local_is_data_instance}") if self.is_data_instance != local_is_data_instance: raise ValueError( f"Cannot combine DataInstance and non-DataInstance object. Union aborted." ) if self.is_data_instance: self.is_empty_feature = data_overview.is_empty_feature( local_table) if self.is_empty_feature: LOGGER.warning("Table {} has empty feature.".format(key)) else: self.check_schema_content(local_table.schema) if combined_table is None: # first table to combine combined_table = local_table combined_schema = local_table.schema else: self.check_id(local_table, combined_table) self.check_label_name(local_table, combined_table) self.check_header(local_table, combined_table) if self.keep_duplicate: repeated_ids = combined_table.join(local_table, lambda v1, v2: 1) self.repeated_ids = [v[0] for v in repeated_ids.collect()] self.key = key local_table = local_table.flatMap(self._renew_id) combined_table = combined_table.union(local_table, self._keep_first) combined_table.schema = combined_schema # only check feature length if not empty if self.is_data_instance and not self.is_empty_feature: self.feature_count = len(combined_schema.get("header")) LOGGER.debug("feature count: {}".format(self.feature_count)) combined_table.mapValues(self.check_feature_length) if combined_table is None: num_data = 0 LOGGER.warning( "All tables provided are empty or have empty features.") else: num_data = combined_table.count() metrics.append(Metric("Total", num_data)) self.add_summary("Total", num_data) self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=metrics) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta( name=self.metric_name, metric_type=self.metric_type)) LOGGER.info( "Union operation finished. Total {} empty tables encountered.". format(empty_count)) return combined_table
def __save_pr_curve(self, precision_and_recall, data_name): precision_res = precision_and_recall[consts.PRECISION] recall_res = precision_and_recall[consts.RECALL] if precision_res[0] != recall_res[0]: LOGGER.warning( "precision mode:{} is not equal to recall mode:{}".format( precision_res[0], recall_res[0])) return metric_namespace = precision_res[0] metric_name_precision = '_'.join([data_name, "precision"]) metric_name_recall = '_'.join([data_name, "recall"]) pos_precision_score = precision_res[1][0] precision_cuts = precision_res[1][1] if len(precision_res[1]) >= 3: precision_thresholds = precision_res[1][2] else: precision_thresholds = None pos_recall_score = recall_res[1][0] recall_cuts = recall_res[1][1] if len(recall_res[1]) >= 3: recall_thresholds = recall_res[1][2] else: recall_thresholds = None precision_curve_name = data_name recall_curve_name = data_name if self.eval_type == consts.BINARY: pos_precision_score = [score[1] for score in pos_precision_score] pos_recall_score = [score[1] for score in pos_recall_score] pos_recall_score, pos_precision_score, idx_list = self.__filt_override_unit_ordinate_coordinate( pos_recall_score, pos_precision_score) precision_cuts = [precision_cuts[idx] for idx in idx_list] recall_cuts = [recall_cuts[idx] for idx in idx_list] edge_idx = idx_list[-1] if edge_idx == len(precision_thresholds) - 1: idx_list = idx_list[:-1] precision_thresholds = [ precision_thresholds[idx] for idx in idx_list ] recall_thresholds = [recall_thresholds[idx] for idx in idx_list] elif self.eval_type == consts.MULTY: pos_recall_score, recall_cuts = self.__multi_class_label_padding( pos_recall_score, recall_cuts) pos_precision_score, precision_cuts = self.__multi_class_label_padding( pos_precision_score, precision_cuts) self.__save_curve_data(precision_cuts, pos_precision_score, metric_name_precision, metric_namespace) self.__save_curve_meta( metric_name_precision, metric_namespace, "_".join([consts.PRECISION.upper(), self.eval_type.upper()]), unit_name="", ordinate_name="Precision", curve_name=precision_curve_name, pair_type=data_name, thresholds=precision_thresholds) self.__save_curve_data(recall_cuts, pos_recall_score, metric_name_recall, metric_namespace) self.__save_curve_meta( metric_name_recall, metric_namespace, "_".join([consts.RECALL.upper(), self.eval_type.upper()]), unit_name="", ordinate_name="Recall", curve_name=recall_curve_name, pair_type=data_name, thresholds=recall_thresholds)
def split_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count): min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count) final_max_bin = optimal_param.max_bin def _compute_ks(start_idx, end_idx): acc_event = [] acc_non_event = [] curt_event_total = 0 curt_non_event_total = 0 for bucket in bucket_list[start_idx: end_idx]: acc_event.append(bucket.event_count + curt_event_total) curt_event_total += bucket.event_count acc_non_event.append(bucket.non_event_count + curt_non_event_total) curt_non_event_total += bucket.non_event_count if curt_event_total == 0 or curt_non_event_total == 0: return None, None, None acc_event_rate = [x / curt_event_total for x in acc_event] acc_non_event_rate = [x / curt_non_event_total for x in acc_non_event] ks_list = [math.fabs(eve - non_eve) for eve, non_eve in zip(acc_event_rate, acc_non_event_rate)] if max(ks_list) == 0: best_index = len(ks_list) // 2 else: best_index = ks_list.index(max(ks_list)) left_event = acc_event[best_index] right_event = curt_event_total - left_event left_non_event = acc_non_event[best_index] right_non_event = curt_non_event_total - left_non_event left_total = left_event + left_non_event right_total = right_event + right_non_event if left_total < min_item_num or right_total < min_item_num: best_index = len(ks_list) // 2 left_event = acc_event[best_index] right_event = curt_event_total - left_event left_non_event = acc_non_event[best_index] right_non_event = curt_non_event_total - left_non_event left_total = left_event + left_non_event right_total = right_event + right_non_event best_ks = ks_list[best_index] res_dict = { 'left_event': left_event, 'right_event': right_event, 'left_non_event': left_non_event, 'right_non_event': right_non_event, 'left_total': left_total, 'right_total': right_total, 'left_is_mixed': left_event > 0 and left_non_event > 0, 'right_is_mixed': right_event > 0 and right_non_event > 0 } return best_ks, start + best_index, res_dict def _merge_buckets(start_idx, end_idx, bucket_idx): res_bucket = copy.deepcopy(bucket_list[start_idx]) res_bucket.idx = bucket_idx for bucket in bucket_list[start_idx + 1: end_idx]: res_bucket = res_bucket.merge(bucket) return res_bucket res_split_index = [] to_split_pair = [(0, len(bucket_list))] # iteratively split while len(to_split_pair) > 0: if len(res_split_index) >= final_max_bin - 1: break start, end = to_split_pair.pop(0) if start >= end: continue best_ks, best_index, res_dict = _compute_ks(start, end) if best_ks is None: continue if optimal_param.mixture: if not (res_dict.get('left_is_mixed') and res_dict.get('right_is_mixed')): continue if res_dict.get('left_total') < min_item_num or res_dict.get('right_total') < min_item_num: continue res_split_index.append(best_index + 1) if res_dict.get('right_total') > res_dict.get('left_total'): to_split_pair.append((best_index + 1, end)) to_split_pair.append((start, best_index + 1)) else: to_split_pair.append((start, best_index + 1)) to_split_pair.append((best_index + 1, end)) # LOGGER.debug("to_split_pair: {}".format(to_split_pair)) if len(res_split_index) == 0: LOGGER.warning("Best ks optimal binning fail to split. Take middle split point instead") res_split_index.append(len(bucket_list) // 2) res_split_index = sorted(res_split_index) res_split_index.append(len(bucket_list)) start = 0 bucket_res = [] non_mixture_num = 0 small_size_num = 0 for bucket_idx, end in enumerate(res_split_index): new_bucket = _merge_buckets(start, end, bucket_idx) bucket_res.append(new_bucket) if not new_bucket.is_mixed: non_mixture_num += 1 if new_bucket.total_count < min_item_num: small_size_num += 1 start = end return bucket_res, non_mixture_num, small_size_num
def fit(self): LOGGER.info( 'begin to fit h**o decision tree, epoch {}, tree idx {}'.format( self.epoch_idx, self.tree_idx)) g_sum, h_sum = self.aggregator.aggregate_root_node_info( suffix=('root_node_sync1', self.epoch_idx)) self.aggregator.broadcast_root_info(g_sum, h_sum, suffix=('root_node_sync2', self.epoch_idx)) if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1: self.max_split_nodes += 1 LOGGER.warning( 'an even max_split_nodes value is suggested when using histogram-subtraction, ' 'max_split_nodes reset to {}'.format(self.max_split_nodes)) tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if dep + 1 == tree_height: break LOGGER.debug('current dep is {}'.format(dep)) split_info = [] # get cur layer node num: cur_layer_node_num = self.sync_node_sample_numbers( suffix=(dep, self.epoch_idx, self.tree_idx)) LOGGER.debug( '{} nodes to split at this layer'.format(cur_layer_node_num)) layer_stored_hist = {} for batch_id, i in enumerate( range(0, cur_layer_node_num, self.max_split_nodes)): LOGGER.debug('cur batch id is {}'.format(batch_id)) left_node_histogram = self.sync_local_histogram( suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) all_histograms = self.histogram_subtraction( left_node_histogram, self.stored_histograms) # store histogram for hist in all_histograms: layer_stored_hist[hist.hid] = hist # FIXME stable parallel_partitions best_splits = self.federated_find_best_split( all_histograms, parallel_partitions=10) split_info += best_splits self.stored_histograms = layer_stored_hist self.sync_best_splits(split_info, suffix=(dep, self.epoch_idx)) LOGGER.debug('best_splits_sent')
def check(self): descr = "intersect param's " self.intersect_method = self.check_and_change_lower( self.intersect_method, [consts.RSA, consts.RAW, consts.DH], f"{descr}intersect_method") if self._warn_to_deprecate_param("random_bit", descr, "rsa_params' 'random_bit'"): if "rsa_params.random_bit" in self.get_user_feeded(): raise ValueError( f"random_bit and rsa_params.random_bit should not be set simultaneously" ) self.rsa_params.random_bit = self.random_bit self.check_boolean(self.sync_intersect_ids, f"{descr}intersect_ids") if self._warn_to_deprecate_param("encode_param", "", ""): if "raw_params" in self.get_user_feeded(): raise ValueError( f"encode_param and raw_params should not be set simultaneously" ) else: self.callback_param.callbacks = ["PerformanceEvaluate"] if self._warn_to_deprecate_param("join_role", descr, "raw_params' 'join_role'"): if "raw_params.join_role" in self.get_user_feeded(): raise ValueError( f"join_role and raw_params.join_role should not be set simultaneously" ) self.raw_params.join_role = self.join_role self.check_boolean(self.only_output_key, f"{descr}only_output_key") self.join_method = self.check_and_change_lower( self.join_method, [consts.INNER_JOIN, consts.LEFT_JOIN], f"{descr}join_method") self.check_boolean(self.new_sample_id, f"{descr}new_sample_id") self.sample_id_generator = self.check_and_change_lower( self.sample_id_generator, [consts.GUEST, consts.HOST], f"{descr}sample_id_generator") if self.join_method == consts.LEFT_JOIN: if not self.sync_intersect_ids: raise ValueError( f"Cannot perform left join without sync intersect ids") self.check_boolean(self.run_cache, f"{descr} run_cache") if self._warn_to_deprecate_param("encode_params", descr, "raw_params") or \ self._warn_to_deprecate_param("with_encode", descr, "raw_params' 'use_hash'"): # self.encode_params.check() if "with_encode" in self.get_user_feeded( ) and "raw_params.use_hash" in self.get_user_feeded(): raise ValueError( f"'raw_params' and 'encode_params' should not be set simultaneously." ) if "raw_params" in self.get_user_feeded( ) and "encode_params" in self.get_user_feeded(): raise ValueError( f"'raw_params' and 'encode_params' should not be set simultaneously." ) LOGGER.warning( f"Param values from 'encode_params' will override 'raw_params' settings." ) self.raw_params.use_hash = self.with_encode self.raw_params.hash_method = self.encode_params.encode_method self.raw_params.salt = self.encode_params.salt self.raw_params.base64 = self.encode_params.base64 self.raw_params.check() self.rsa_params.check() self.dh_params.check() # self.intersect_cache_param.check() self.check_boolean(self.cardinality_only, f"{descr}cardinality_only") self.check_boolean(self.sync_cardinality, f"{descr}sync_cardinality") self.check_boolean(self.run_preprocess, f"{descr}run_preprocess") self.intersect_preprocess_params.check() if self.cardinality_only: if self.intersect_method not in [consts.RSA]: raise ValueError(f"cardinality-only mode only support rsa.") if self.intersect_method == consts.RSA and self.rsa_params.split_calculation: raise ValueError( f"cardinality-only mode only supports unified calculation." ) if self.run_preprocess: if self.intersect_preprocess_params.false_positive_rate < 0.01: raise ValueError( f"for preprocessing ids, false_positive_rate must be no less than 0.01" ) if self.cardinality_only: raise ValueError( f"cardinality_only mode cannot run preprocessing.") if self.run_cache: if self.intersect_method not in [consts.RSA, consts.DH]: raise ValueError(f"Only rsa or dh method supports cache.") if self.intersect_method == consts.RSA and self.rsa_params.split_calculation: raise ValueError( f"RSA split_calculation does not support cache.") if self.cardinality_only: raise ValueError( f"cache is not available for cardinality_only mode.") if self.run_preprocess: raise ValueError(f"Preprocessing does not support cache.") deprecated_param_list = [ "repeated_id_process", "repeated_id_owner", "intersect_cache_param", "allow_info_share", "info_owner", "with_sample_id" ] for param in deprecated_param_list: self._warn_deprecated_param(param, descr) LOGGER.debug("Finish intersect parameter check!") return True
def fit_single_model(self, data_instances, validate_data=None): LOGGER.info(f"Start to train single {self.model_name}") if len(self.component_properties.host_party_idlist) > 1: raise ValueError(f"Hetero SSHE Model does not support multi-host training.") self.callback_list.on_train_begin(data_instances, validate_data) model_shape = self.get_features_shape(data_instances) instances_count = data_instances.count() if not self.component_properties.is_warm_start: w = self._init_weights(model_shape) self.model_weights = LinearModelWeights(l=w, fit_intercept=self.model_param.init_param.fit_intercept) last_models = copy.deepcopy(self.model_weights) else: last_models = copy.deepcopy(self.model_weights) w = last_models.unboxed self.callback_warm_start_init_iter(self.n_iter_) if self.role == consts.GUEST: if with_weight(data_instances): LOGGER.info(f"data with sample weight, use sample weight.") if self.model_param.early_stop == "diff": LOGGER.warning("input data with weight, please use 'weight_diff' for 'early_stop'.") data_instances = scale_sample_weight(data_instances) self.batch_generator.initialize_batch_generator(data_instances, batch_size=self.batch_size) with SPDZ( "hetero_sshe", local_party=self.local_party, all_parties=self.parties, q_field=self.q_field, use_mix_rand=self.model_param.use_mix_rand, ) as spdz: spdz.set_flowid(self.flowid) self.secure_matrix_obj.set_flowid(self.flowid) # not sharing the model when reveal_every_iter if not self.reveal_every_iter: w_self, w_remote = self.share_model(w, suffix="init") last_w_self, last_w_remote = w_self, w_remote LOGGER.debug(f"first_w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}") batch_data_generator = self.batch_generator.generate_batch_data() encoded_batch_data = [] batch_labels_list = [] batch_weight_list = [] for batch_data in batch_data_generator: if self.fit_intercept: batch_features = batch_data.mapValues(lambda x: np.hstack((x.features, 1.0))) else: batch_features = batch_data.mapValues(lambda x: x.features) if self.role == consts.GUEST: batch_labels = batch_data.mapValues(lambda x: np.array([x.label], dtype=self.label_type)) batch_labels_list.append(batch_labels) if self.weight: batch_weight = batch_data.mapValues(lambda x: np.array([x.weight], dtype=float)) batch_weight_list.append(batch_weight) else: batch_weight_list.append(None) self.batch_num.append(batch_data.count()) encoded_batch_data.append( fixedpoint_table.FixedPointTensor(self.fixedpoint_encoder.encode(batch_features), q_field=self.fixedpoint_encoder.n, endec=self.fixedpoint_encoder)) while self.n_iter_ < self.max_iter: self.callback_list.on_epoch_begin(self.n_iter_) LOGGER.info(f"start to n_iter: {self.n_iter_}") loss_list = [] self.optimizer.set_iters(self.n_iter_) if not self.reveal_every_iter: self.self_optimizer.set_iters(self.n_iter_) self.remote_optimizer.set_iters(self.n_iter_) for batch_idx, batch_data in enumerate(encoded_batch_data): current_suffix = (str(self.n_iter_), str(batch_idx)) if self.role == consts.GUEST: batch_labels = batch_labels_list[batch_idx] batch_weight = batch_weight_list[batch_idx] else: batch_labels = None batch_weight = None if self.reveal_every_iter: y = self.forward(weights=self.model_weights, features=batch_data, labels=batch_labels, suffix=current_suffix, cipher=self.cipher, batch_weight=batch_weight) else: y = self.forward(weights=(w_self, w_remote), features=batch_data, labels=batch_labels, suffix=current_suffix, cipher=self.cipher, batch_weight=batch_weight) if self.role == consts.GUEST: if self.weight: error = y - batch_labels.join(batch_weight, lambda y, b: y * b) else: error = y - batch_labels self_g, remote_g = self.backward(error=error, features=batch_data, suffix=current_suffix, cipher=self.cipher) else: self_g, remote_g = self.backward(error=y, features=batch_data, suffix=current_suffix, cipher=self.cipher) # loss computing; suffix = ("loss",) + current_suffix if self.reveal_every_iter: batch_loss = self.compute_loss(weights=self.model_weights, labels=batch_labels, suffix=suffix, cipher=self.cipher) else: batch_loss = self.compute_loss(weights=(w_self, w_remote), labels=batch_labels, suffix=suffix, cipher=self.cipher) if batch_loss is not None: batch_loss = batch_loss * self.batch_num[batch_idx] loss_list.append(batch_loss) if self.reveal_every_iter: # LOGGER.debug(f"before reveal: self_g shape: {self_g.shape}, remote_g_shape: {remote_g}," # f"self_g: {self_g}") new_g = self.reveal_models(self_g, remote_g, suffix=current_suffix) # LOGGER.debug(f"after reveal: new_g shape: {new_g.shape}, new_g: {new_g}" # f"self.model_param.reveal_strategy: {self.model_param.reveal_strategy}") if new_g is not None: self.model_weights = self.optimizer.update_model(self.model_weights, new_g, has_applied=False) else: self.model_weights = LinearModelWeights( l=np.zeros(self_g.shape), fit_intercept=self.model_param.init_param.fit_intercept) else: if self.optimizer.penalty == consts.L2_PENALTY: self_g = self_g + self.self_optimizer.alpha * w_self remote_g = remote_g + self.remote_optimizer.alpha * w_remote # LOGGER.debug(f"before optimizer: {self_g}, {remote_g}") self_g = self.self_optimizer.apply_gradients(self_g) remote_g = self.remote_optimizer.apply_gradients(remote_g) # LOGGER.debug(f"after optimizer: {self_g}, {remote_g}") w_self -= self_g w_remote -= remote_g LOGGER.debug(f"w_self shape: {w_self.shape}, w_remote_shape: {w_remote.shape}") if self.role == consts.GUEST: loss = np.sum(loss_list) / instances_count self.loss_history.append(loss) if self.need_call_back_loss: self.callback_loss(self.n_iter_, loss) else: loss = None if self.converge_func_name in ["diff", "abs"]: self.is_converged = self.check_converge_by_loss(loss, suffix=(str(self.n_iter_),)) elif self.converge_func_name == "weight_diff": if self.reveal_every_iter: self.is_converged = self.check_converge_by_weights( last_w=last_models.unboxed, new_w=self.model_weights.unboxed, suffix=(str(self.n_iter_),)) last_models = copy.deepcopy(self.model_weights) else: self.is_converged = self.check_converge_by_weights( last_w=(last_w_self, last_w_remote), new_w=(w_self, w_remote), suffix=(str(self.n_iter_),)) last_w_self, last_w_remote = copy.deepcopy(w_self), copy.deepcopy(w_remote) else: raise ValueError(f"Cannot recognize early_stop function: {self.converge_func_name}") LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged)) self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break if self.is_converged: break # Finally reconstruct if not self.reveal_every_iter: new_w = self.reveal_models(w_self, w_remote, suffix=("final",)) if new_w is not None: self.model_weights = LinearModelWeights( l=new_w, fit_intercept=self.model_param.init_param.fit_intercept) LOGGER.debug(f"loss_history: {self.loss_history}") self.set_summary(self.get_model_summary())