def update_tree(self, cur_to_split: List[Node], split_info: List[SplitInfo]): """ update current tree structure ---------- split_info """ LOGGER.debug('updating tree_node, cur layer has {} node'.format(len(cur_to_split))) next_layer_node = [] assert len(cur_to_split) == len(split_info) for idx in range(len(cur_to_split)): sum_grad = cur_to_split[idx].sum_grad sum_hess = cur_to_split[idx].sum_hess if split_info[idx].best_fid is None or split_info[idx].gain <= self.min_impurity_split + consts.FLOAT_ZERO: cur_to_split[idx].is_leaf = True self.tree_node.append(cur_to_split[idx]) continue cur_to_split[idx].fid = split_info[idx].best_fid cur_to_split[idx].bid = split_info[idx].best_bid cur_to_split[idx].missing_dir = split_info[idx].missing_dir p_id = cur_to_split[idx].id l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2 cur_to_split[idx].left_nodeid, cur_to_split[idx].right_nodeid = l_id, r_id self.tree_node_num += 2 l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess # create new left node and new right node left_node = Node(id=l_id, sitename=self.sitename, sum_grad=l_g, sum_hess=l_h, weight=self.splitter.node_weight(l_g, l_h), parent_nodeid=p_id, sibling_nodeid=r_id, is_left_node=True) right_node = Node(id=r_id, sitename=self.sitename, sum_grad=sum_grad - l_g, sum_hess=sum_hess - l_h, weight=self.splitter.node_weight(sum_grad - l_g, sum_hess - l_h), parent_nodeid=p_id, sibling_nodeid=l_id, is_left_node=False) next_layer_node.append(left_node) print('append left,cur tree has {} node'.format(len(self.tree_node))) next_layer_node.append(right_node) print('append right,cur tree has {} node'.format(len(self.tree_node))) self.tree_node.append(cur_to_split[idx]) return next_layer_node
def update_tree_node_queue(self, splitinfos, max_depth_reach): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is" .format(len(splitinfos), len(self.tree_node_queue))) new_tree_node_queue = [] for i in range(len(self.tree_node_queue)): sum_grad = self.tree_node_queue[i].sum_grad sum_hess = self.tree_node_queue[i].sum_hess if max_depth_reach or splitinfos[i].gain <= \ self.min_impurity_split + consts.FLOAT_ZERO: self.tree_node_queue[i].is_leaf = True else: self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1 self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.tree_node_queue[i].left_nodeid, sitename=self.sitename, sum_grad=splitinfos[i].sum_grad, sum_hess=splitinfos[i].sum_hess, weight=self.splitter.node_weight( splitinfos[i].sum_grad, splitinfos[i].sum_hess)) right_node = Node(id=self.tree_node_queue[i].right_nodeid, sitename=self.sitename, sum_grad=sum_grad - splitinfos[i].sum_grad, sum_hess=sum_hess - splitinfos[i].sum_hess, weight=self.splitter.node_weight( \ sum_grad - splitinfos[i].sum_grad, sum_hess - splitinfos[i].sum_hess)) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) self.tree_node_queue[i].sitename = splitinfos[i].sitename if self.tree_node_queue[i].sitename == self.sitename: self.tree_node_queue[i].fid = self.encode( "feature_idx", splitinfos[i].best_fid) self.tree_node_queue[i].bid = self.encode( "feature_val", splitinfos[i].best_bid, self.tree_node_queue[i].id) self.tree_node_queue[i].missing_dir = self.encode( "missing_dir", splitinfos[i].missing_dir, self.tree_node_queue[i].id) else: self.tree_node_queue[i].fid = splitinfos[i].best_fid self.tree_node_queue[i].bid = splitinfos[i].best_bid self.update_feature_importance(splitinfos[i]) self.tree_.append(self.tree_node_queue[i]) self.tree_node_queue = new_tree_node_queue
def test_node(self): param_dict = { "id": 5, "sitename": "test", "fid": 55, "bid": 555, "weight": -1, "is_leaf": True, "sum_grad": 2, "sum_hess": 3, "left_nodeid": 6, "right_nodeid": 7 } node = Node(id=5, sitename="test", fid=55, bid=555, weight=-1, is_leaf=True, sum_grad=2, sum_hess=3, left_nodeid=6, right_nodeid=7) for key in param_dict: self.assertTrue(param_dict[key] == getattr(node, key))
def fit(self): LOGGER.info("begin to fit guest decision tree") self.sync_encrypted_grad_and_hess() # LOGGER.debug("self.grad and hess is {}".format(list(self.grad_and_hess.collect()))) root_sum_grad, root_sum_hess = self.get_grad_hess_sum(self.grad_and_hess) root_node = Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight(root_sum_grad, root_sum_hess)) self.tree_node_queue = [root_node] self.dispatch_all_node_to_root() for dep in range(self.max_depth): LOGGER.info("start to fit depth {}, tree node queue size is {}".format(dep, len(self.tree_node_queue))) self.sync_tree_node_queue(self.tree_node_queue, dep) if len(self.tree_node_queue) == 0: break self.sync_node_positions(dep) self.data_bin_with_node_dispatch = self.data_bin.join(self.node_dispatch, lambda data_inst, dispatch_info: ( data_inst, dispatch_info)) batch = 0 splitinfos = [] for i in range(0, len(self.tree_node_queue), self.max_split_nodes): self.cur_split_nodes = self.tree_node_queue[i: i + self.max_split_nodes] node_map = {} node_num = 0 for tree_node in self.cur_split_nodes: node_map[tree_node.id] = node_num node_num += 1 acc_histograms = self.get_histograms(node_map=node_map) self.best_splitinfo_guest = self.splitter.find_split(acc_histograms, self.valid_features, self.data_bin._partitions, self.sitename, self.use_missing, self.zero_as_missing) self.federated_find_split(dep, batch) final_splitinfo_host = self.sync_final_split_host(dep, batch) cur_splitinfos = self.merge_splitinfo(self.best_splitinfo_guest, final_splitinfo_host) splitinfos.extend(cur_splitinfos) batch += 1 max_depth_reach = True if dep + 1 == self.max_depth else False self.update_tree_node_queue(splitinfos, max_depth_reach) self.redispatch_node(dep) self.sync_tree() self.convert_bin_to_real() tree_ = self.tree_ LOGGER.info("tree node num is %d" % len(tree_)) LOGGER.info("end to fit guest decision tree")
def fit(self): LOGGER.info("begin to fit guest decision tree") self.sync_encrypted_grad_and_hess() root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=consts.GUEST, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) self.tree_node_queue = [root_node] self.dispatch_all_node_to_root() for dep in range(self.max_depth): LOGGER.info( "start to fit depth {}, tree node queue size is {}".format( dep, len(self.tree_node_queue))) self.sync_tree_node_queue(self.tree_node_queue, dep) if len(self.tree_node_queue) == 0: break self.sync_node_positions(dep) node_map = {} node_num = 0 for tree_node in self.tree_node_queue: node_map[tree_node.id] = node_num node_num += 1 self.data_bin_with_node_dispatch = self.data_bin.join( self.node_dispatch, lambda data_inst, dispatch_info: (data_inst, dispatch_info)) acc_histograms = self.get_histograms(node_map=node_map) self.best_splitinfo_guest = self.splitter.find_split( acc_histograms, self.valid_features) self.federated_find_split(dep) final_splitinfo_host = self.sync_final_split_host(dep) splitinfos = self.merge_splitinfo(self.best_splitinfo_guest, final_splitinfo_host) max_depth_reach = True if dep + 1 == self.max_depth else False self.update_tree_node_queue(splitinfos, max_depth_reach) self.redispatch_node(dep) self.sync_tree() self.convert_bin_to_real() tree_ = self.tree_ LOGGER.info("tree node num is %d" % len(tree_)) self.predict_weights = self.node_dispatch.mapValues( lambda unleaf_state_nodeid: tree_[unleaf_state_nodeid[1]].weight) LOGGER.info("end to fit guest decision tree")
def sync_tree_node_queue(self, tree_node_queue, dep=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = copy.deepcopy(tree_node_queue) for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) self.transfer_inst.tree_node_queue.remote(mask_tree_node_queue, role=consts.HOST, idx=-1, suffix=(dep, )) """
def sync_tree_node_queue(self, tree_node_queue, dep=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = tree_node_queue.copy() for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) federation.remote(obj=mask_tree_node_queue, name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree_node_queue, dep), role=consts.HOST, idx=0)
def set_model_param(self, model_param): self.tree_node = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_node.append(_node)
def set_model_param(self, model_param): self.tree_ = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid) self.tree_.append(_node) self.split_maskdict = dict(model_param.split_maskdict)
def fit(self): """ start to fit """ LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {}'.format(self.epoch_idx, self.tree_idx)) # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.g_h) # get aggregated root info self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx)) g_h_dict = self.aggregator.get_aggregated_root_info(suffix=('root_node_sync2', self.epoch_idx)) global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum'] # initialize node root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight= self.splitter.node_weight(global_g_sum, global_h_sum)) self.cur_layer_node = [root_node] LOGGER.debug('assign samples to root node') self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, 0) for dep in range(self.max_depth): if dep + 1 == self.max_depth: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) rest_sample_weights = self.get_node_sample_weights(self.inst2node_idx, self.tree_node) if self.sample_weights is None: self.sample_weights = rest_sample_weights else: self.sample_weights = self.sample_weights.union(rest_sample_weights) # stop fitting break LOGGER.debug('start to fit layer {}'.format(dep)) table_with_assignment = self.data_bin.join(self.inst2node_idx, lambda inst, assignment: (inst, assignment)) # send current layer node number: self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) split_info, agg_histograms = [], [] for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i+self.max_split_nodes] node_map = self.get_node_map(nodes=cur_to_split) LOGGER.debug('node map is {}'.format(node_map)) LOGGER.debug('computing histogram for batch{} at depth{}'.format(batch_id, dep)) local_histogram = self.get_left_node_local_histogram( cur_nodes=cur_to_split, tree=self.tree_node, g_h=self.g_h, table_with_assign=table_with_assignment, split_points=self.bin_split_points, sparse_point=self.bin_sparse_points, valid_feature=self.valid_features ) LOGGER.debug('federated finding best splits for batch{} at layer {}'.format(batch_id, dep)) self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) agg_histograms += local_histogram split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) LOGGER.debug('got best splits from arbiter') new_layer_node = self.update_tree(self.cur_layer_node, split_info) self.cur_layer_node = new_layer_node self.update_feature_importance(split_info) self.inst2node_idx, leaf_val = self.assign_instance_to_new_node(table_with_assignment, self.tree_node) # record leaf val if self.sample_weights is None: self.sample_weights = leaf_val else: self.sample_weights = self.sample_weights.union(leaf_val) LOGGER.debug('assigning instance to new nodes done') self.convert_bin_to_val() LOGGER.debug('fitting tree done') LOGGER.debug('tree node num is {}'.format(len(self.tree_node)))