def __init__(self, tree_param: DecisionTreeModelParam, valid_feature: dict, epoch_idx: int, tree_idx: int, flow_id: int): super(HomoDecisionTreeArbiter, self).__init__(tree_param) self.splitter = Splitter( self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node, ) self.transfer_inst = HomoDecisionTreeTransferVariable() """ initializing here """ self.valid_features = valid_feature self.tree_node = [] # start from root node self.tree_node_num = 0 self.cur_layer_node = [] self.runtime_idx = 0 self.sitename = consts.ARBITER self.epoch_idx = epoch_idx self.tree_idx = tree_idx # secure aggregator self.set_flowid(flow_id) self.aggregator = DecisionTreeArbiterAggregator(verbose=False) # stored histogram for faster computation {node_id:histogram_bag} self.stored_histograms = {}
def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeHost, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.data_bin_with_position = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.infos = None self.valid_features = None self.pubkey = None self.privakey = None self.tree_id = None self.encrypted_grad_and_hess = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.tree_node_queue = None self.cur_split_nodes = None self.split_maskdict = {} self.missing_dir_maskdict = {} self.tree_ = None self.runtime_idx = 0 self.sitename = consts.HOST
def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.data_bin_with_node_dispatch = None self.node_dispatch = None self.infos = None self.valid_features = None self.encrypter = None self.encrypted_mode_calculator = None self.best_splitinfo_guest = None self.tree_node_queue = None self.cur_split_nodes = None self.tree_ = [] self.tree_node_num = 0 self.split_maskdict = {} self.transfer_inst = HeteroDecisionTreeTransferVariable() self.predict_weights = None self.runtime_idx = 0 self.feature_importances_ = {}
class HeteroDecisionTreeHost(DecisionTree): def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeHost, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.data_bin_with_position = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.infos = None self.valid_features = None self.pubkey = None self.privakey = None self.tree_id = None self.encrypted_grad_and_hess = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.tree_node_queue = None self.cur_split_nodes = None self.split_maskdict = {} self.tree_ = None def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def set_inputinfo(self, data_bin=None, grad_and_hess=None, bin_split_points=None, bin_sparse_points=None): LOGGER.info("set input info") self.data_bin = data_bin self.grad_and_hess = grad_and_hess self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_points def set_valid_features(self, valid_features=None): LOGGER.info("set valid features") self.valid_features = valid_features def encode(self, etype="feature_idx", val=None, nid=None): if etype == "feature_idx": return val if etype == "feature_val": self.split_maskdict[nid] = val return None raise TypeError("encode type %s is not support!" % (str(etype))) @staticmethod def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None): if dtype == "feature_idx": return val if dtype == "feature_val": if nid in split_maskdict: return split_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) return TypeError("decode type %s is not support!" % (str(dtype))) def sync_encrypted_grad_and_hess(self): LOGGER.info("get encrypted grad and hess") self.grad_and_hess = federation.get( name=self.transfer_inst.encrypted_grad_and_hess.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_grad_and_hess), idx=0) def sync_node_positions(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) node_positions = federation.get( name=self.transfer_inst.node_positions.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.node_positions, dep), idx=0) return node_positions def sync_tree_node_queue(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) self.tree_node_queue = federation.get( name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree_node_queue, dep), idx=0) def get_histograms(self, node_map={}): LOGGER.info("start to get node histograms") # self.data_bin_with_position = self.data_bin.join(node_positions, lambda v1, v2: (v1, v2)) histograms = FeatureHistogram.calculate_histogram( self.data_bin_with_position, self.grad_and_hess, self.bin_split_points, self.bin_sparse_points, self.valid_features, node_map) LOGGER.info("begin to accumulate histograms") acc_histograms = FeatureHistogram.accumulate_histogram(histograms) LOGGER.info("acc histogram shape is {}".format(len(acc_histograms))) return acc_histograms def sync_encrypted_splitinfo_host(self, encrypted_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send encrypted splitinfo of depth {}, batch {}".format( dep, batch)) federation.remote( obj=encrypted_splitinfo_host, name=self.transfer_inst.encrypted_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_splitinfo_host, dep, batch), role=consts.GUEST, idx=0) def sync_federated_best_splitinfo_host(self, dep=-1, batch=-1): LOGGER.info( "get federated best splitinfo of depth {}, batch {}".format( dep, batch)) federated_best_splitinfo_host = federation.get( name=self.transfer_inst.federated_best_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.federated_best_splitinfo_host, dep, batch), idx=0) return federated_best_splitinfo_host def sync_final_splitinfo_host(self, splitinfo_host, federated_best_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfos = [] for i in range(len(splitinfo_host)): best_idx, best_gain = federated_best_splitinfo_host[i] if best_idx != -1: assert splitinfo_host[i][best_idx].sitename == consts.HOST splitinfo = splitinfo_host[i][best_idx] splitinfo.best_fid = self.encode("feature_idx", splitinfo.best_fid) assert splitinfo.best_fid is not None splitinfo.best_bid = self.encode("feature_val", splitinfo.best_bid, self.cur_split_nodes[i].id) splitinfo.gain = best_gain else: splitinfo = SplitInfo(sitename=consts.HOST, best_fid=-1, best_bid=-1, gain=best_gain) final_splitinfos.append(splitinfo) federation.remote(obj=final_splitinfos, name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.final_splitinfo_host, dep, batch), role=consts.GUEST, idx=0) def sync_dispatch_node_host(self, dep): LOGGER.info("get node from host to dispath, depth is {}".format(dep)) dispatch_node_host = federation.get( name=self.transfer_inst.dispatch_node_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host, dep), idx=0) return dispatch_node_host @staticmethod def dispatch_node(value1, value2, decoder=None, split_maskdict=None, bin_sparse_points=None): if len(value1) <= 2: return value1 unleaf_state, fid, bid, nodeid, left_nodeid, right_nodeid = value1 fid = decoder("feature_idx", fid, split_maskdict=split_maskdict) bid = decoder("feature_val", bid, nodeid, split_maskdict=split_maskdict) if value2.features.get_data(fid, bin_sparse_points[fid]) <= bid: return unleaf_state, left_nodeid else: return unleaf_state, right_nodeid def sync_dispatch_node_host_result(self, dispatch_node_host_result, dep=-1): LOGGER.info("send host dispatch result, depth is {}".format(dep)) federation.remote( obj=dispatch_node_host_result, name=self.transfer_inst.dispatch_node_host_result.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host_result, dep), role=consts.GUEST, idx=0) def find_dispatch(self, dispatch_node_host, dep=-1): LOGGER.info("start to find host dispath of depth {}".format(dep)) dispatch_node_method = functools.partial( self.dispatch_node, decoder=self.decode, split_maskdict=self.split_maskdict, bin_sparse_points=self.bin_sparse_points) dispatch_node_host_result = dispatch_node_host.join( self.data_bin, dispatch_node_method) self.sync_dispatch_node_host_result(dispatch_node_host_result, dep) def sync_tree(self): LOGGER.info("sync tree from guest") self.tree_ = federation.get(name=self.transfer_inst.tree.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree), idx=0) def remove_duplicated_split_nodes(self, split_nid_used): LOGGER.info("remove duplicated nodes from split mask dict") duplicated_nodes = set( self.split_maskdict.keys()) - set(split_nid_used) for nid in duplicated_nodes: del self.split_maskdict[nid] def convert_bin_to_real(self): LOGGER.info("convert tree node bins to real value") split_nid_used = [] for i in range(len(self.tree_)): if self.tree_[i].is_leaf is True: continue if self.tree_[i].sitename == consts.HOST: fid = self.decode("feature_idx", self.tree_[i].fid, split_maskdict=self.split_maskdict) bid = self.decode("feature_val", self.tree_[i].bid, self.tree_[i].id, self.split_maskdict) real_splitval = self.encode("feature_val", self.bin_split_points[fid][bid], self.tree_[i].id) self.tree_[i].bid = real_splitval split_nid_used.append(self.tree_[i].id) self.remove_duplicated_split_nodes(split_nid_used) @staticmethod def traverse_tree(predict_state, data_inst, tree_=None, decoder=None, split_maskdict=None): tag, nid = predict_state if tag == 0: return (tag, nid) while tree_[nid].sitename != consts.GUEST: fid = decoder("feature_idx", tree_[nid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict) if data_inst.features.get_data(fid, 0) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid return (1, nid) def sync_predict_finish_tag(self, recv_times): LOGGER.info( "get the {}-th predict finish tag from guest".format(recv_times)) finish_tag = federation.get( name=self.transfer_inst.predict_finish_tag.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_finish_tag, recv_times), idx=0) return finish_tag def sync_predict_data(self, recv_times): LOGGER.info( "srecv predict data to host, recv times is {}".format(recv_times)) predict_data = federation.get( name=self.transfer_inst.predict_data.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data, recv_times), idx=0) return predict_data def sync_data_predicted_by_host(self, predict_data, send_times): LOGGER.info( "send predicted data by host, send times is {}".format(send_times)) federation.remote(obj=predict_data, name=self.transfer_inst.predict_data_by_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data_by_host, send_times), role=consts.GUEST, idx=0) def fit(self): LOGGER.info("begin to fit host decision tree") self.sync_encrypted_grad_and_hess() for dep in range(self.max_depth): self.sync_tree_node_queue(dep) if len(self.tree_node_queue) == 0: break node_positions = self.sync_node_positions(dep) self.data_bin_with_position = self.data_bin.join( node_positions, lambda v1, v2: (v1, v2)) batch = 0 for i in range(0, len(self.tree_node_queue), self.max_split_nodes): self.cur_split_nodes = self.tree_node_queue[i:i + self. max_split_nodes] node_map = {} node_num = 0 for tree_node in self.cur_split_nodes: node_map[tree_node.id] = node_num node_num += 1 acc_histograms = self.get_histograms(node_map=node_map) splitinfo_host, encrypted_splitinfo_host = self.splitter.find_split_host( acc_histograms, self.valid_features, self.data_bin._partitions) self.sync_encrypted_splitinfo_host(encrypted_splitinfo_host, dep, batch) federated_best_splitinfo_host = self.sync_federated_best_splitinfo_host( dep, batch) self.sync_final_splitinfo_host(splitinfo_host, federated_best_splitinfo_host, dep, batch) batch += 1 dispatch_node_host = self.sync_dispatch_node_host(dep) self.find_dispatch(dispatch_node_host, dep) self.sync_tree() self.convert_bin_to_real() LOGGER.info("end to fit guest decision tree") def predict(self, data_inst): LOGGER.info("start to predict!") site_guest_send_times = 0 while True: finish_tag = self.sync_predict_finish_tag(site_guest_send_times) if finish_tag is True: break predict_data = self.sync_predict_data(site_guest_send_times) traverse_tree = functools.partial( self.traverse_tree, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict) predict_data = predict_data.join(data_inst, traverse_tree) self.sync_data_predicted_by_host(predict_data, site_guest_send_times) site_guest_send_times += 1 LOGGER.info("predict finish!") def get_model_meta(self): model_meta = DecisionTreeModelMeta() model_meta.max_depth = self.max_depth model_meta.min_sample_split = self.min_sample_split model_meta.min_impurity_split = self.min_impurity_split model_meta.min_leaf_node = self.min_leaf_node return model_meta def set_model_meta(self, model_meta): self.max_depth = model_meta.max_depth self.min_sample_split = model_meta.min_sample_split self.min_impurity_split = model_meta.min_impurity_split self.min_leaf_node = model_meta.min_leaf_node def get_model_param(self): model_param = DecisionTreeModelParam() for node in self.tree_: model_param.tree_.add(id=node.id, sitename=node.sitename, fid=node.fid, bid=node.bid, weight=node.weight, is_leaf=node.is_leaf, left_nodeid=node.left_nodeid, right_nodeid=node.right_nodeid) model_param.split_maskdict.update(self.split_maskdict) return model_param def set_model_param(self, model_param): self.tree_ = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid) self.tree_.append(_node) self.split_maskdict = dict(model_param.split_maskdict) def get_model(self): model_meta = self.get_model_meta() model_param = self.get_model_param() return model_meta, model_param def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param)
class HomoDecisionTreeArbiter(DecisionTree): def __init__(self, tree_param: DecisionTreeModelParam, valid_feature: dict, epoch_idx: int, tree_idx: int, flow_id: int): super(HomoDecisionTreeArbiter, self).__init__(tree_param) self.splitter = Splitter( self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node, ) self.transfer_inst = HomoDecisionTreeTransferVariable() """ initializing here """ self.valid_features = valid_feature self.tree_node = [] # start from root node self.tree_node_num = 0 self.cur_layer_node = [] self.runtime_idx = 0 self.sitename = consts.ARBITER self.epoch_idx = epoch_idx self.tree_idx = tree_idx # secure aggregator self.set_flowid(flow_id) self.aggregator = DecisionTreeArbiterAggregator(verbose=False) # stored histogram for faster computation {node_id:histogram_bag} self.stored_histograms = {} def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def sync_node_sample_numbers(self, suffix): cur_layer_node_num = self.transfer_inst.cur_layer_node_num.get( -1, suffix=suffix) for num in cur_layer_node_num[1:]: assert num == cur_layer_node_num[0] return cur_layer_node_num[0] def federated_find_best_split(self, node_histograms, parallel_partitions=10) -> List[SplitInfo]: # node histograms [[HistogramBag,HistogramBag,...],[HistogramBag,HistogramBag,....],..] LOGGER.debug( 'federated finding best splits,histograms from {} guest received'. format(len(node_histograms))) LOGGER.debug('aggregating histograms .....') acc_histogram = node_histograms best_splits = self.splitter.find_split(acc_histogram, self.valid_features, parallel_partitions, self.sitename, self.use_missing, self.zero_as_missing) return best_splits def sync_best_splits(self, split_info, suffix): LOGGER.debug('sending best split points') self.transfer_inst.best_split_points.remote(split_info, idx=-1, suffix=suffix) def sync_local_histogram(self, suffix) -> List[HistogramBag]: LOGGER.debug('get local histograms') node_local_histogram = self.aggregator.aggregate_histogram( suffix=suffix) LOGGER.debug('num of histograms {}'.format(len(node_local_histogram))) return node_local_histogram def histogram_subtraction(self, left_node_histogram, stored_histograms): # histogram subtraction all_histograms = [] for left_hist in left_node_histogram: all_histograms.append(left_hist) # LOGGER.debug('hist id is {}, pid is {}'.format(left_hist.hid, left_hist.p_hid)) # root node hist if left_hist.hid == 0: continue right_hist = stored_histograms[left_hist.p_hid] - left_hist right_hist.hid, right_hist.p_hid = left_hist.hid + 1, right_hist.p_hid all_histograms.append(right_hist) return all_histograms def fit(self): LOGGER.info( 'begin to fit h**o decision tree, epoch {}, tree idx {}'.format( self.epoch_idx, self.tree_idx)) g_sum, h_sum = self.aggregator.aggregate_root_node_info( suffix=('root_node_sync1', self.epoch_idx)) LOGGER.debug('g_sum is {},h_sum is {}'.format(g_sum, h_sum)) self.aggregator.broadcast_root_info(g_sum, h_sum, suffix=('root_node_sync2', self.epoch_idx)) if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1: self.max_split_nodes += 1 LOGGER.warning( 'an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}' .format(self.max_split_nodes)) for dep in range(self.max_depth): if dep + 1 == self.max_depth: break LOGGER.debug('at dep {}'.format(dep)) split_info = [] # get cur layer node num: cur_layer_node_num = self.sync_node_sample_numbers( suffix=(dep, self.epoch_idx, self.tree_idx)) LOGGER.debug( '{} nodes to split at this layer'.format(cur_layer_node_num)) layer_stored_hist = {} for batch_id, i in enumerate( range(0, cur_layer_node_num, self.max_split_nodes)): LOGGER.debug('cur batch id is {}'.format(batch_id)) left_node_histogram = self.sync_local_histogram( suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) all_histograms = self.histogram_subtraction( left_node_histogram, self.stored_histograms) # store histogram for hist in all_histograms: layer_stored_hist[hist.hid] = hist # FIXME stable parallel_partitions best_splits = self.federated_find_best_split( all_histograms, parallel_partitions=10) split_info += best_splits self.stored_histograms = layer_stored_hist self.sync_best_splits(split_info, suffix=(dep, self.epoch_idx)) LOGGER.debug('best_splits_sent') def predict(self, data_inst=None): """ Do nothing """ LOGGER.debug('start predicting')
class HeteroDecisionTreeGuest(DecisionTree): def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.data_bin_with_node_dispatch = None self.node_dispatch = None self.infos = None self.valid_features = None self.encrypter = None self.encrypted_mode_calculator = None self.best_splitinfo_guest = None self.tree_node_queue = None self.cur_split_nodes = None self.tree_ = [] self.tree_node_num = 0 self.split_maskdict = {} self.transfer_inst = HeteroDecisionTreeTransferVariable() self.predict_weights = None self.runtime_idx = 0 self.feature_importances_ = {} def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def set_runtime_idx(self, runtime_idx): self.runtime_idx = runtime_idx def set_inputinfo(self, data_bin=None, grad_and_hess=None, bin_split_points=None, bin_sparse_points=None): LOGGER.info("set input info") self.data_bin = data_bin self.grad_and_hess = grad_and_hess self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_points def set_encrypter(self, encrypter): LOGGER.info("set encrypter") self.encrypter = encrypter def set_encrypted_mode_calculator(self, encrypted_mode_calculator): self.encrypted_mode_calculator = encrypted_mode_calculator def encrypt(self, val): return self.encrypter.encrypt(val) def decrypt(self, val): return self.encrypter.decrypt(val) def encode(self, etype="feature_idx", val=None, nid=None): if etype == "feature_idx": return val if etype == "feature_val": self.split_maskdict[nid] = val return None raise TypeError("encode type %s is not support!" % (str(etype))) @staticmethod def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None): if dtype == "feature_idx": return val if dtype == "feature_val": if nid in split_maskdict: return split_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) return TypeError("decode type %s is not support!" % (str(dtype))) def set_valid_features(self, valid_features=None): LOGGER.info("set valid features") self.valid_features = valid_features def sync_encrypted_grad_and_hess(self): LOGGER.info("send encrypted grad and hess to host") encrypted_grad_and_hess = self.encrypt_grad_and_hess() federation.remote(obj=encrypted_grad_and_hess, name=self.transfer_inst.encrypted_grad_and_hess.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_grad_and_hess), role=consts.HOST, idx=-1) def encrypt_grad_and_hess(self): LOGGER.info("start to encrypt grad and hess") encrypted_grad_and_hess = self.encrypted_mode_calculator.encrypt( self.grad_and_hess) return encrypted_grad_and_hess def get_grad_hess_sum(self, grad_and_hess_table): LOGGER.info("calculate the sum of grad and hess") grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[ 0] + value2[0], value1[1] + value2[1])) return grad, hess def dispatch_all_node_to_root(self, root_id=0): LOGGER.info("dispatch all node to root") self.node_dispatch = self.data_bin.mapValues(lambda data_inst: (1, root_id)) def get_histograms(self, node_map={}): LOGGER.info("start to get node histograms") histograms = FeatureHistogram.calculate_histogram( self.data_bin_with_node_dispatch, self.grad_and_hess, self.bin_split_points, self.bin_sparse_points, self.valid_features, node_map) acc_histograms = FeatureHistogram.accumulate_histogram(histograms) return acc_histograms def sync_tree_node_queue(self, tree_node_queue, dep=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = copy.deepcopy(tree_node_queue) for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) federation.remote(obj=mask_tree_node_queue, name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree_node_queue, dep), role=consts.HOST, idx=-1) def sync_node_positions(self, dep): LOGGER.info("send node positions of depth {}".format(dep)) federation.remote(obj=self.node_dispatch, name=self.transfer_inst.node_positions.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.node_positions, dep), role=consts.HOST, idx=-1) def sync_encrypted_splitinfo_host(self, dep=-1, batch=-1): LOGGER.info("get encrypted splitinfo of depth {}, batch {}".format( dep, batch)) encrypted_splitinfo_host = federation.get( name=self.transfer_inst.encrypted_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_splitinfo_host, dep, batch), idx=-1) return encrypted_splitinfo_host def sync_federated_best_splitinfo_host(self, federated_best_splitinfo_host, dep=-1, batch=-1, idx=-1): LOGGER.info( "send federated best splitinfo of depth {}, batch {}".format( dep, batch)) federation.remote( obj=federated_best_splitinfo_host, name=self.transfer_inst.federated_best_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.federated_best_splitinfo_host, dep, batch), role=consts.HOST, idx=idx) def find_host_split(self, value): cur_split_node, encrypted_splitinfo_host = value sum_grad = cur_split_node.sum_grad sum_hess = cur_split_node.sum_hess best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_idx = -1 for i in range(len(encrypted_splitinfo_host)): sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i] sum_grad_l = self.decrypt(sum_grad_l) sum_hess_l = self.decrypt(sum_hess_l) sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l, sum_hess_l, sum_grad_r, sum_hess_r) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_idx = i best_gain = self.encrypt(best_gain) return best_idx, best_gain def federated_find_split(self, dep=-1, batch=-1): LOGGER.info("federated find split of depth {}, batch {}".format( dep, batch)) encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host( dep, batch) for i in range(len(encrypted_splitinfo_host)): encrypted_splitinfo_host_table = eggroll.parallelize( zip(self.cur_split_nodes, encrypted_splitinfo_host[i]), include_key=False, partition=self.data_bin._partitions) splitinfos = encrypted_splitinfo_host_table.mapValues( self.find_host_split).collect() best_splitinfo_host = [splitinfo[1] for splitinfo in splitinfos] self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep, batch, i) def sync_final_split_host(self, dep=-1, batch=-1): LOGGER.info("get host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfo_host = federation.get( name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.final_splitinfo_host, dep, batch), idx=-1) return final_splitinfo_host def find_best_split_guest_and_host(self, splitinfo_guest_host): best_gain_host = self.decrypt(splitinfo_guest_host[1].gain) best_gain_host_idx = 1 for i in range(1, len(splitinfo_guest_host)): gain_host_i = self.decrypt(splitinfo_guest_host[i].gain) if best_gain_host < gain_host_i: best_gain_host = gain_host_i best_gain_host_idx = i if splitinfo_guest_host[0].gain >= best_gain_host - consts.FLOAT_ZERO: best_splitinfo = splitinfo_guest_host[0] else: best_splitinfo = splitinfo_guest_host[best_gain_host_idx] best_splitinfo.sum_grad = self.decrypt(best_splitinfo.sum_grad) best_splitinfo.sum_hess = self.decrypt(best_splitinfo.sum_hess) best_splitinfo.gain = best_gain_host return best_splitinfo def merge_splitinfo(self, splitinfo_guest, splitinfo_host): LOGGER.info("merge splitinfo") merge_infos = [] for i in range(len(splitinfo_guest)): splitinfo = [splitinfo_guest[i]] for j in range(len(splitinfo_host)): splitinfo.append(splitinfo_host[j][i]) merge_infos.append(splitinfo) splitinfo_guest_host_table = eggroll.parallelize( merge_infos, include_key=False, partition=self.data_bin._partitions) best_splitinfo_table = splitinfo_guest_host_table.mapValues( self.find_best_split_guest_and_host) best_splitinfos = [ best_splitinfo[1] for best_splitinfo in best_splitinfo_table.collect() ] return best_splitinfos def update_feature_importance(self, splitinfo): if self.feature_importance_type == "split": inc = 1 elif self.feature_importance_type == "gain": inc = splitinfo.gain else: raise ValueError( "feature importance type {} not support yet".format( self.feature_importance_type)) sitename = splitinfo.sitename fid = splitinfo.best_fid if (sitename, fid) not in self.feature_importances_: self.feature_importances_[(sitename, fid)] = 0 self.feature_importances_[(sitename, fid)] += inc def update_tree_node_queue(self, splitinfos, max_depth_reach): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is" .format(len(splitinfos), len(self.tree_node_queue))) new_tree_node_queue = [] for i in range(len(self.tree_node_queue)): sum_grad = self.tree_node_queue[i].sum_grad sum_hess = self.tree_node_queue[i].sum_hess if max_depth_reach or splitinfos[i].gain <= \ self.min_impurity_split + consts.FLOAT_ZERO: self.tree_node_queue[i].is_leaf = True else: self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1 self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.tree_node_queue[i].left_nodeid, sitename=consts.GUEST, sum_grad=splitinfos[i].sum_grad, sum_hess=splitinfos[i].sum_hess, weight=self.splitter.node_weight( splitinfos[i].sum_grad, splitinfos[i].sum_hess)) right_node = Node(id=self.tree_node_queue[i].right_nodeid, sitename=consts.GUEST, sum_grad=sum_grad - splitinfos[i].sum_grad, sum_hess=sum_hess - splitinfos[i].sum_hess, weight=self.splitter.node_weight( \ sum_grad - splitinfos[i].sum_grad, sum_hess - splitinfos[i].sum_hess)) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) self.tree_node_queue[i].sitename = splitinfos[i].sitename if self.tree_node_queue[i].sitename == consts.GUEST: self.tree_node_queue[i].fid = self.encode( "feature_idx", splitinfos[i].best_fid) self.tree_node_queue[i].bid = self.encode( "feature_val", splitinfos[i].best_bid, self.tree_node_queue[i].id) else: self.tree_node_queue[i].fid = splitinfos[i].best_fid self.tree_node_queue[i].bid = splitinfos[i].best_bid self.update_feature_importance(splitinfos[i]) self.tree_.append(self.tree_node_queue[i]) self.tree_node_queue = new_tree_node_queue @staticmethod def dispatch_node(value, tree_=None, decoder=None, split_maskdict=None, bin_sparse_points=None): unleaf_state, nodeid = value[1] if tree_[nodeid].is_leaf is True: return tree_[nodeid].weight else: if tree_[nodeid].sitename == consts.GUEST: fid = decoder("feature_idx", tree_[nodeid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nodeid].bid, nodeid, split_maskdict) if value[0].features.get_data(fid, bin_sparse_points[fid]) <= bid: return (1, tree_[nodeid].left_nodeid) else: return (1, tree_[nodeid].right_nodeid) else: return (1, tree_[nodeid].fid, tree_[nodeid].bid, tree_[nodeid].sitename, nodeid, tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid) def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1): LOGGER.info("send node to host to dispath, depth is {}".format(dep)) federation.remote(obj=dispatch_guest_data, name=self.transfer_inst.dispatch_node_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host, dep), role=consts.HOST, idx=-1) def sync_dispatch_node_host_result(self, dep=-1): LOGGER.info("get host dispatch result, depth is {}".format(dep)) dispatch_node_host_result = federation.get( name=self.transfer_inst.dispatch_node_host_result.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host_result, dep), idx=-1) return dispatch_node_host_result def redispatch_node(self, dep=-1): LOGGER.info("redispatch node of depth {}".format(dep)) dispatch_node_method = functools.partial( self.dispatch_node, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict, bin_sparse_points=self.bin_sparse_points) dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues( dispatch_node_method) tree_node_num = self.tree_node_num LOGGER.info("remask dispatch node result of depth {}".format(dep)) dispatch_to_host_result = dispatch_guest_result.filter( lambda key, value: isinstance(value, tuple) and len(value) > 2) dispatch_guest_result = dispatch_guest_result.subtractByKey( dispatch_to_host_result) leaf = dispatch_guest_result.filter( lambda key, value: isinstance(value, tuple) is False) if self.predict_weights is None: self.predict_weights = leaf else: self.predict_weights = self.predict_weights.union(leaf) dispatch_guest_result = dispatch_guest_result.subtractByKey(leaf) self.sync_dispatch_node_host(dispatch_to_host_result, dep) dispatch_node_host_result = self.sync_dispatch_node_host_result(dep) self.node_dispatch = None for idx in range(len(dispatch_node_host_result)): if self.node_dispatch is None: self.node_dispatch = dispatch_node_host_result[idx] else: self.node_dispatch = self.node_dispatch.join(dispatch_node_host_result[idx], \ lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \ unleaf_state_nodeid1 if len( unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2) self.node_dispatch = self.node_dispatch.union(dispatch_guest_result) def sync_tree(self): LOGGER.info("sync tree to host") federation.remote(obj=self.tree_, name=self.transfer_inst.tree.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree), role=consts.HOST, idx=-1) def convert_bin_to_real(self): LOGGER.info("convert tree node bins to real value") for i in range(len(self.tree_)): if self.tree_[i].is_leaf is True: continue if self.tree_[i].sitename == consts.GUEST: fid = self.decode("feature_idx", self.tree_[i].fid, split_maskdict=self.split_maskdict) bid = self.decode("feature_val", self.tree_[i].bid, self.tree_[i].id, self.split_maskdict) real_splitval = self.encode("feature_val", self.bin_split_points[fid][bid], self.tree_[i].id) self.tree_[i].bid = real_splitval def fit(self): LOGGER.info("begin to fit guest decision tree") self.sync_encrypted_grad_and_hess() root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=consts.GUEST, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) self.tree_node_queue = [root_node] self.dispatch_all_node_to_root() for dep in range(self.max_depth): LOGGER.info( "start to fit depth {}, tree node queue size is {}".format( dep, len(self.tree_node_queue))) self.sync_tree_node_queue(self.tree_node_queue, dep) if len(self.tree_node_queue) == 0: break self.sync_node_positions(dep) self.data_bin_with_node_dispatch = self.data_bin.join( self.node_dispatch, lambda data_inst, dispatch_info: (data_inst, dispatch_info)) batch = 0 splitinfos = [] for i in range(0, len(self.tree_node_queue), self.max_split_nodes): self.cur_split_nodes = self.tree_node_queue[i:i + self. max_split_nodes] node_map = {} node_num = 0 for tree_node in self.cur_split_nodes: node_map[tree_node.id] = node_num node_num += 1 acc_histograms = self.get_histograms(node_map=node_map) self.best_splitinfo_guest = self.splitter.find_split( acc_histograms, self.valid_features, self.data_bin._partitions) self.federated_find_split(dep, batch) final_splitinfo_host = self.sync_final_split_host(dep, batch) cur_splitinfos = self.merge_splitinfo( self.best_splitinfo_guest, final_splitinfo_host) splitinfos.extend(cur_splitinfos) batch += 1 max_depth_reach = True if dep + 1 == self.max_depth else False self.update_tree_node_queue(splitinfos, max_depth_reach) self.redispatch_node(dep) self.sync_tree() self.convert_bin_to_real() tree_ = self.tree_ LOGGER.info("tree node num is %d" % len(tree_)) LOGGER.info("end to fit guest decision tree") @staticmethod def traverse_tree(predict_state, data_inst, tree_=None, decoder=None, split_maskdict=None): nid, tag = predict_state while tree_[nid].sitename == consts.GUEST: if tree_[nid].is_leaf is True: return tree_[nid].weight fid = decoder("feature_idx", tree_[nid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict) if data_inst.features.get_data(fid, 0) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid return nid, 1 def sync_predict_finish_tag(self, finish_tag, send_times): LOGGER.info("send the {}-th predict finish tag {} to host".format( finish_tag, send_times)) federation.remote(obj=finish_tag, name=self.transfer_inst.predict_finish_tag.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_finish_tag, send_times), role=consts.HOST, idx=-1) def sync_predict_data(self, predict_data, send_times): LOGGER.info("send predict data to host, sending times is {}".format( send_times)) federation.remote(obj=predict_data, name=self.transfer_inst.predict_data.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data, send_times), role=consts.HOST, idx=-1) def sync_data_predicted_by_host(self, send_times): LOGGER.info( "get predicted data by host, recv times is {}".format(send_times)) predict_data = federation.get( name=self.transfer_inst.predict_data_by_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data_by_host, send_times), idx=-1) return predict_data def predict(self, data_inst): LOGGER.info("start to predict!") predict_data = data_inst.mapValues(lambda data_inst: (0, 1)) site_host_send_times = 0 predict_result = None while True: traverse_tree = functools.partial( self.traverse_tree, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict) predict_data = predict_data.join(data_inst, traverse_tree) predict_leaf = predict_data.filter( lambda key, value: isinstance(value, tuple) is False) if predict_result is None: predict_result = predict_leaf else: predict_result = predict_result.union(predict_leaf) predict_data = predict_data.subtractByKey(predict_leaf) unleaf_node_count = predict_data.count() if unleaf_node_count == 0: self.sync_predict_finish_tag(True, site_host_send_times) break self.sync_predict_finish_tag(False, site_host_send_times) self.sync_predict_data(predict_data, site_host_send_times) predict_data_host = self.sync_data_predicted_by_host( site_host_send_times) for i in range(len(predict_data_host)): predict_data = predict_data.join( predict_data_host[i], lambda state1_nodeid1, state2_nodeid2: state1_nodeid1 if state1_nodeid1[1] == 0 else state2_nodeid2) site_host_send_times += 1 LOGGER.info("predict finish!") return predict_result def get_model_meta(self): model_meta = DecisionTreeModelMeta() model_meta.criterion_meta.CopyFrom( CriterionMeta(criterion_method=self.criterion_method, criterion_param=self.criterion_params)) model_meta.max_depth = self.max_depth model_meta.min_sample_split = self.min_sample_split model_meta.min_impurity_split = self.min_impurity_split model_meta.min_leaf_node = self.min_leaf_node return model_meta def set_model_meta(self, model_meta): self.max_depth = model_meta.max_depth self.min_sample_split = model_meta.min_sample_split self.min_impurity_split = model_meta.min_impurity_split self.min_leaf_node = model_meta.min_leaf_node self.criterion_method = model_meta.criterion_meta.criterion_method self.criterion_params = list(model_meta.criterion_meta.criterion_param) def get_model_param(self): model_param = DecisionTreeModelParam() for node in self.tree_: model_param.tree_.add(id=node.id, sitename=node.sitename, fid=node.fid, bid=node.bid, weight=node.weight, is_leaf=node.is_leaf, left_nodeid=node.left_nodeid, right_nodeid=node.right_nodeid) model_param.split_maskdict.update(self.split_maskdict) return model_param def set_model_param(self, model_param): self.tree_ = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid) self.tree_.append(_node) self.split_maskdict = dict(model_param.split_maskdict) def get_model(self): model_meta = self.get_model_meta() model_param = self.get_model_param() return model_meta, model_param def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param) def get_feature_importance(self): return self.feature_importances_
def __init__(self, tree_param: DecisionTreeParam, data_bin = None, bin_split_points: np.array = None, bin_sparse_point=None, g_h = None, valid_feature: dict = None, epoch_idx: int = None, role: str = None, tree_idx: int = None, flow_id: int = None, mode='train'): """ Parameters ---------- tree_param: decision tree parameter object data_bin binned: data instance bin_split_points: data split points bin_sparse_point: sparse data point g_h computed: g val and h val of instances valid_feature: dict points out valid features {valid:true,invalid:false} epoch_idx: current epoch index role: host or guest flow_id: flow id mode: train / predict """ super(HomoDecisionTreeClient, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = data_bin self.g_h = g_h self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_point self.epoch_idx = epoch_idx self.tree_idx = tree_idx # check max_split_nodes if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1: self.max_split_nodes += 1 LOGGER.warning('an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}'.format(self.max_split_nodes)) self.transfer_inst = HomoDecisionTreeTransferVariable() """ initializing here """ self.valid_features = valid_feature self.tree_node = [] # start from root node self.tree_node_num = 0 self.cur_layer_node = [] self.runtime_idx = 0 self.sitename = consts.GUEST self.feature_importance = {} self.inst2node_idx = None # record weights of samples self.sample_weights = None # secure aggregator, class SecureBoostClientAggregator if mode == 'train': self.role = role self.set_flowid(flow_id) self.aggregator = DecisionTreeClientAggregator(verbose=False) elif mode == 'predict': self.role, self.aggregator = None, None
class HomoDecisionTreeClient(DecisionTree): def __init__(self, tree_param: DecisionTreeParam, data_bin = None, bin_split_points: np.array = None, bin_sparse_point=None, g_h = None, valid_feature: dict = None, epoch_idx: int = None, role: str = None, tree_idx: int = None, flow_id: int = None, mode='train'): """ Parameters ---------- tree_param: decision tree parameter object data_bin binned: data instance bin_split_points: data split points bin_sparse_point: sparse data point g_h computed: g val and h val of instances valid_feature: dict points out valid features {valid:true,invalid:false} epoch_idx: current epoch index role: host or guest flow_id: flow id mode: train / predict """ super(HomoDecisionTreeClient, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = data_bin self.g_h = g_h self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_point self.epoch_idx = epoch_idx self.tree_idx = tree_idx # check max_split_nodes if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1: self.max_split_nodes += 1 LOGGER.warning('an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}'.format(self.max_split_nodes)) self.transfer_inst = HomoDecisionTreeTransferVariable() """ initializing here """ self.valid_features = valid_feature self.tree_node = [] # start from root node self.tree_node_num = 0 self.cur_layer_node = [] self.runtime_idx = 0 self.sitename = consts.GUEST self.feature_importance = {} self.inst2node_idx = None # record weights of samples self.sample_weights = None # secure aggregator, class SecureBoostClientAggregator if mode == 'train': self.role = role self.set_flowid(flow_id) self.aggregator = DecisionTreeClientAggregator(verbose=False) elif mode == 'predict': self.role, self.aggregator = None, None def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def get_grad_hess_sum(self, grad_and_hess_table): LOGGER.info("calculate the sum of grad and hess") grad, hess = grad_and_hess_table.reduce( lambda value1, value2: (value1[0] + value2[0], value1[1] + value2[1])) return grad, hess def update_feature_importance(self, split_info: List[SplitInfo]): for splitinfo in split_info: if self.feature_importance_type == "split": inc = 1 elif self.feature_importance_type == "gain": inc = splitinfo.gain else: raise ValueError("feature importance type {} not support yet".format(self.feature_importance_type)) fid = splitinfo.best_fid if fid not in self.feature_importance: self.feature_importance[fid] = 0 self.feature_importance[fid] += inc def sync_local_node_histogram(self, acc_histogram: List[HistogramBag], suffix): # sending local histogram self.aggregator.send_histogram(acc_histogram, suffix=suffix) LOGGER.debug('local histogram sent at layer {}'.format(suffix[0])) def get_node_map(self, nodes: List[Node], left_node_only=True): node_map = {} idx = 0 for node in nodes: if node.id != 0 and (not node.is_left_node and left_node_only): continue node_map[node.id] = idx idx += 1 return node_map def get_local_histogram(self, cur_to_split: List[Node], g_h, table_with_assign, split_points, sparse_point, valid_feature): LOGGER.info("start to get node histograms") node_map = self.get_node_map(nodes=cur_to_split) histograms = FeatureHistogram.calculate_histogram( table_with_assign, g_h, split_points, sparse_point, valid_feature, node_map, self.use_missing, self.zero_as_missing) hist_bags = [] for hist_list in histograms: hist_bags.append(HistogramBag(hist_list)) return hist_bags def get_left_node_local_histogram(self, cur_nodes: List[Node], tree: List[Node], g_h, table_with_assign, split_points, sparse_point, valid_feature): node_map = self.get_node_map(cur_nodes, left_node_only=True) LOGGER.info("start to get node histograms") histograms = FeatureHistogram.calculate_histogram( table_with_assign, g_h, split_points, sparse_point, valid_feature, node_map, self.use_missing, self.zero_as_missing) hist_bags = [] for hist_list in histograms: hist_bags.append(HistogramBag(hist_list)) left_nodes = [] for node in cur_nodes: if node.is_left_node or node.id == 0: left_nodes.append(node) # set histogram id and parent histogram id for node, hist_bag in zip(left_nodes, hist_bags): # LOGGER.debug('node id {}, node parent id {}, cur tree {}'.format(node.id, node.parent_nodeid, len(tree))) hist_bag.hid = node.id hist_bag.p_hid = node.parent_nodeid return hist_bags def update_tree(self, cur_to_split: List[Node], split_info: List[SplitInfo]): """ update current tree structure ---------- split_info """ LOGGER.debug('updating tree_node, cur layer has {} node'.format(len(cur_to_split))) next_layer_node = [] assert len(cur_to_split) == len(split_info) for idx in range(len(cur_to_split)): sum_grad = cur_to_split[idx].sum_grad sum_hess = cur_to_split[idx].sum_hess if split_info[idx].best_fid is None or split_info[idx].gain <= self.min_impurity_split + consts.FLOAT_ZERO: cur_to_split[idx].is_leaf = True self.tree_node.append(cur_to_split[idx]) continue cur_to_split[idx].fid = split_info[idx].best_fid cur_to_split[idx].bid = split_info[idx].best_bid cur_to_split[idx].missing_dir = split_info[idx].missing_dir p_id = cur_to_split[idx].id l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2 cur_to_split[idx].left_nodeid, cur_to_split[idx].right_nodeid = l_id, r_id self.tree_node_num += 2 l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess # create new left node and new right node left_node = Node(id=l_id, sitename=self.sitename, sum_grad=l_g, sum_hess=l_h, weight=self.splitter.node_weight(l_g, l_h), parent_nodeid=p_id, sibling_nodeid=r_id, is_left_node=True) right_node = Node(id=r_id, sitename=self.sitename, sum_grad=sum_grad - l_g, sum_hess=sum_hess - l_h, weight=self.splitter.node_weight(sum_grad - l_g, sum_hess - l_h), parent_nodeid=p_id, sibling_nodeid=l_id, is_left_node=False) next_layer_node.append(left_node) print('append left,cur tree has {} node'.format(len(self.tree_node))) next_layer_node.append(right_node) print('append right,cur tree has {} node'.format(len(self.tree_node))) self.tree_node.append(cur_to_split[idx]) return next_layer_node def convert_bin_to_val(self): """ convert current bid in tree nodes to real value """ for node in self.tree_node: if not node.is_leaf: node.bid = self.bin_split_points[node.fid][node.bid] def assign_instance_to_root_node(self, data_bin, root_node_id): return data_bin.mapValues(lambda inst: (1, root_node_id)) @staticmethod def assign_a_instance(row, tree: List[Node], bin_sparse_point, use_missing, use_zero_as_missing): leaf_status, nodeid = row[1] node = tree[nodeid] if node.is_leaf: return node.weight fid = node.fid bid = node.bid missing_dir = node.missing_dir missing_val = False if use_zero_as_missing: if row[0].features.get_data(fid, None) is None or \ row[0].features.get_data(fid) == NoneType(): missing_val = True elif use_missing and row[0].features.get_data(fid) == NoneType(): missing_val = True if missing_val: if missing_dir == 1: return 1, tree[nodeid].right_nodeid else: return 1, tree[nodeid].left_nodeid else: if row[0].features.get_data(fid, bin_sparse_point[fid]) <= bid: return 1, tree[nodeid].left_nodeid else: return 1, tree[nodeid].right_nodeid def assign_instance_to_new_node(self, table_with_assignment, tree_node: List[Node]): LOGGER.debug('re-assign instance to new nodes') assign_method = functools.partial(self.assign_a_instance, tree=tree_node, bin_sparse_point= self.bin_sparse_points, use_missing=self.use_missing, use_zero_as_missing =self.zero_as_missing) # FIXME assign_result = table_with_assignment.mapValues(assign_method) leaf_val = assign_result.filter(lambda key, value: isinstance(value, tuple) is False) assign_result = assign_result.subtractByKey(leaf_val) return assign_result, leaf_val @staticmethod def get_node_sample_weights(inst2node, tree_node: List[Node]): """ get samples' weights which correspond to its node assignment """ func = functools.partial(lambda inst, nodes: nodes[inst[1]].weight, nodes=tree_node) return inst2node.mapValues(func) def get_feature_importance(self): return self.feature_importance def sync_tree(self,): pass def sync_cur_layer_node_num(self, node_num, suffix): self.transfer_inst.cur_layer_node_num.remote(node_num, role=consts.ARBITER, idx=-1, suffix=suffix) def sync_best_splits(self, suffix) -> List[SplitInfo]: best_splits = self.transfer_inst.best_split_points.get(idx=0, suffix=suffix) return best_splits def fit(self): """ start to fit """ LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {}'.format(self.epoch_idx, self.tree_idx)) # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.g_h) # get aggregated root info self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx)) g_h_dict = self.aggregator.get_aggregated_root_info(suffix=('root_node_sync2', self.epoch_idx)) global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum'] # initialize node root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight= self.splitter.node_weight(global_g_sum, global_h_sum)) self.cur_layer_node = [root_node] LOGGER.debug('assign samples to root node') self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, 0) for dep in range(self.max_depth): if dep + 1 == self.max_depth: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) rest_sample_weights = self.get_node_sample_weights(self.inst2node_idx, self.tree_node) if self.sample_weights is None: self.sample_weights = rest_sample_weights else: self.sample_weights = self.sample_weights.union(rest_sample_weights) # stop fitting break LOGGER.debug('start to fit layer {}'.format(dep)) table_with_assignment = self.data_bin.join(self.inst2node_idx, lambda inst, assignment: (inst, assignment)) # send current layer node number: self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) split_info, agg_histograms = [], [] for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i+self.max_split_nodes] node_map = self.get_node_map(nodes=cur_to_split) LOGGER.debug('node map is {}'.format(node_map)) LOGGER.debug('computing histogram for batch{} at depth{}'.format(batch_id, dep)) local_histogram = self.get_left_node_local_histogram( cur_nodes=cur_to_split, tree=self.tree_node, g_h=self.g_h, table_with_assign=table_with_assignment, split_points=self.bin_split_points, sparse_point=self.bin_sparse_points, valid_feature=self.valid_features ) LOGGER.debug('federated finding best splits for batch{} at layer {}'.format(batch_id, dep)) self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) agg_histograms += local_histogram split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) LOGGER.debug('got best splits from arbiter') new_layer_node = self.update_tree(self.cur_layer_node, split_info) self.cur_layer_node = new_layer_node self.update_feature_importance(split_info) self.inst2node_idx, leaf_val = self.assign_instance_to_new_node(table_with_assignment, self.tree_node) # record leaf val if self.sample_weights is None: self.sample_weights = leaf_val else: self.sample_weights = self.sample_weights.union(leaf_val) LOGGER.debug('assigning instance to new nodes done') self.convert_bin_to_val() LOGGER.debug('fitting tree done') LOGGER.debug('tree node num is {}'.format(len(self.tree_node))) def traverse_tree(self, data_inst: Instance, tree: List[Node], use_missing=True, zero_as_missing=True): nid = 0# root node id while True: if tree[nid].is_leaf: return tree[nid].weight cur_node = tree[nid] fid,bid = cur_node.fid,cur_node.bid missing_dir = cur_node.missing_dir if use_missing and zero_as_missing: if data_inst.features.get_data(fid) == NoneType() or data_inst.features.get_data(fid, None) is None: nid = tree[nid].right_nodeid if missing_dir == 1 else tree[nid].left_nodeid elif data_inst.features.get_data(fid) <= bid: nid = tree[nid].left_nodeid else: nid = tree[nid].right_nodeid elif data_inst.features.get_data(fid) == NoneType(): nid = tree[nid].right_nodeid if missing_dir == 1 else tree[nid].left_nodeid elif data_inst.features.get_data(fid, 0) <= bid: nid = tree[nid].left_nodeid else: nid = tree[nid].right_nodeid def predict(self, data_inst): LOGGER.debug('tree start to predict') traverse_tree = functools.partial(self.traverse_tree, tree=self.tree_node, use_missing=self.use_missing, zero_as_missing=self.zero_as_missing,) predicted_weights = data_inst.mapValues(traverse_tree) return predicted_weights def get_model_meta(self): model_meta = DecisionTreeModelMeta() model_meta.criterion_meta.CopyFrom(CriterionMeta(criterion_method=self.criterion_method, criterion_param=self.criterion_params)) model_meta.max_depth = self.max_depth model_meta.min_sample_split = self.min_sample_split model_meta.min_impurity_split = self.min_impurity_split model_meta.min_leaf_node = self.min_leaf_node model_meta.use_missing = self.use_missing model_meta.zero_as_missing = self.zero_as_missing return model_meta def set_model_meta(self, model_meta): self.max_depth = model_meta.max_depth self.min_sample_split = model_meta.min_sample_split self.min_impurity_split = model_meta.min_impurity_split self.min_leaf_node = model_meta.min_leaf_node self.criterion_method = model_meta.criterion_meta.criterion_method self.criterion_params = list(model_meta.criterion_meta.criterion_param) self.use_missing = model_meta.use_missing self.zero_as_missing = model_meta.zero_as_missing def get_model_param(self): model_param = DecisionTreeModelParam() for node in self.tree_node: model_param.tree_.add(id=node.id, sitename=self.role, fid=node.fid, bid=node.bid, weight=node.weight, is_leaf=node.is_leaf, left_nodeid=node.left_nodeid, right_nodeid=node.right_nodeid, missing_dir=node.missing_dir) LOGGER.debug('output tree: epoch_idx:{} tree_idx:{}'.format(self.epoch_idx, self.tree_idx)) return model_param def set_model_param(self, model_param): self.tree_node = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_node.append(_node) def get_model(self): model_meta = self.get_model_meta() model_param = self.get_model_param() return model_meta, model_param def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param) """ For debug """ def print_leafs(self): LOGGER.debug('printing tree') for node in self.tree_node: LOGGER.debug(node) @staticmethod def print_split(split_infos: [SplitInfo]): LOGGER.debug('printing split info') for info in split_infos: LOGGER.debug(info) @staticmethod def print_hist(hist_list: [HistogramBag]): LOGGER.debug('printing histogramBag') for bag in hist_list: LOGGER.debug(bag)
class HeteroDecisionTreeGuest(DecisionTree): def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.data_bin_with_node_dispatch = None self.node_dispatch = None self.infos = None self.valid_features = None self.encrypter = None self.node_positions = None self.best_splitinfo_guest = None self.tree_node_queue = None self.tree_ = [] self.tree_node_num = 0 self.split_maskdict = {} self.transfer_inst = HeteroDecisionTreeTransferVariable() self.predict_weights = None def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def set_inputinfo(self, data_bin=None, grad_and_hess=None, bin_split_points=None, bin_sparse_points=None): LOGGER.info("set input info") self.data_bin = data_bin self.grad_and_hess = grad_and_hess self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_points def set_encrypter(self, encrypter): LOGGER.info("set encrypter") self.encrypter = encrypter def encrypt(self, val): return self.encrypter.encrypt(val) def decrypt(self, val): return self.encrypter.decrypt(val) def encode(self, etype="feature_idx", val=None, nid=None): if etype == "feature_idx": return val if etype == "feature_val": self.split_maskdict[nid] = val return None raise TypeError("encode type %s is not support!" % (str(etype))) @staticmethod def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None): if dtype == "feature_idx": return val if dtype == "feature_val": if nid in split_maskdict: return split_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) return TypeError("decode type %s is not support!" % (str(dtype))) def set_valid_features(self, valid_features=None): LOGGER.info("set valid features") self.valid_features = valid_features def sync_encrypted_grad_and_hess(self): LOGGER.info("send encrypted grad and hess to host") encrypted_grad_and_hess = self.encrypt_grad_and_hess() federation.remote(obj=encrypted_grad_and_hess, name=self.transfer_inst.encrypted_grad_and_hess.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_grad_and_hess), role=consts.HOST, idx=0) def encrypt_grad_and_hess(self): LOGGER.info("start to encrypt grad and hess") encrypter = self.encrypter encrypted_grad_and_hess = self.grad_and_hess.mapValues( lambda grad_hess: (encrypter.encrypt(grad_hess[0]), encrypter.encrypt(grad_hess[1]))) LOGGER.info("finish to encrypt grad and hess") return encrypted_grad_and_hess def get_grad_hess_sum(self, grad_and_hess_table): LOGGER.info("calculate the sum of grad and hess") grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[ 0] + value2[0], value1[1] + value2[1])) return grad, hess def dispatch_all_node_to_root(self, root_id=0): LOGGER.info("dispatch all node to root") self.node_dispatch = self.data_bin.mapValues(lambda data_inst: (1, root_id)) def get_histograms(self, node_map={}): LOGGER.info("start to get node histograms") histograms = FeatureHistogram.calculate_histogram( self.data_bin_with_node_dispatch, self.grad_and_hess, self.bin_split_points, self.bin_sparse_points, self.valid_features, node_map) acc_histograms = FeatureHistogram.accumulate_histogram(histograms) LOGGER.info("acc histogram shape is {}".format(len(acc_histograms))) return acc_histograms def sync_tree_node_queue(self, tree_node_queue, dep=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = tree_node_queue.copy() for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) federation.remote(obj=mask_tree_node_queue, name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree_node_queue, dep), role=consts.HOST, idx=0) def sync_node_positions(self, dep): LOGGER.info("send node positions of depth {}".format(dep)) federation.remote(obj=self.node_dispatch, name=self.transfer_inst.node_positions.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.node_positions, dep), role=consts.HOST, idx=0) def sync_encrypted_splitinfo_host(self, dep=-1): LOGGER.info("get encrypted splitinfo of depth {}".format(dep)) encrypted_splitinfo_host = federation.get( name=self.transfer_inst.encrypted_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_splitinfo_host, dep), idx=0) return encrypted_splitinfo_host def sync_federated_best_splitinfo_host(self, federated_best_splitinfo_host, dep=-1): LOGGER.info("send federated best splitinfo of depth {}".format(dep)) federation.remote( obj=federated_best_splitinfo_host, name=self.transfer_inst.federated_best_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.federated_best_splitinfo_host, dep), role=consts.HOST, idx=0) def federated_find_split(self, dep=-1): LOGGER.info("federated find split of depth {}".format(dep)) encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host(dep) best_splitinfo_host = [] for i in range(len(encrypted_splitinfo_host)): sum_grad = self.tree_node_queue[i].sum_grad sum_hess = self.tree_node_queue[i].sum_hess best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_idx = -1 for j in range(len(encrypted_splitinfo_host[i])): sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i][j] sum_grad_l = self.decrypt(sum_grad_l) sum_hess_l = self.decrypt(sum_hess_l) sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l, sum_hess_l, sum_grad_r, sum_hess_r) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_idx = j best_gain = self.encrypt(best_gain) best_splitinfo_host.append([best_idx, best_gain]) self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep) def sync_final_split_host(self, dep=-1): LOGGER.info("get host final splitinfo of depth {}".format(dep)) final_splitinfo_host = federation.get( name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.final_splitinfo_host, dep), idx=0) return final_splitinfo_host def merge_splitinfo(self, splitinfo_guest, splitinfo_host): LOGGER.info("merge splitinfo") splitinfos = [] for i in range(len(splitinfo_guest)): splitinfo = None gain_host = self.decrypt(splitinfo_host[i].gain) if splitinfo_guest[i].gain >= gain_host - consts.FLOAT_ZERO: splitinfo = splitinfo_guest[i] else: splitinfo = splitinfo_host[i] splitinfo.sum_grad = self.decrypt(splitinfo.sum_grad) splitinfo.sum_hess = self.decrypt(splitinfo.sum_hess) splitinfo.gain = gain_host splitinfos.append(splitinfo) return splitinfos def update_tree_node_queue(self, splitinfos, max_depth_reach): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is" .format(len(splitinfos), len(self.tree_node_queue))) new_tree_node_queue = [] for i in range(len(self.tree_node_queue)): sum_grad = self.tree_node_queue[i].sum_grad sum_hess = self.tree_node_queue[i].sum_hess if max_depth_reach or splitinfos[i].gain <= \ self.min_impurity_split + consts.FLOAT_ZERO: self.tree_node_queue[i].is_leaf = True else: self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1 self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.tree_node_queue[i].left_nodeid, sitename=consts.GUEST, sum_grad=splitinfos[i].sum_grad, sum_hess=splitinfos[i].sum_hess, weight=self.splitter.node_weight( splitinfos[i].sum_grad, splitinfos[i].sum_hess)) right_node = Node(id=self.tree_node_queue[i].right_nodeid, sitename=consts.GUEST, sum_grad=sum_grad - splitinfos[i].sum_grad, sum_hess=sum_hess - splitinfos[i].sum_hess, weight=self.splitter.node_weight( \ sum_grad - splitinfos[i].sum_grad, sum_hess - splitinfos[i].sum_hess)) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) LOGGER.info("tree_node_queue {} split!!!".format( self.tree_node_queue[i].id)) self.tree_node_queue[i].sitename = splitinfos[i].sitename if self.tree_node_queue[i].sitename == consts.GUEST: self.tree_node_queue[i].fid = self.encode( "feature_idx", splitinfos[i].best_fid) self.tree_node_queue[i].bid = self.encode( "feature_val", splitinfos[i].best_bid, self.tree_node_queue[i].id) else: self.tree_node_queue[i].fid = splitinfos[i].best_fid self.tree_node_queue[i].bid = splitinfos[i].best_bid self.tree_.append(self.tree_node_queue[i]) self.tree_node_queue = new_tree_node_queue @staticmethod def dispatch_node(value, tree_=None, decoder=None, split_maskdict=None, bin_sparse_points=None): unleaf_state, nodeid = value[1] if unleaf_state == 0: return value[1] if tree_[nodeid].is_leaf is True: return (0, nodeid) else: if tree_[nodeid].sitename == consts.GUEST: fid = decoder("feature_idx", tree_[nodeid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nodeid].bid, nodeid, split_maskdict) if value[0].features.get_data(fid, bin_sparse_points[fid]) <= bid: return (1, tree_[nodeid].left_nodeid) else: return (1, tree_[nodeid].right_nodeid) else: return (1, tree_[nodeid].fid, tree_[nodeid].bid, \ nodeid, tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid) def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1): LOGGER.info("send node to host to dispath, depth is {}".format(dep)) federation.remote(obj=dispatch_guest_data, name=self.transfer_inst.dispatch_node_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host, dep), role=consts.HOST, idx=0) def sync_dispatch_node_host_result(self, dep=-1): LOGGER.info("get host dispatch result, depth is {}".format(dep)) dispatch_node_host_result = federation.get( name=self.transfer_inst.dispatch_node_host_result.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host_result, dep), idx=0) return dispatch_node_host_result def redispatch_node(self, dep=-1): LOGGER.info("redispatch node of depth {}".format(dep)) dispatch_node_method = functools.partial( self.dispatch_node, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict, bin_sparse_points=self.bin_sparse_points) dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues( dispatch_node_method) tree_node_num = self.tree_node_num LOGGER.info("rmask edispatch node result of depth {}".format(dep)) dispatch_node_mask = dispatch_guest_result.mapValues( lambda state_nodeid: (state_nodeid[0], random.randint(0, tree_node_num - 1)) if len(state_nodeid) == 2 else state_nodeid) self.sync_dispatch_node_host(dispatch_node_mask, dep) dispatch_node_host_result = self.sync_dispatch_node_host_result(dep) self.node_dispatch = dispatch_guest_result.join(dispatch_node_host_result, \ lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \ unleaf_state_nodeid1 if len( unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2) def sync_tree(self): LOGGER.info("sync tree to host") federation.remote(obj=self.tree_, name=self.transfer_inst.tree.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree), role=consts.HOST, idx=0) def convert_bin_to_real(self): LOGGER.info("convert tree node bins to real value") for i in range(len(self.tree_)): if self.tree_[i].is_leaf is True: continue if self.tree_[i].sitename == consts.GUEST: fid = self.decode("feature_idx", self.tree_[i].fid, split_maskdict=self.split_maskdict) bid = self.decode("feature_val", self.tree_[i].bid, self.tree_[i].id, self.split_maskdict) real_splitval = self.encode("feature_val", self.bin_split_points[fid][bid], self.tree_[i].id) self.tree_[i].bid = real_splitval def fit(self): LOGGER.info("begin to fit guest decision tree") self.sync_encrypted_grad_and_hess() root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=consts.GUEST, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) self.tree_node_queue = [root_node] self.dispatch_all_node_to_root() for dep in range(self.max_depth): LOGGER.info( "start to fit depth {}, tree node queue size is {}".format( dep, len(self.tree_node_queue))) self.sync_tree_node_queue(self.tree_node_queue, dep) if len(self.tree_node_queue) == 0: break self.sync_node_positions(dep) node_map = {} node_num = 0 for tree_node in self.tree_node_queue: node_map[tree_node.id] = node_num node_num += 1 self.data_bin_with_node_dispatch = self.data_bin.join( self.node_dispatch, lambda data_inst, dispatch_info: (data_inst, dispatch_info)) acc_histograms = self.get_histograms(node_map=node_map) self.best_splitinfo_guest = self.splitter.find_split( acc_histograms, self.valid_features) self.federated_find_split(dep) final_splitinfo_host = self.sync_final_split_host(dep) splitinfos = self.merge_splitinfo(self.best_splitinfo_guest, final_splitinfo_host) max_depth_reach = True if dep + 1 == self.max_depth else False self.update_tree_node_queue(splitinfos, max_depth_reach) self.redispatch_node(dep) self.sync_tree() self.convert_bin_to_real() tree_ = self.tree_ LOGGER.info("tree node num is %d" % len(tree_)) self.predict_weights = self.node_dispatch.mapValues( lambda unleaf_state_nodeid: tree_[unleaf_state_nodeid[1]].weight) LOGGER.info("end to fit guest decision tree") @staticmethod def traverse_tree(predict_state, data_inst, tree_=None, decoder=None, split_maskdict=None): tag, nid = predict_state if tag == 0: return (tag, nid) while tree_[nid].sitename != consts.HOST: if tree_[nid].is_leaf is True: return (0, nid) fid = decoder("feature_idx", tree_[nid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict) if data_inst.features.get_data(fid, 0) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid return (1, nid) def sync_predict_finish_tag(self, finish_tag, send_times): LOGGER.info("send the {}-th predict finish tag {} to host".format( finish_tag, send_times)) federation.remote(obj=finish_tag, name=self.transfer_inst.predict_finish_tag.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_finish_tag, send_times), role=consts.HOST, idx=0) def sync_predict_data(self, predict_data, send_times): LOGGER.info("send predict data to host, sending times is {}".format( send_times)) federation.remote(obj=predict_data, name=self.transfer_inst.predict_data.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data, send_times), role=consts.HOST, idx=0) def sync_data_predicted_by_host(self, send_times): LOGGER.info( "get predicted data by host, recv times is {}".format(send_times)) predict_data = federation.get( name=self.transfer_inst.predict_data_by_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data_by_host, send_times), idx=0) return predict_data def predict(self, data_inst): LOGGER.info("start to predict!") predict_data = data_inst.mapValues(lambda data_inst: (1, 0)) site_host_send_times = 0 while True: traverse_tree = functools.partial( self.traverse_tree, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict) predict_data = predict_data.join(data_inst, traverse_tree) unleaf_node_count = predict_data.reduce( lambda value1, value2: (value1[0] + value2[0], 0))[0] if unleaf_node_count == 0: self.sync_predict_finish_tag(True, site_host_send_times) break predict_data_mask = predict_data.mapValues(lambda state_nodeid: ( state_nodeid[0], random.randint(0, len(self.tree_) - 1) ) if state_nodeid[0] == 0 else state_nodeid) self.sync_predict_finish_tag(False, site_host_send_times) self.sync_predict_data(predict_data_mask, site_host_send_times) predict_data_host = self.sync_data_predicted_by_host( site_host_send_times) predict_data = predict_data.join(predict_data_host, \ lambda unleaf_state1_nodeid1, unleaf_state2_nodeid2: \ unleaf_state1_nodeid1 if unleaf_state1_nodeid1[ 0] == 0 else unleaf_state2_nodeid2) site_host_send_times += 1 predict_data = predict_data.mapValues( lambda tag_nid: self.tree_[tag_nid[1]].weight) LOGGER.info("predict finish!") return predict_data def get_tree_model(self): LOGGER.info("get tree model") tree_model = DecisionTreeModelMeta() tree_model.tree_ = self.tree_ tree_model.split_maskdict = self.split_maskdict return tree_model def set_tree_model(self, tree_model): LOGGER.info("set tree model") self.tree_ = tree_model.tree_ self.split_maskdict = tree_model.split_maskdict