def check_host_number(self, tree_type): host_num = len(self.component_properties.host_party_idlist) LOGGER.info('host number is {}'.format(host_num)) if tree_type == plan.tree_type_dict['layered_tree']: assert host_num == 1, 'only 1 host party is allowed in layered mode'
def remove_redundant_splitinfo_in_split_maskdict(self, split_nid_used): LOGGER.info("remove duplicated nodes from split mask dict") duplicated_nodes = set( self.split_maskdict.keys()) - set(split_nid_used) for nid in duplicated_nodes: del self.split_maskdict[nid]
def fit(self): """ start to fit """ LOGGER.info( 'begin to fit h**o decision tree, epoch {}, tree idx {}'.format( self.epoch_idx, self.tree_idx)) # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.g_h) # get aggregated root info self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx)) g_h_dict = self.aggregator.get_aggregated_root_info( suffix=('root_node_sync2', self.epoch_idx)) global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum'] # initialize node root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight=self.splitter.node_weight( global_g_sum, global_h_sum)) self.cur_layer_node = [root_node] LOGGER.debug('assign samples to root node') self.inst2node_idx = self.assign_instance_to_root_node( self.data_bin, 0) tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if dep + 1 == tree_height: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) rest_sample_weights = self.get_node_sample_weights( self.inst2node_idx, self.tree_node) if self.sample_weights is None: self.sample_weights = rest_sample_weights else: self.sample_weights = self.sample_weights.union( rest_sample_weights) # stop fitting break LOGGER.debug('start to fit layer {}'.format(dep)) table_with_assignment = self.update_instances_node_positions() # send current layer node number: self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) split_info, agg_histograms = [], [] for batch_id, i in enumerate( range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes] node_map = self.get_node_map(nodes=cur_to_split) LOGGER.debug('node map is {}'.format(node_map)) LOGGER.debug( 'computing histogram for batch{} at depth{}'.format( batch_id, dep)) local_histogram = self.get_left_node_local_histogram( cur_nodes=cur_to_split, tree=self.tree_node, g_h=self.g_h, table_with_assign=table_with_assignment, split_points=self.bin_split_points, sparse_point=self.bin_sparse_points, valid_feature=self.valid_features) LOGGER.debug( 'federated finding best splits for batch{} at layer {}'. format(batch_id, dep)) self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) agg_histograms += local_histogram split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) LOGGER.debug('got best splits from arbiter') new_layer_node = self.update_tree(self.cur_layer_node, split_info) self.cur_layer_node = new_layer_node self.inst2node_idx, leaf_val = self.assign_instances_to_new_node( table_with_assignment, self.tree_node) # record leaf val if self.sample_weights is None: self.sample_weights = leaf_val else: self.sample_weights = self.sample_weights.union(leaf_val) LOGGER.debug('assigning instance to new nodes done') self.convert_bin_to_real() LOGGER.debug('fitting tree done')
def sync_dispatch_node_host(self, dep): LOGGER.info("get node from host to dispath, depth is {}".format(dep)) dispatch_node_host = self.transfer_inst.dispatch_node_host.get( idx=0, suffix=(dep, )) return dispatch_node_host
def sync_predict_finish_tag(self, recv_times): LOGGER.info( "get the {}-th predict finish tag from guest".format(recv_times)) finish_tag = self.transfer_inst.predict_finish_tag.get( idx=0, suffix=(recv_times, )) return finish_tag
def set_encrypter(self, encrypter): LOGGER.info("set encrypter") self.encrypter = encrypter
def sync_node_positions(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) node_positions = self.transfer_inst.node_positions.get(idx=0, suffix=(dep, )) return node_positions
def fit_binary(self, data_instances, validate_data=None): LOGGER.info("Enter hetero_lr_guest fit") self.header = self.get_header(data_instances) self.validation_strategy = self.init_validation_strategy( data_instances, validate_data) data_instances = data_instances.mapValues(HeteroLRGuest.load_data) LOGGER.debug( f"MODEL_STEP After load data, data count: {data_instances.count()}" ) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() LOGGER.info("Generate mini-batch from input data") self.batch_generator.initialize_batch_generator( data_instances, self.batch_size) self.gradient_loss_operator.set_total_batch_nums( self.batch_generator.batch_nums) self.encrypted_calculator = [ EncryptModeCalculator( self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums) ] LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights( w, fit_intercept=self.fit_intercept) while self.n_iter_ < self.max_iter: LOGGER.info("iter:{}".format(self.n_iter_)) batch_data_generator = self.batch_generator.generate_batch_data() self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data in batch_data_generator: # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst' batch_feat_inst = batch_data # LOGGER.debug(f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}") # Start gradient procedure LOGGER.debug( "iter: {}, before compute gradient, data count: {}".format( self.n_iter_, batch_feat_inst.count())) optim_guest_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index) # LOGGER.debug('optim_guest_gradient: {}'.format(optim_guest_gradient)) # training_info = {"iteration": self.n_iter_, "batch_index": batch_index} # self.update_local_model(fore_gradient, data_instances, self.model_weights.coef_, **training_info) loss_norm = self.optimizer.loss_norm(self.model_weights) self.gradient_loss_operator.compute_loss( data_instances, self.model_weights, self.n_iter_, batch_index, loss_norm) self.model_weights = self.optimizer.update_model( self.model_weights, optim_guest_gradient) batch_index += 1 # LOGGER.debug("lr_weight, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed)) self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) if self.validation_strategy: LOGGER.debug('LR guest running validation') self.validation_strategy.validate(self, self.n_iter_) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break self.n_iter_ += 1 if self.is_converged: break if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self.get_model_summary())
def unified_calculation_process(self, data_instances): LOGGER.info("RSA intersect using unified calculation.") # generate rsa keys # self.e, self.d, self.n = self.generate_protocol_key() self.generate_protocol_key() LOGGER.info("Generate protocol key!") public_key = {"e": self.e, "n": self.n} # sends public key e & n to guest self.transfer_variable.host_pubkey.remote(public_key, role=consts.GUEST, idx=0) LOGGER.info("Remote public key to Guest.") # hash host ids prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair( data_instances, self.d, self.n, self.p, self.q, self.cp, self.cq, self.first_hash_operator) prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: 1) self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process, role=consts.GUEST, idx=0) LOGGER.info("Remote host_ids_process to Guest.") # Recv guest ids guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0) LOGGER.info("Get guest_pubkey_ids from guest") # Process(signs) guest ids and return to guest host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: ( k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq ))) self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids, role=consts.GUEST, idx=0) LOGGER.info("Remote host_sign_guest_ids_process to Guest.") # recv intersect ids intersect_ids = None if self.sync_intersect_ids: encrypt_intersect_ids = self.transfer_variable.intersect_ids.get( idx=0) intersect_ids_pair = encrypt_intersect_ids.join( prvkey_ids_process_pair, lambda e, h: h) intersect_ids = intersect_ids_pair.map(lambda k, v: (v, "id")) LOGGER.info("Get intersect ids from Guest") return intersect_ids
def fit(self, data_instances, validate_data=None): """ Train linR model of role guest Parameters ---------- data_instances: DTable of Instance, input data """ LOGGER.info("Enter hetero_linR_guest fit") self._abnormal_detection(data_instances) self.header = self.get_header(data_instances) self.validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() LOGGER.info("Generate mini-batch from input data") self.batch_generator.initialize_batch_generator( data_instances, self.batch_size) self.gradient_loss_operator.set_total_batch_nums( self.batch_generator.batch_nums) self.encrypted_calculator = [ EncryptModeCalculator( self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums) ] LOGGER.info("Start initialize model.") LOGGER.info("fit_intercept:{}".format( self.init_param_obj.fit_intercept)) model_shape = self.get_features_shape(data_instances) w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights( w, fit_intercept=self.fit_intercept) while self.n_iter_ < self.max_iter: LOGGER.info("iter:{}".format(self.n_iter_)) # each iter will get the same batch_data_generator batch_data_generator = self.batch_generator.generate_batch_data() self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data in batch_data_generator: # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst' batch_feat_inst = self.transform(batch_data) # Start gradient procedure optim_guest_gradient, _, _ = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index) loss_norm = self.optimizer.loss_norm(self.model_weights) self.gradient_loss_operator.compute_loss( data_instances, self.n_iter_, batch_index, loss_norm) self.model_weights = self.optimizer.update_model( self.model_weights, optim_guest_gradient) batch_index += 1 # LOGGER.debug( # "model_weights, iters: {}, update_model: {}".format(self.n_iter_, self.model_weights.unboxed)) self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) # LOGGER.debug("model weights is {}".format(self.model_weights.coef_)) if self.validation_strategy: LOGGER.debug('LinR guest running validation') self.validation_strategy.validate(self, self.n_iter_) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break self.n_iter_ += 1 if self.is_converged: break if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self.get_model_summary())
def get_grad_and_hess(g_h, dim=0): LOGGER.info("get grad and hess of tree {}".format(dim)) grad_and_hess_subtree = g_h.mapValues(lambda grad_and_hess: ( grad_and_hess[0][dim], grad_and_hess[1][dim])) return grad_and_hess_subtree
def intersect_online_process(self, data_inst, caches): # LOGGER.debug(f"caches is: {caches}") cache_data, cache_meta = list(caches.values())[0] intersect_meta = list(cache_meta.values())[0]["intersect_meta"] # LOGGER.debug(f"intersect_meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) self.load_intersect_meta(intersect_meta) self.init_intersect_method() self.intersection_obj.load_intersect_key(cache_meta) if data_overview.check_with_inst_id(data_inst): self.use_match_id_process = True LOGGER.info(f"use match_id_process") intersect_data = data_inst if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id: proc_obj.use_sample_id() match_data = proc_obj.recover(data=data_inst) intersect_data = match_data if self.role == consts.HOST: cache_id = cache_meta[str(self.guest_party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0) guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0) if guest_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") elif self.role == consts.GUEST: for i, party_id in enumerate(self.host_party_id_list): cache_id = cache_meta[str(party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.HOST, idx=i) host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i) if host_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") else: raise ValueError(f"Role {self.role} cannot run intersection transform.") self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data) if self.use_match_id_process: if not self.model_param.sync_intersect_ids: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) else: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data) if self.intersect_ids and self.model_param.only_output_key: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"], "sid_name": data_inst.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data_inst.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process: if not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data_inst) self.intersect_ids.schema = result_data.schema LOGGER.debug(f"not only_output_key, restore value called") if self.intersection_obj.only_output_key and result_data: schema = {"sid_name": data_inst.schema["sid_name"]} result_data = result_data.mapValues(lambda v: 1) result_data.schema = schema self.intersect_ids.schema = schema if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data_inst, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def fit(self, data): if self.component_properties.caches: return self.intersect_online_process(data, self.component_properties.caches) self.init_intersect_method() if data_overview.check_with_inst_id(data): self.use_match_id_process = True LOGGER.info(f"use match_id_process") if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) self.proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id: self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data) if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(match_data) intersect_meta = self.intersection_obj.get_intersect_method_meta() self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(match_data) else: intersect_data = match_data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(match_data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) else: if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(data) intersect_meta = self.intersection_obj.get_intersect_method_meta() # LOGGER.debug(f"callback intersect meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(data) else: intersect_data = data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) if self.intersection_obj.cardinality_only: if self.intersection_obj.intersect_num is not None: data_count = data.count() self.intersect_num = self.intersection_obj.intersect_num self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate # self.model = self.intersection_obj.get_model() self.set_summary(self.get_model_summary()) self.callback() return data if self.use_match_id_process: if self.model_param.sync_intersect_ids: self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data) else: # self.intersect_ids = match_data self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) if self.model_param.only_output_key and self.intersect_ids: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"], "sid_name": data.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data) LOGGER.debug(f"not only_output_key, restore value called") if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def predict(self, data_inst, ret_format='std'): # standard format, leaf indices, raw score assert ret_format in ['std', 'leaf', 'raw'], 'illegal ret format' LOGGER.info('running prediction') cache_dataset_key = self.predict_data_cache.get_data_key(data_inst) processed_data = self.data_and_header_alignment(data_inst) last_round = self.predict_data_cache.predict_data_last_round( cache_dataset_key) self.sync_predict_round(last_round) rounds = len(self.boosting_model_list) // self.booster_dim trees = [] LOGGER.debug( 'round involved in prediction {}, last round is {}, data key {}'. format(list(range(last_round, rounds)), last_round, cache_dataset_key)) for idx in range(last_round, rounds): for booster_idx in range(self.booster_dim): tree = self.load_learner( self.booster_meta, self.boosting_model_list[idx * self.booster_dim + booster_idx], idx, booster_idx) trees.append(tree) predict_cache = None tree_num = len(trees) if last_round != 0: predict_cache = self.predict_data_cache.predict_data_at( cache_dataset_key, min(rounds, last_round)) LOGGER.info('load predict cache of round {}'.format( min(rounds, last_round))) if tree_num == 0 and predict_cache is not None and not (ret_format == 'leaf'): return self.score_to_predict_result(data_inst, predict_cache) if self.boosting_strategy == consts.MIX_TREE: predict_rs = mix_sbt_guest_predict( processed_data, self.hetero_sbt_transfer_variable, trees, self.learning_rate, self.init_score, self.booster_dim, predict_cache, pred_leaf=(ret_format == 'leaf')) else: if self.EINI_inference and not self.on_training: # EINI is for inference stage sitename = self.role + ':' + str( self.component_properties.local_partyid) predict_rs = EINI_guest_predict( processed_data, trees, self.learning_rate, self.init_score, self.booster_dim, self.encrypt_param.key_length, self.hetero_sbt_transfer_variable, sitename, self.component_properties.host_party_idlist, predict_cache, False) else: predict_rs = sbt_guest_predict( processed_data, self.hetero_sbt_transfer_variable, trees, self.learning_rate, self.init_score, self.booster_dim, predict_cache, pred_leaf=(ret_format == 'leaf')) if ret_format == 'leaf': return predict_rs # predict result is leaf position self.predict_data_cache.add_data(cache_dataset_key, predict_rs, cur_boosting_round=rounds) LOGGER.debug('adding predict rs {}'.format(predict_rs)) LOGGER.debug('last round is {}'.format( self.predict_data_cache.predict_data_last_round( cache_dataset_key))) if ret_format == 'raw': return predict_rs else: return self.score_to_predict_result(data_inst, predict_rs)
def sync_predict_data(self, predict_data, send_times): LOGGER.info("send predict data to host, sending times is {}".format(send_times)) self.transfer_inst.predict_data.remote(predict_data, role=consts.HOST, idx=-1, suffix=(send_times,))
def run_cardinality(self, data_instances): LOGGER.info(f"run cardinality_only with RSA") # generate rsa keys self.generate_protocol_key() LOGGER.info("Generate protocol key!") public_key = {"e": self.e, "n": self.n} # sends public key e & n to guest self.transfer_variable.host_pubkey.remote(public_key, role=consts.GUEST, idx=0) LOGGER.info("Remote public key to Guest.") # hash host ids prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair( data_instances, self.d, self.n, self.p, self.q, self.cp, self.cq, self.first_hash_operator) filter = self.construct_filter( prvkey_ids_process_pair, false_positive_rate=self.intersect_preprocess_params. false_positive_rate, hash_method=self.intersect_preprocess_params.hash_method, random_state=self.intersect_preprocess_params.random_state) self.filter = filter self.transfer_variable.host_filter.remote(filter, role=consts.GUEST, idx=0) LOGGER.info("Remote host_filter to Guest.") # Recv guest ids guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0) LOGGER.info("Get guest_pubkey_ids from guest") # Process(signs) guest ids and return to guest host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: ( k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq ))) self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids, role=consts.GUEST, idx=0) LOGGER.info("Remote host_sign_guest_ids_process to Guest.") if self.sync_cardinality: self.intersect_num = self.transfer_variable.cardinality.get(idx=0) LOGGER.info("Got intersect cardinality from guest.") return data_instances
def sync_data_predicted_by_host(self, send_times): LOGGER.info("get predicted data by host, recv times is {}".format(send_times)) predict_data = self.transfer_inst.predict_data_by_host.get(idx=-1, suffix=(send_times,)) return predict_data
def split_calculation_process(self, data_instances): LOGGER.info("RSA intersect using split calculation.") # split data sid_hash_odd = data_instances.filter(lambda k, v: k & 1) sid_hash_even = data_instances.filter(lambda k, v: not k & 1) # LOGGER.debug(f"sid_hash_odd count: {sid_hash_odd.count()}," # f"odd fraction: {sid_hash_odd.count()/data_instances.count()}") # generate rsa keys # self.e, self.d, self.n = self.generate_protocol_key() self.generate_protocol_key() LOGGER.info("Generate host protocol key!") public_key = {"e": self.e, "n": self.n} # sends public key e & n to guest self.transfer_variable.host_pubkey.remote(public_key, role=consts.GUEST, idx=0) LOGGER.info("Remote public key to Guest.") # generate ri for even ids # count = sid_hash_even.count() # self.r = self.generate_r_base(self.random_bit, count, self.random_base_fraction) # LOGGER.info(f"Generate {len(self.r)} r values.") # receive guest key for even ids guest_public_key = self.transfer_variable.guest_pubkey.get(idx=0) # LOGGER.debug("Get guest_public_key:{} from Guest".format(guest_public_key)) LOGGER.info(f"Get guest_public_key from Guest") self.rcv_e = int(guest_public_key["e"]) self.rcv_n = int(guest_public_key["n"]) # encrypt & send guest pubkey-encrypted odd ids pubkey_ids_process = self.pubkey_id_process( sid_hash_even, fraction=self.random_base_fraction, random_bit=self.random_bit, rsa_e=self.rcv_e, rsa_n=self.rcv_n) LOGGER.info(f"Finish pubkey_ids_process") mask_host_id = pubkey_ids_process.mapValues(lambda v: 1) self.transfer_variable.host_pubkey_ids.remote(mask_host_id, role=consts.GUEST, idx=0) LOGGER.info("Remote host_pubkey_ids to Guest") # encrypt & send prvkey-encrypted host odd ids to guest prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair( sid_hash_odd, self.d, self.n, self.p, self.q, self.cp, self.cq) prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: 1) self.transfer_variable.host_prvkey_ids.remote(prvkey_ids_process, role=consts.GUEST, idx=0) LOGGER.info("Remote host_prvkey_ids to Guest.") # get & sign guest pubkey-encrypted odd ids guest_pubkey_ids = self.transfer_variable.guest_pubkey_ids.get(idx=0) LOGGER.info(f"Get guest_pubkey_ids from guest") host_sign_guest_ids = guest_pubkey_ids.map(lambda k, v: ( k, self.sign_id(k, self.d, self.n, self.p, self.q, self.cp, self.cq ))) LOGGER.debug(f"host sign guest_pubkey_ids") # send signed guest odd ids self.transfer_variable.host_sign_guest_ids.remote(host_sign_guest_ids, role=consts.GUEST, idx=0) LOGGER.info("Remote host_sign_guest_ids_process to Guest.") # recv guest privkey-encrypted even ids guest_prvkey_ids = self.transfer_variable.guest_prvkey_ids.get(idx=0) LOGGER.info("Get guest_prvkey_ids") # receive guest-signed host even ids recv_guest_sign_host_ids = self.transfer_variable.guest_sign_host_ids.get( idx=0) LOGGER.info(f"Get guest_sign_host_ids from Guest.") guest_sign_host_ids = pubkey_ids_process.join( recv_guest_sign_host_ids, lambda g, r: (g[0], RsaIntersectionHost.hash( gmpy2.divm(int(r), int(g[1]), self.rcv_n), self. final_hash_operator, self.rsa_params.salt))) sid_guest_sign_host_ids = guest_sign_host_ids.map(lambda k, v: (v[1], v[0])) encrypt_intersect_even_ids = sid_guest_sign_host_ids.join( guest_prvkey_ids, lambda sid, h: sid) # filter & send intersect even ids intersect_even_ids = self.filter_intersect_ids( [encrypt_intersect_even_ids]) remote_intersect_even_ids = encrypt_intersect_even_ids.mapValues( lambda v: 1) self.transfer_variable.host_intersect_ids.remote( remote_intersect_even_ids, role=consts.GUEST, idx=0) LOGGER.info(f"Remote host intersect ids to Guest") # recv intersect ids intersect_ids = None if self.sync_intersect_ids: encrypt_intersect_odd_ids = self.transfer_variable.intersect_ids.get( idx=0) intersect_odd_ids_pair = encrypt_intersect_odd_ids.join( prvkey_ids_process_pair, lambda e, h: h) intersect_odd_ids = intersect_odd_ids_pair.map(lambda k, v: (v, 1)) intersect_ids = intersect_odd_ids.union(intersect_even_ids) LOGGER.info("Get intersect ids from Guest") return intersect_ids
def key_derivation(self, target_num): """ Derive a list of keys for encryption and transmission :param target_num: N in k-N OT :return: List[ObliviousTransferKey] """ LOGGER.info( "enter sender key derivation phase for target num = {}".format( target_num)) # 1. Choose a random scalar (y) from Z^q, calculate S and T to verify its legality y, s, t = self._gen_legal_y_s_t() # 2. Send S to the receiver, if it is illegal addressed by the receiver, regenerate y, S, T attempt_count = 0 while True: self.transfer_variable.s.remote(s, suffix=(attempt_count, ), role=consts.GUEST, idx=0) # federation.remote(obj=s, # name=self.transfer_variable.s.name, # tag=self.transfer_variable.generate_transferid(self.transfer_variable.s, attempt_count), # role=consts.GUEST, # idx=0) LOGGER.info( "sent S to guest for the {}-th time".format(attempt_count)) s_legal = self.transfer_variable.s_legal.get( idx=0, suffix=(attempt_count, )) # s_legal = federation.get(name=self.transfer_variable.s_legal.name, # tag=self.transfer_variable.generate_transferid(self.transfer_variable.s_legal, # attempt_count), # idx=0) if s_legal: LOGGER.info( "receiver confirms the legality of S at {} attempt, will proceed" .format(attempt_count)) break else: LOGGER.info( "receiver rejects this S at {} attempt, will regenerate S". format(attempt_count)) y, s, t = self._gen_legal_y_s_t() attempt_count += 1 # 3. Wait for the receiver to hash S to get T LOGGER.info("waiting for the receiver to hash S to get T") # 4. Get R = cT + xG from the receiver, also init the MAC r = self.transfer_variable.r.get(idx=0) # r = federation.get(name=self.transfer_variable.r.name, # tag=self.transfer_variable.generate_transferid(self.transfer_variable.r), # idx=0) LOGGER.info("got from guest R = " + r.output()) self._init_mac(s, r) # 5. MAC and output the key list key_list = [] yt = self.tec_arithmetic.mul(scalar=y, a=t) # yT yr = self.tec_arithmetic.mul(scalar=y, a=r) # yR for i in range(target_num): iyt = self.tec_arithmetic.mul(scalar=i, a=yt) # iyT diff = self.tec_arithmetic.sub(a=yr, b=iyt) # yR - iyT key = self._mac_tec_element(diff) LOGGER.info("{}-th key generated".format(i)) LOGGER.info("key before MAC = " + diff.output()) LOGGER.info("key = {}".format(key)) key_list.append(ObliviousTransferKey(i, key)) LOGGER.info("all keys successfully generated") return key_list
def fit(self, data_instances, validate_data=None): LOGGER.debug("Start data count: {}".format(data_instances.count())) self._abnormal_detection(data_instances) self.check_abnormal_values(data_instances) self.init_schema(data_instances) # validation_strategy = self.init_validation_strategy(data_instances, validate_data) pubkey = self.cipher.gen_paillier_pubkey(enable=self.use_encrypt, suffix=('fit',)) if self.use_encrypt: self.cipher_operator.set_public_key(pubkey) self.model_weights = self._init_model_variables(data_instances) w = self.cipher_operator.encrypt_list(self.model_weights.unboxed) self.model_weights = LogisticRegressionWeights(w, self.model_weights.fit_intercept) # LOGGER.debug("After init, model_weights: {}".format(self.model_weights.unboxed)) mini_batch_obj = MiniBatch(data_inst=data_instances, batch_size=self.batch_size) total_batch_num = mini_batch_obj.batch_nums if self.use_encrypt: re_encrypt_times = (total_batch_num - 1) // self.re_encrypt_batches + 1 LOGGER.debug("re_encrypt_times is :{}, batch_size: {}, total_batch_num: {}, re_encrypt_batches: {}".format( re_encrypt_times, self.batch_size, total_batch_num, self.re_encrypt_batches)) self.cipher.set_re_cipher_time(re_encrypt_times) total_data_num = data_instances.count() LOGGER.debug("Current data count: {}".format(total_data_num)) model_weights = self.model_weights self.prev_round_weights = copy.deepcopy(model_weights) degree = 0 while self.n_iter_ < self.max_iter + 1: batch_data_generator = mini_batch_obj.mini_batch_data_generator() if ((self.n_iter_ + 1) % self.aggregate_iters == 0) or self.n_iter_ == self.max_iter: weight = self.aggregator.aggregate_then_get(model_weights, degree=degree, suffix=self.n_iter_) # LOGGER.debug("Before aggregate: {}, degree: {} after aggregated: {}".format( # model_weights.unboxed / degree, # degree, # weight.unboxed)) self.model_weights = LogisticRegressionWeights(weight.unboxed, self.fit_intercept) if not self.use_encrypt: loss = self._compute_loss(data_instances, self.prev_round_weights) self.aggregator.send_loss(loss, degree=degree, suffix=(self.n_iter_,)) LOGGER.info("n_iters: {}, loss: {}".format(self.n_iter_, loss)) degree = 0 self.is_converged = self.aggregator.get_converge_status(suffix=(self.n_iter_,)) LOGGER.info("n_iters: {}, is_converge: {}".format(self.n_iter_, self.is_converged)) if self.is_converged or self.n_iter_ == self.max_iter: break model_weights = self.model_weights batch_num = 0 for batch_data in batch_data_generator: n = batch_data.count() degree += n LOGGER.debug('before compute_gradient') f = functools.partial(self.gradient_operator.compute_gradient, coef=model_weights.coef_, intercept=model_weights.intercept_, fit_intercept=self.fit_intercept) grad = batch_data.applyPartitions(f).reduce(fate_operator.reduce_add) grad /= n if self.use_proximal: # use additional proximal term model_weights = self.optimizer.update_model(model_weights, grad=grad, has_applied=False, prev_round_weights=self.prev_round_weights) else: model_weights = self.optimizer.update_model(model_weights, grad=grad, has_applied=False) if self.use_encrypt and batch_num % self.re_encrypt_batches == 0: LOGGER.debug("Before accept re_encrypted_model, batch_iter_num: {}".format(batch_num)) w = self.cipher.re_cipher(w=model_weights.unboxed, iter_num=self.n_iter_, batch_iter_num=batch_num) model_weights = LogisticRegressionWeights(w, self.fit_intercept) batch_num += 1 # validation_strategy.validate(self, self.n_iter_) self.n_iter_ += 1 self.set_summary(self.get_model_summary()) LOGGER.info("Finish Training task, total iters: {}".format(self.n_iter_))
def sync_tree_node_queue(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) self.cur_layer_nodes = self.transfer_inst.tree_node_queue.get( idx=0, suffix=(dep, ))
def sync_node_positions(self, dep, idx=-1): LOGGER.info("send node positions of depth {}".format(dep)) self.transfer_inst.node_positions.remote(self.inst2node_idx, role=consts.HOST, idx=idx, suffix=(dep,))
def sync_tree(self, ): LOGGER.info("sync tree from guest") self.tree_node = self.transfer_inst.tree.get(idx=0)
def sync_federated_best_splitinfo_host(self, federated_best_splitinfo_host, dep=-1, batch=-1, idx=-1): LOGGER.info("send federated best splitinfo of depth {}, batch {}".format(dep, batch)) self.transfer_inst.federated_best_splitinfo_host.remote(federated_best_splitinfo_host, role=consts.HOST, idx=idx, suffix=(dep, batch,))
def sync_predict_data(self, recv_times): LOGGER.info( "srecv predict data to host, recv times is {}".format(recv_times)) predict_data = self.transfer_inst.predict_data.get( idx=0, suffix=(recv_times, )) return predict_data
def sync_final_split_host(self, dep=-1, batch=-1, idx=-1): LOGGER.info("get host final splitinfo of depth {}, batch {}".format(dep, batch)) final_splitinfo_host = self.transfer_inst.final_splitinfo_host.get(idx=idx, suffix=(dep, batch,)) return final_splitinfo_host if idx == -1 else [final_splitinfo_host]
def get_grad_hess_sum(self, grad_and_hess_table): LOGGER.info("calculate the sum of grad and hess") grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[ 0] + value2[0], value1[1] + value2[1])) return grad, hess
def sync_tree(self, idx=-1): LOGGER.info("sync tree to host") tree_nodes = self.remove_sensitive_info() self.transfer_inst.tree.remote(tree_nodes, role=consts.HOST, idx=idx)
def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param)
def fit(self, data_instances=None, validate_data=None): """ Fit OneVsRest model Parameters: ---------- data_instances: Table of instances """ if self.mode == consts.H**O: raise ValueError( "Currently, One vs Rest is not supported for h**o algorithm") LOGGER.info("mode is {}, role is {}, start to one_vs_rest fit".format( self.mode, self.role)) LOGGER.info("Total classes:{}".format(self.classes)) self.classifier.callback_one_vs_rest = True current_flow_id = self.classifier.flowid summary_dict = {} for label_index, label in enumerate(self.classes): LOGGER.info( "Start to train OneVsRest with label_index:{}, label:{}". format(label_index, label)) classifier = copy.deepcopy(self.classifier) classifier.need_one_vs_rest = False classifier.set_flowid(".".join( [current_flow_id, "model_" + str(label_index)])) if self.has_label: header = data_instances.schema.get("header") data_instances_mask_label = self._mask_data_label( data_instances, label=label) data_instances_mask_label.schema['header'] = header if validate_data is not None: validate_mask_label_data = self._mask_data_label( validate_data, label=label) validate_mask_label_data.schema['header'] = header else: validate_mask_label_data = validate_data LOGGER.info("finish mask label:{}".format(label)) LOGGER.info("start classifier fit") classifier.fit_binary(data_instances_mask_label, validate_data=validate_mask_label_data) else: LOGGER.info("start classifier fit") classifier.fit_binary(data_instances, validate_data=validate_data) _summary = classifier.summary() _summary['one_vs_rest'] = True summary_dict[label] = _summary self.models.append(classifier) LOGGER.info("Finish model_{} training!".format(label_index)) self.classifier.set_summary(summary_dict)