def fit(self): """ start to fit """ LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {},' 'running on distributed backend'.format(self.epoch_idx, self.tree_idx)) self.init_root_node_and_gh_sum() LOGGER.debug('assign samples to root node') self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, 0) tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if dep + 1 == tree_height: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) rest_sample_leaf_pos = self.inst2node_idx.mapValues(lambda x: x[1]) if self.sample_leaf_pos is None: self.sample_leaf_pos = rest_sample_leaf_pos else: self.sample_leaf_pos = self.sample_leaf_pos.union(rest_sample_leaf_pos) # stop fitting break LOGGER.debug('start to fit layer {}'.format(dep)) table_with_assignment = self.update_instances_node_positions() # send current layer node number: self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) split_info, agg_histograms = [], [] for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i+self.max_split_nodes] node_map = self.get_node_map(nodes=cur_to_split) LOGGER.debug('node map is {}'.format(node_map)) LOGGER.debug('computing histogram for batch{} at depth{}'.format(batch_id, dep)) local_histogram = self.get_left_node_local_histogram( cur_nodes=cur_to_split, tree=self.tree_node, g_h=self.g_h, table_with_assign=table_with_assignment, split_points=self.bin_split_points, sparse_point=self.bin_sparse_points, valid_feature=self.valid_features ) LOGGER.debug('federated finding best splits for batch{} at layer {}'.format(batch_id, dep)) self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) agg_histograms += local_histogram split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) LOGGER.debug('got best splits from arbiter') new_layer_node = self.update_tree(self.cur_layer_node, split_info) self.cur_layer_node = new_layer_node self.inst2node_idx, leaf_val = self.assign_instances_to_new_node(table_with_assignment, self.tree_node) # record leaf val if self.sample_leaf_pos is None: self.sample_leaf_pos = leaf_val else: self.sample_leaf_pos = self.sample_leaf_pos.union(leaf_val) LOGGER.debug('assigning instance to new nodes done') self.convert_bin_to_real() self.sample_weights_post_process() LOGGER.debug('fitting tree done')
def merge_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count): t0 = time.time() max_item_num = math.floor(optimal_param.max_bin_pct * sample_count) min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count) bucket_dict = {idx: bucket for idx, bucket in enumerate(bucket_list)} final_max_bin = optimal_param.max_bin LOGGER.debug("Get in merge optimal binning, sample_count: {}, max_item_num: {}, min_item_num: {}," "final_max_bin: {}".format(sample_count, max_item_num, min_item_num, final_max_bin)) min_heap = heap.MinHeap() def _add_heap_nodes(constraint=None): LOGGER.debug("Add heap nodes, constraint: {}, dict_length: {}".format(constraint, len(bucket_dict))) this_non_mixture_num = 0 this_small_size_num = 0 # Make bucket satisfy mixture condition # for i in bucket_dict.keys(): for i in range(len(bucket_dict)): left_bucket = bucket_dict[i] right_bucket = bucket_dict.get(left_bucket.right_neighbor_idx) if left_bucket.right_neighbor_idx == i: raise RuntimeError("left_bucket's right neighbor == itself") if not left_bucket.is_mixed: this_non_mixture_num += 1 if left_bucket.total_count < min_item_num: this_small_size_num += 1 if right_bucket is None: continue # Violate maximum items constraint if left_bucket.total_count + right_bucket.total_count > max_item_num: continue if constraint == 'mixture': if left_bucket.is_mixed or right_bucket.is_mixed: continue elif constraint == 'single_mixture': if left_bucket.is_mixed and right_bucket.is_mixed: continue elif constraint == 'small_size': if left_bucket.total_count >= min_item_num or right_bucket.total_count >= min_item_num: continue elif constraint == 'single_small_size': if left_bucket.total_count >= min_item_num and right_bucket.total_count >= min_item_num: continue heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) return min_heap, this_non_mixture_num, this_small_size_num def _update_bucket_info(b_dict): """ update bucket information """ order_dict = dict() for bucket_idx, item in b_dict.items(): order_dict[bucket_idx] = item.left_bound sorted_order_dict = sorted(order_dict.items(), key=operator.itemgetter(1)) start_idx = 0 for item in sorted_order_dict: bucket_idx = item[0] if start_idx == bucket_idx: start_idx += 1 continue b_dict[start_idx] = b_dict[bucket_idx] b_dict[start_idx].idx = start_idx start_idx += 1 del b_dict[bucket_idx] bucket_num = len(b_dict) for i in range(bucket_num): if i == 0: b_dict[i].set_left_neighbor(None) b_dict[i].set_right_neighbor(i + 1) else: b_dict[i].set_left_neighbor(i - 1) b_dict[i].set_right_neighbor(i + 1) b_dict[bucket_num - 1].set_right_neighbor(None) # for b_dict_idx, bucket in b_dict.items(): # LOGGER.debug("After _update_bucket_info, b_dict_idx: {}, b_idx: {}".format(b_dict_idx, bucket.idx)) return b_dict def _merge_heap(constraint=None, aim_var=0): next_id = max(bucket_dict.keys()) + 1 while aim_var > 0 and not min_heap.is_empty: min_node = min_heap.pop() left_bucket = min_node.left_bucket right_bucket = min_node.right_bucket # Some buckets may be already merged if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict: continue new_bucket = bucket_info.Bucket(idx=next_id, adjustment_factor=optimal_param.adjustment_factor) new_bucket = _init_new_bucket(new_bucket, min_node) bucket_dict[next_id] = new_bucket del bucket_dict[left_bucket.idx] del bucket_dict[right_bucket.idx] min_heap.remove_empty_node(left_bucket.idx) min_heap.remove_empty_node(right_bucket.idx) aim_var = _aim_vars_decrease(constraint, new_bucket, left_bucket, right_bucket, aim_var) _add_node_from_new_bucket(new_bucket, constraint) next_id += 1 return min_heap, aim_var def _add_node_from_new_bucket(new_bucket: bucket_info.Bucket, constraint): left_bucket = bucket_dict.get(new_bucket.left_neighbor_idx) right_bucket = bucket_dict.get(new_bucket.right_neighbor_idx) if constraint == 'mixture': if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num: if not left_bucket.is_mixed and not new_bucket.is_mixed: heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=new_bucket) min_heap.insert(heap_node) if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num: if not right_bucket.is_mixed and not new_bucket.is_mixed: heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) elif constraint == 'single_mixture': if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num: if not (left_bucket.is_mixed and new_bucket.is_mixed): heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=new_bucket) min_heap.insert(heap_node) if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num: if not (right_bucket.is_mixed and new_bucket.is_mixed): heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) elif constraint == 'small_size': if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num: if left_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=new_bucket) min_heap.insert(heap_node) if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num: if right_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) elif constraint == 'single_small_size': if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num: if left_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=new_bucket) min_heap.insert(heap_node) if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num: if right_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) else: if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=new_bucket) min_heap.insert(heap_node) if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num: heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket, right_bucket=right_bucket) min_heap.insert(heap_node) def _init_new_bucket(new_bucket: bucket_info.Bucket, min_node: heap.HeapNode): new_bucket.left_bound = min_node.left_bucket.left_bound new_bucket.right_bound = min_node.right_bucket.right_bound new_bucket.left_neighbor_idx = min_node.left_bucket.left_neighbor_idx new_bucket.right_neighbor_idx = min_node.right_bucket.right_neighbor_idx new_bucket.event_count = min_node.left_bucket.event_count + min_node.right_bucket.event_count new_bucket.non_event_count = min_node.left_bucket.non_event_count + min_node.right_bucket.non_event_count new_bucket.event_total = min_node.left_bucket.event_total new_bucket.non_event_total = min_node.left_bucket.non_event_total left_neighbor_bucket = bucket_dict.get(new_bucket.left_neighbor_idx) if left_neighbor_bucket is not None: left_neighbor_bucket.right_neighbor_idx = new_bucket.idx right_neighbor_bucket = bucket_dict.get(new_bucket.right_neighbor_idx) if right_neighbor_bucket is not None: right_neighbor_bucket.left_neighbor_idx = new_bucket.idx return new_bucket def _aim_vars_decrease(constraint, new_bucket: bucket_info.Bucket, left_bucket, right_bucket, aim_var): if constraint in ['mixture', 'single_mixture']: if not left_bucket.is_mixed: aim_var -= 1 if not right_bucket.is_mixed: aim_var -= 1 if not new_bucket.is_mixed: aim_var += 1 elif constraint in ['small_size', 'single_small_size']: if left_bucket.total_count < min_item_num: aim_var -= 1 if right_bucket.total_count < min_item_num: aim_var -= 1 if new_bucket.total_count < min_item_num: aim_var += 1 else: aim_var = len(bucket_dict) - final_max_bin return aim_var if optimal_param.mixture: LOGGER.debug("Before mixture add, dick length: {}".format(len(bucket_dict))) min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='mixture') min_heap, non_mixture_num = _merge_heap(constraint='mixture', aim_var=non_mixture_num) bucket_dict = _update_bucket_info(bucket_dict) min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_mixture') min_heap, non_mixture_num = _merge_heap(constraint='single_mixture', aim_var=non_mixture_num) LOGGER.debug("After mixture merge, min_heap size: {}, non_mixture_num: {}".format(min_heap.size, non_mixture_num)) bucket_dict = _update_bucket_info(bucket_dict) LOGGER.debug("Before small_size add, dick length: {}".format(len(bucket_dict))) min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='small_size') min_heap, small_size_num = _merge_heap(constraint='small_size', aim_var=small_size_num) bucket_dict = _update_bucket_info(bucket_dict) min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_small_size') min_heap, small_size_num = _merge_heap(constraint='single_small_size', aim_var=small_size_num) bucket_dict = _update_bucket_info(bucket_dict) LOGGER.debug("Before add, dick length: {}".format(len(bucket_dict))) min_heap, non_mixture_num, small_size_num = _add_heap_nodes() LOGGER.debug("After normal add, small_size: {}, min_heap size: {}".format(small_size_num, min_heap.size)) min_heap, total_bucket_num = _merge_heap(aim_var=len(bucket_dict) - final_max_bin) LOGGER.debug("After normal merge, min_heap size: {}".format(min_heap.size)) non_mixture_num = 0 small_size_num = 0 for i, bucket in bucket_dict.items(): if not bucket.is_mixed: non_mixture_num += 1 if bucket.total_count < min_item_num: small_size_num += 1 bucket_res = list(bucket_dict.values()) bucket_res = sorted(bucket_res, key=lambda bucket: bucket.left_bound) LOGGER.debug("Before return merge_optimal_binning, non_mixture_num: {}, small_size_num: {}," "min_heap size: {}".format(non_mixture_num, small_size_num, min_heap.size)) LOGGER.debug("Before return, dick length: {}".format(len(bucket_dict))) LOGGER.info(f"Consume time: {time.time() - t0}") return bucket_res, non_mixture_num, small_size_num
def memory_fit(self): """ fitting using memory backend """ LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {},' 'running on memory backend'.format(self.epoch_idx, self.tree_idx)) self.init_root_node_and_gh_sum() g, h = self.get_g_h_arr() self.init_memory_hist_builder(g, h, self.arr_bin_data, self.bin_num + self.use_missing) # last missing bin root_indices = self.init_node2index(len(self.arr_bin_data)) self.cur_layer_node[0].inst_indices = root_indices # root node tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if dep + 1 == tree_height: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) break self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) node_map = self.get_node_map(self.cur_layer_node) node_hists = [] for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes] for node in cur_to_split: if node.id in node_map: hist = self.sklearn_compute_agg_hist(node.inst_indices) hist_bag = HistogramBag(hist, tensor_type='array') hist_bag.hid = node.id hist_bag.p_hid = node.parent_nodeid node_hists.append(hist_bag) self.sync_local_node_histogram(node_hists, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) node_hists = [] split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) new_layer_node = self.update_tree(self.cur_layer_node, split_info) node2inst_idx = [] for node in self.cur_layer_node: if node.is_leaf: continue l, r = self.assign_arr_inst(node, self.arr_bin_data, node.inst_indices, missing_bin_index=self.bin_num) node2inst_idx.append(l) node2inst_idx.append(r) assert len(node2inst_idx) == len(new_layer_node) for node, indices in zip(new_layer_node, node2inst_idx): node.inst_indices = indices self.cur_layer_node = new_layer_node sample_indices, weights = [], [] for node in self.tree_node: if node.is_leaf: sample_indices += list(node.inst_indices) weights += [node.weight] * len(node.inst_indices) else: node.bid = self.bin_split_points[node.fid][int(node.bid)] # post-processing of memory backend fit sample_id = self.sample_id_arr[sample_indices] self.leaf_count = {} for node in self.tree_node: if node.is_leaf: self.leaf_count[node.id] = len(node.inst_indices) LOGGER.debug('leaf count is {}'.format(self.leaf_count)) sample_id_type = type(self.g_h.take(1)[0][0]) self.sample_weights = session.parallelize([(sample_id_type(id_), weight) for id_, weight in zip(sample_id, weights)], include_key=True, partition=self.data_bin.partitions)
def set_transfer_variable(self): if self.transfer_variable is not None: LOGGER.debug( "set flowid to transfer_variable, flowid: {}".format(self.flowid) ) self.transfer_variable.set_flowid(self.flowid)
def init_bucket(self, data_instances): header = data_overview.get_header(data_instances) self._default_setting(header) init_bucket_param = copy.deepcopy(self.params) init_bucket_param.bin_num = self.optimal_param.init_bin_nums if self.optimal_param.init_bucket_method == consts.QUANTILE: init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False) else: init_binning_obj = BucketBinning(params=init_bucket_param) init_binning_obj.set_bin_inner_param(self.bin_inner_param) init_split_points = init_binning_obj.fit_split_points(data_instances) is_sparse = data_overview.is_sparse_data(data_instances) bucket_dict = dict() for col_name, sps in init_split_points.items(): # bucket_list = [bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp) # for idx, sp in enumerate(sps)] bucket_list = [] for idx, sp in enumerate(sps): bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp) if idx == 0: bucket.left_bound = -math.inf bucket.set_left_neighbor(None) else: bucket.left_bound = sps[idx - 1] bucket.event_total = self.event_total bucket.non_event_total = self.non_event_total bucket_list.append(bucket) bucket_list[-1].set_right_neighbor(None) bucket_dict[col_name] = bucket_list LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, " f"length of list: {len(bucket_list)}") # bucket_table = data_instances.mapPartitions2(convert_func) # bucket_table = bucket_table.reduce(self.merge_bucket_list, key_func=lambda key: key[1]) from fate_arch.common.versions import get_eggroll_version version = get_eggroll_version() if version.startswith('2.0'): convert_func = functools.partial(self.convert_data_to_bucket_old, split_points=init_split_points, headers=self.header, bucket_dict=copy.deepcopy(bucket_dict), is_sparse=is_sparse, get_bin_num_func=self.get_bin_num) summary_dict = data_instances.mapPartitions(convert_func, use_previous_behavior=False) # summary_dict = summary_dict.reduce(self.copy_merge, key_func=lambda key: key[1]) from federatedml.util.reduce_by_key import reduce bucket_table = reduce(summary_dict, self.merge_bucket_list, key_func=lambda key: key[1]) elif version.startswith('2.2'): convert_func = functools.partial(self.convert_data_to_bucket, split_points=init_split_points, headers=self.header, bucket_dict=copy.deepcopy(bucket_dict), is_sparse=is_sparse, get_bin_num_func=self.get_bin_num) bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list) bucket_table = dict(bucket_table.collect()) else: raise RuntimeError(f"Cannot recognized eggroll version: {version}") for k, v in bucket_table.items(): LOGGER.debug(f"[feature] {k}, length of list: {len(v)}") LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) bucket_table = [(k, v) for k, v in bucket_table.items()] LOGGER.debug("bucket_table: {}, length: {}".format(type(bucket_table), len(bucket_table))) bucket_table = session.parallelize(bucket_table, include_key=True, partition=data_instances.partitions) return bucket_table
def fit_binary(self, data_instances, validate_data): self._abnormal_detection(data_instances) self.check_abnormal_values(data_instances) self.check_abnormal_values(validate_data) # self.validation_strategy = self.init_validation_strategy(data_instances, validate_data) self.callback_list.on_train_begin(data_instances, validate_data) LOGGER.debug(f"MODEL_STEP Start fin_binary, data count: {data_instances.count()}") self.header = self.get_header(data_instances) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() if self.transfer_variable.use_async.get(idx=0): LOGGER.debug(f"set_use_async") self.gradient_loss_operator.set_use_async() self.batch_generator.initialize_batch_generator(data_instances) self.gradient_loss_operator.set_total_batch_nums(self.batch_generator.batch_nums) self.encrypted_calculator = [EncryptModeCalculator(self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums)] LOGGER.info("Start initialize model.") model_shape = self.get_features_shape(data_instances) if self.init_param_obj.fit_intercept: self.init_param_obj.fit_intercept = False if not self.component_properties.is_warm_start: w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights(w, fit_intercept=self.init_param_obj.fit_intercept) else: self.callback_warm_start_init_iter(self.n_iter_) while self.n_iter_ < self.max_iter: self.callback_list.on_epoch_begin(self.n_iter_) LOGGER.info("iter:" + str(self.n_iter_)) batch_data_generator = self.batch_generator.generate_batch_data() batch_index = 0 self.optimizer.set_iters(self.n_iter_) for batch_data in batch_data_generator: # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst' batch_feat_inst = batch_data # LOGGER.debug(f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}") optim_host_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index) # LOGGER.debug('optim_host_gradient: {}'.format(optim_host_gradient)) self.gradient_loss_operator.compute_loss(self.model_weights, self.optimizer, self.n_iter_, batch_index, self.cipher_operator) self.model_weights = self.optimizer.update_model(self.model_weights, optim_host_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info(suffix=(self.n_iter_,)) LOGGER.info("Get is_converged flag from arbiter:{}".format(self.is_converged)) LOGGER.info("iter: {}, is_converged: {}".format(self.n_iter_, self.is_converged)) LOGGER.debug(f"flowid: {self.flowid}, step_index: {self.n_iter_}") self.callback_list.on_epoch_end(self.n_iter_) self.n_iter_ += 1 if self.stop_training: break if self.is_converged: break self.callback_list.on_train_end() self.set_summary(self.get_model_summary())
def fit(self, data_instances, validate_data=None): """ Train poisson regression model of role host Parameters ---------- data_instances: DTable of Instance, input data """ LOGGER.info("Enter hetero_poisson host") self._abnormal_detection(data_instances) self.validation_strategy = self.init_validation_strategy( data_instances, validate_data) self.header = self.get_header(data_instances) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() self.batch_generator.initialize_batch_generator(data_instances) self.encrypted_calculator = [ EncryptModeCalculator( self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums) ] LOGGER.info("Start initialize model.") model_shape = self.get_features_shape(data_instances) if self.init_param_obj.fit_intercept: self.init_param_obj.fit_intercept = False w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) self.model_weights = LinearModelWeights( w, fit_intercept=self.fit_intercept) while self.n_iter_ < self.max_iter: LOGGER.info("iter:" + str(self.n_iter_)) batch_data_generator = self.batch_generator.generate_batch_data() self.optimizer.set_iters(self.n_iter_) batch_index = 0 for batch_data in batch_data_generator: batch_feat_inst = self.transform(batch_data) optim_host_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index) self.gradient_loss_operator.compute_loss( batch_feat_inst, self.model_weights, self.encrypted_calculator, self.optimizer, self.n_iter_, batch_index, self.cipher_operator) self.model_weights = self.optimizer.update_model( self.model_weights, optim_host_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("Get is_converged flag from arbiter:{}".format( self.is_converged)) if self.validation_strategy: LOGGER.debug('Poisson host running validation') self.validation_strategy.validate(self, self.n_iter_) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break self.n_iter_ += 1 LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) if self.is_converged: break if not self.is_converged: LOGGER.info("Reach max iter {}, train model finish!".format( self.max_iter)) if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self.get_model_summary())
def set_model_param(self, model_param): self.nn.restore_model(model_param.model_bytes) self.store_header = list(model_param.header) LOGGER.debug('stored header load, is {}'.format(self.store_header))
def debug_data_inst(data_inst): collect_data = list(data_inst.collect()) LOGGER.debug('showing DTable') for d in collect_data: LOGGER.debug('key {} id {}, features {} label {}'.format(d[0], d[1].inst_id, d[1].features, d[1].label))
def predict(self, data_inst, ret_format='std'): # standard format, leaf indices, raw score assert ret_format in ['std', 'leaf', 'raw'], 'illegal ret format' LOGGER.info('running prediction') cache_dataset_key = self.predict_data_cache.get_data_key(data_inst) processed_data = self.data_and_header_alignment(data_inst) last_round = self.predict_data_cache.predict_data_last_round( cache_dataset_key) self.sync_predict_round(last_round) rounds = len(self.boosting_model_list) // self.booster_dim trees = [] LOGGER.debug( 'round involved in prediction {}, last round is {}, data key {}'. format(list(range(last_round, rounds)), last_round, cache_dataset_key)) for idx in range(last_round, rounds): for booster_idx in range(self.booster_dim): tree = self.load_booster( self.booster_meta, self.boosting_model_list[idx * self.booster_dim + booster_idx], idx, booster_idx) trees.append(tree) predict_cache = None tree_num = len(trees) if last_round != 0: predict_cache = self.predict_data_cache.predict_data_at( cache_dataset_key, min(rounds, last_round)) LOGGER.info('load predict cache of round {}'.format( min(rounds, last_round))) if tree_num == 0 and predict_cache is not None and not (ret_format == 'leaf'): return self.score_to_predict_result(data_inst, predict_cache) predict_rs = self.boosting_fast_predict( processed_data, trees=trees, predict_cache=predict_cache, pred_leaf=(ret_format == 'leaf')) if ret_format == 'leaf': return predict_rs # predict result is leaf position self.predict_data_cache.add_data(cache_dataset_key, predict_rs, cur_boosting_round=rounds) LOGGER.debug('adding predict rs {}'.format(predict_rs)) LOGGER.debug('last round is {}'.format( self.predict_data_cache.predict_data_last_round( cache_dataset_key))) if ret_format == 'raw': return predict_rs else: return self.score_to_predict_result(data_inst, predict_rs)
def fit(self, data_inst, validate_data=None): self.validation_strategy = self.init_validation_strategy( data_inst, validate_data) self._build_model() self.prepare_batch_data(self.batch_generator, data_inst) if not self.input_shape: self.model.set_empty() self._set_loss_callback_info() cur_epoch = 0 while cur_epoch < self.epochs: LOGGER.debug("cur epoch is {}".format(cur_epoch)) epoch_loss = 0 for batch_idx in range(len(self.data_x)): self.model.train(self.data_x[batch_idx], self.data_y[batch_idx], cur_epoch, batch_idx) self.reset_flowid() metrics = self.model.evaluate(self.data_x[batch_idx], self.data_y[batch_idx], cur_epoch, batch_idx) self.recovery_flowid() LOGGER.debug("metrics is {}".format(metrics)) batch_loss = metrics["loss"] epoch_loss += batch_loss epoch_loss /= len(self.data_x) LOGGER.debug("epoch {}' loss is {}".format(cur_epoch, epoch_loss)) self.callback_metric("loss", "train", [Metric(cur_epoch, epoch_loss)]) self.history_loss.append(epoch_loss) if self.validation_strategy: self.validation_strategy.validate(self, cur_epoch) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break is_converge = self.converge_func.is_converge(epoch_loss) self._summary_buf["is_converged"] = is_converge self.transfer_variable.is_converge.remote(is_converge, role=consts.HOST, idx=0, suffix=(cur_epoch, )) if is_converge: LOGGER.debug( "Training process is converged in epoch {}".format( cur_epoch)) break cur_epoch += 1 if cur_epoch == self.epochs: LOGGER.debug( "Training process reach max training epochs {} and not converged" .format(self.epochs)) if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self._get_model_summary())
def fit(self, data_instances): """ Apply binning method for both data instances in local party as well as the other one. Afterwards, calculate the specific metric value for specific columns. Currently, iv is support for binary labeled data only. """ LOGGER.info("Start feature binning fit and transform") self._abnormal_detection(data_instances) # self._parse_cols(data_instances) self._setup_bin_inner_param(data_instances, self.model_param) self.binning_obj.fit_split_points(data_instances) if self.model_param.skip_static: self.transform(data_instances) return self.data_output label_counts = data_overview.count_labels(data_instances) if label_counts > 2: raise ValueError( "Iv calculation support binary-data only in this version.") data_instances = data_instances.mapValues(self.load_data) self.set_schema(data_instances) label_table = data_instances.mapValues(lambda x: x.label) if self.model_param.local_only: LOGGER.info("This is a local only binning fit") self.binning_obj.cal_local_iv(data_instances, label_table=label_table) self.transform(data_instances) self.set_summary(self.binning_obj.bin_results.summary()) LOGGER.debug(f"Summary is: {self.summary()}") return self.data_output cipher = PaillierEncrypt() cipher.generate_key() f = functools.partial(self.encrypt, cipher=cipher) encrypted_label_table = label_table.mapValues(f) self.transfer_variable.encrypted_label.remote(encrypted_label_table, role=consts.HOST, idx=-1) LOGGER.info("Sent encrypted_label_table to host") self.binning_obj.cal_local_iv(data_instances, label_table=label_table) encrypted_bin_infos = self.transfer_variable.encrypted_bin_sum.get( idx=-1) # LOGGER.debug("encrypted_bin_sums: {}".format(encrypted_bin_sums)) total_summary = self.binning_obj.bin_results.summary() LOGGER.info("Get encrypted_bin_sum from host") for host_idx, encrypted_bin_info in enumerate(encrypted_bin_infos): host_party_id = self.component_properties.host_party_idlist[ host_idx] encrypted_bin_sum = encrypted_bin_info['encrypted_bin_sum'] host_bin_methods = encrypted_bin_info['bin_method'] category_names = encrypted_bin_info['category_names'] result_counts = self.__decrypt_bin_sum(encrypted_bin_sum, cipher) LOGGER.debug( "Received host {} result, length of buckets: {}".format( host_idx, len(result_counts))) LOGGER.debug("category_name: {}, host_bin_methods: {}".format( category_names, host_bin_methods)) # if self.model_param.method == consts.OPTIMAL: if host_bin_methods == consts.OPTIMAL: optimal_binning_params = encrypted_bin_info['optimal_params'] host_model_params = copy.deepcopy(self.model_param) host_model_params.bin_num = optimal_binning_params.get( 'bin_num') host_model_params.optimal_binning_param.metric_method = optimal_binning_params.get( 'metric_method') host_model_params.optimal_binning_param.mixture = optimal_binning_params.get( 'mixture') host_model_params.optimal_binning_param.max_bin_pct = optimal_binning_params.get( 'max_bin_pct') host_model_params.optimal_binning_param.min_bin_pct = optimal_binning_params.get( 'min_bin_pct') self.binning_obj.event_total, self.binning_obj.non_event_total = self.get_histogram( data_instances) optimal_binning_cols = { x: y for x, y in result_counts.items() if x not in category_names } host_binning_obj = self.optimal_binning_sync( optimal_binning_cols, data_instances.count(), data_instances.partitions, host_idx, host_model_params) category_bins = { x: y for x, y in result_counts.items() if x in category_names } host_binning_obj.cal_iv_woe(category_bins, self.model_param.adjustment_factor) else: host_binning_obj = BaseBinning() host_binning_obj.cal_iv_woe(result_counts, self.model_param.adjustment_factor) host_binning_obj.set_role_party(role=consts.HOST, party_id=host_party_id) total_summary = self._merge_summary( total_summary, host_binning_obj.bin_results.summary()) self.host_results.append(host_binning_obj) self.set_schema(data_instances) self.transform(data_instances) LOGGER.info("Finish feature binning fit and transform") self.set_summary(total_summary) LOGGER.debug(f"Summary is: {self.summary()}") return self.data_output
def fit(self, data_instances=None, validate_data=None): """ Train linear model of role arbiter Parameters ---------- data_instances: DTable of Instance, input data """ LOGGER.info("Enter hetero linear model arbiter fit") self.cipher_operator = self.cipher.paillier_keygen( self.model_param.encrypt_param.key_length) self.batch_generator.initialize_batch_generator() self.gradient_loss_operator.set_total_batch_nums( self.batch_generator.batch_num) self.validation_strategy = self.init_validation_strategy( data_instances, validate_data) while self.n_iter_ < self.max_iter: iter_loss = None batch_data_generator = self.batch_generator.generate_batch_data() total_gradient = None self.optimizer.set_iters(self.n_iter_) for batch_index in batch_data_generator: # Compute and Transfer gradient info gradient = self.gradient_loss_operator.compute_gradient_procedure( self.cipher_operator, self.optimizer, self.n_iter_, batch_index) if total_gradient is None: total_gradient = gradient else: total_gradient = total_gradient + gradient training_info = { "iteration": self.n_iter_, "batch_index": batch_index } self.perform_subtasks(**training_info) loss_list = self.gradient_loss_operator.compute_loss( self.cipher_operator, self.n_iter_, batch_index) if len(loss_list) == 1: if iter_loss is None: iter_loss = loss_list[0] else: iter_loss += loss_list[0] # LOGGER.info("Get loss from guest:{}".format(de_loss)) # if converge if iter_loss is not None: iter_loss /= self.batch_generator.batch_num if self.need_call_back_loss: self.callback_loss(self.n_iter_, iter_loss) self.loss_history.append(iter_loss) if self.model_param.early_stop == 'weight_diff': # LOGGER.debug("total_gradient: {}".format(total_gradient)) weight_diff = fate_operator.norm(total_gradient) # LOGGER.info("iter: {}, weight_diff:{}, is_converged: {}".format(self.n_iter_, # weight_diff, self.is_converged)) if weight_diff < self.model_param.tol: self.is_converged = True else: if iter_loss is None: raise ValueError( "Multiple host situation, loss early stop function is not available." "You should use 'weight_diff' instead") self.is_converged = self.converge_func.is_converge(iter_loss) LOGGER.info("iter: {}, loss:{}, is_converged: {}".format( self.n_iter_, iter_loss, self.is_converged)) self.converge_procedure.sync_converge_info(self.is_converged, suffix=(self.n_iter_, )) if self.validation_strategy: LOGGER.debug('Linear Arbiter running validation') self.validation_strategy.validate(self, self.n_iter_) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') self.best_iteration = self.n_iter_ break self.n_iter_ += 1 if self.is_converged: break summary = { "loss_history": self.loss_history, "is_converged": self.is_converged, "best_iteration": self.best_iteration } if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) if self.loss_history is not None and len(self.loss_history) > 0: summary["best_iter_loss"] = self.loss_history[self.best_iteration] self.set_summary(summary) LOGGER.debug("finish running linear model arbiter")
def fit_binary(self, data_instances, validate_data): self._abnormal_detection(data_instances) self.check_abnormal_values(data_instances) self.check_abnormal_values(validate_data) self.validation_strategy = self.init_validation_strategy( data_instances, validate_data) LOGGER.debug( f"MODEL_STEP Start fin_binary, data count: {data_instances.count()}" ) self.header = self.get_header(data_instances) self.cipher_operator = self.cipher.gen_paillier_cipher_operator() self.batch_generator.initialize_batch_generator(data_instances) self.gradient_loss_operator.set_total_batch_nums( self.batch_generator.batch_nums) self.encrypted_calculator = [ EncryptModeCalculator( self.cipher_operator, self.encrypted_mode_calculator_param.mode, self.encrypted_mode_calculator_param.re_encrypted_rate) for _ in range(self.batch_generator.batch_nums) ] LOGGER.info("Start initialize model.") model_shape = self.get_features_shape(data_instances) if self.init_param_obj.fit_intercept: self.init_param_obj.fit_intercept = False w = self.initializer.init_model(model_shape, init_params=self.init_param_obj) # LOGGER.debug("model_shape: {}, w shape: {}, w: {}".format(model_shape, w.shape, w)) self.model_weights = LinearModelWeights( w, fit_intercept=self.init_param_obj.fit_intercept) while self.n_iter_ < self.max_iter: LOGGER.info("iter:" + str(self.n_iter_)) batch_data_generator = self.batch_generator.generate_batch_data() batch_index = 0 self.optimizer.set_iters(self.n_iter_) for batch_data in batch_data_generator: # transforms features of raw input 'batch_data_inst' into more representative features 'batch_feat_inst' batch_feat_inst = self.transform(batch_data) LOGGER.debug( f"MODEL_STEP In Batch {batch_index}, batch data count: {batch_feat_inst.count()}" ) optim_host_gradient, fore_gradient = self.gradient_loss_operator.compute_gradient_procedure( batch_feat_inst, self.encrypted_calculator, self.model_weights, self.optimizer, self.n_iter_, batch_index) # LOGGER.debug('optim_host_gradient: {}'.format(optim_host_gradient)) training_info = { "iteration": self.n_iter_, "batch_index": batch_index } self.update_local_model(fore_gradient, data_instances, self.model_weights.coef_, **training_info) self.gradient_loss_operator.compute_loss( self.model_weights, self.optimizer, self.n_iter_, batch_index, self.cipher_operator) self.model_weights = self.optimizer.update_model( self.model_weights, optim_host_gradient) batch_index += 1 self.is_converged = self.converge_procedure.sync_converge_info( suffix=(self.n_iter_, )) LOGGER.info("Get is_converged flag from arbiter:{}".format( self.is_converged)) if self.validation_strategy: LOGGER.debug('LR host running validation') self.validation_strategy.validate(self, self.n_iter_) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break self.n_iter_ += 1 LOGGER.info("iter: {}, is_converged: {}".format( self.n_iter_, self.is_converged)) if self.is_converged: break if self.validation_strategy and self.validation_strategy.has_saved_best_model( ): self.load_model(self.validation_strategy.cur_best_model) self.set_summary(self.get_model_summary())
def sync_local_node_histogram(self, acc_histogram: List[HistogramBag], suffix): # sending local histogram self.aggregator.send_histogram(acc_histogram, suffix=suffix) LOGGER.debug('local histogram sent at layer {}'.format(suffix[0]))
def one_vs_rest_fit(self, train_data=None, validate_data=None): LOGGER.debug("Class num larger than 2, need to do one_vs_rest") self.one_vs_rest_obj.fit(data_instances=train_data, validate_data=validate_data)
def _sync_class_host(self, class_set): LOGGER.debug("Start to get aggregate classes") class_nums = self.transfer_variable.aggregate_classes.get(idx=0) self.classes = [x for x in range(class_nums)]
def __add__(self, other): LOGGER.debug("In binary_op0, _w: {}".format(self._weights)) return self.binary_op(other, operator.add, inplace=False)
def fit(self, data_inst, validate_data=None): LOGGER.debug('in training, partitions is {}'.format( data_inst.partitions)) LOGGER.info('start to fit a ftl model, ' 'run mode is {},' 'communication efficient mode is {}'.format( self.mode, self.comm_eff)) self.check_host_number() data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data( self.init_intersect_obj(), data_inst, guest_side=True) self.input_dim = self.x_shape[0] # cache data_loader for faster validation self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader self.partitions = data_inst.partitions LOGGER.debug('self partitions is {}'.format(self.partitions)) self.initialize_nn(input_shape=self.x_shape) self.feat_dim = self.nn._model.output_shape[1] self.constant_k = 1 / self.feat_dim self.validation_strategy = self.init_validation_strategy( data_inst, validate_data) self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"unit_name": "iters"})) # compute intermediate result of first epoch self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components( data_loader) for epoch_idx in range(self.epochs): LOGGER.debug('fitting epoch {}'.format(epoch_idx)) host_components = self.exchange_components(self.send_components, epoch_idx=epoch_idx) loss = None for local_round_idx in range(self.local_round): if self.comm_eff: LOGGER.debug( 'running local iter {}'.format(local_round_idx)) grads = self.compute_backward_gradients( host_components, data_loader, epoch_idx=epoch_idx, local_round=local_round_idx) self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff) if local_round_idx == 0: loss = self.compute_loss( host_components, epoch_idx, len(data_loader.get_overlap_indexes())) if local_round_idx + 1 != self.local_round: self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua( data_loader) self.callback_metric("loss", "train", [Metric(epoch_idx, loss)]) self.history_loss.append(loss) # updating variables for next epochs if epoch_idx + 1 == self.epochs: # only need to update phi in last epochs self.phi, _ = self.compute_phi_and_overlap_ua(data_loader) else: # compute phi, phi_product, overlap_ua etc. for next epoch self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components( data_loader) # check early_stopping_rounds if self.validation_strategy is not None: self.validation_strategy.validate(self, epoch_idx) if self.validation_strategy.need_stop(): LOGGER.debug('early stopping triggered') break # check n_iter_no_change if self.n_iter_no_change is True: if self.check_convergence(loss): self.sync_stop_flag(epoch_idx, stop_flag=True) break else: self.sync_stop_flag(epoch_idx, stop_flag=False) LOGGER.debug('fitting epoch {} done, loss is {}'.format( epoch_idx, loss)) self.callback_meta( "loss", "train", MetricMeta(name="train", metric_type="LOSS", extra_metas={"Best": min(self.history_loss)})) self.set_summary(self.generate_summary()) LOGGER.debug('fitting ftl model done')
def preprocess(self): if self.multi_mode == consts.MULTI_OUTPUT: self.booster_dim = 1 LOGGER.debug('multi mode tree dim reset to 1')