def update_tree(self, split_info, reach_max_depth): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is" .format(len(split_info), len(self.cur_layer_nodes))) new_tree_node_queue = [] for i in range(len(self.cur_layer_nodes)): sum_grad = self.cur_layer_nodes[i].sum_grad sum_hess = self.cur_layer_nodes[i].sum_hess if reach_max_depth or split_info[i].gain <= \ self.min_impurity_split + consts.FLOAT_ZERO: # if reach max_depth, only convert nodes to leaves self.cur_layer_nodes[i].is_leaf = True else: pid = self.cur_layer_nodes[i].id self.cur_layer_nodes[i].left_nodeid = self.tree_node_num + 1 self.cur_layer_nodes[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.cur_layer_nodes[i].left_nodeid, sitename=self.sitename, sum_grad=split_info[i].sum_grad, sum_hess=split_info[i].sum_hess, weight=self.splitter.node_weight( split_info[i].sum_grad, split_info[i].sum_hess), is_left_node=True, parent_nodeid=pid) right_node = Node(id=self.cur_layer_nodes[i].right_nodeid, sitename=self.sitename, sum_grad=sum_grad - split_info[i].sum_grad, sum_hess=sum_hess - split_info[i].sum_hess, weight=self.splitter.node_weight( sum_grad - split_info[i].sum_grad, sum_hess - split_info[i].sum_hess), is_left_node=False, parent_nodeid=pid) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) self.cur_layer_nodes[i].sitename = split_info[i].sitename if self.cur_layer_nodes[i].sitename == self.sitename: self.cur_layer_nodes[i].fid = self.encode( "feature_idx", split_info[i].best_fid) self.cur_layer_nodes[i].bid = self.encode( "feature_val", split_info[i].best_bid, self.cur_layer_nodes[i].id) self.cur_layer_nodes[i].missing_dir = self.encode( "missing_dir", split_info[i].missing_dir, self.cur_layer_nodes[i].id) else: self.cur_layer_nodes[i].fid = split_info[i].best_fid self.cur_layer_nodes[i].bid = split_info[i].best_bid self.update_feature_importance(split_info[i]) self.tree_node.append(self.cur_layer_nodes[i]) self.cur_layer_nodes = new_tree_node_queue
def update_host_side_tree(self, split_info, reach_max_depth): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is {}" .format(len(split_info), len(self.cur_layer_nodes))) new_tree_node_queue = [] for i in range(len(self.cur_layer_nodes)): sum_grad = self.cur_layer_nodes[i].sum_grad sum_hess = self.cur_layer_nodes[i].sum_hess # when host node can not be further split, fid/bid is set to -1 if reach_max_depth or split_info[i].best_fid == -1: self.cur_layer_nodes[i].is_leaf = True else: self.cur_layer_nodes[i].left_nodeid = self.tree_node_num + 1 self.cur_layer_nodes[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.cur_layer_nodes[i].left_nodeid, sitename=self.sitename, sum_grad=split_info[i].sum_grad, sum_hess=split_info[i].sum_hess) right_node = Node( id=self.cur_layer_nodes[i].right_nodeid, sitename=self.sitename, sum_grad=sum_grad - split_info[i].sum_grad, sum_hess=sum_hess - split_info[i].sum_hess, ) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) self.cur_layer_nodes[i].sitename = split_info[i].sitename self.cur_layer_nodes[i].fid = split_info[i].best_fid self.cur_layer_nodes[i].bid = split_info[i].best_bid self.cur_layer_nodes[i].missing_dir = split_info[i].missing_dir if self.feature_importance_type == 'split': self.update_feature_importance(split_info[i], record_site_name=False) self.tree_node.append(self.cur_layer_nodes[i]) self.cur_layer_nodes = new_tree_node_queue
def sync_cur_to_split_nodes(self, cur_to_split_node, dep=-1, idx=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = copy.deepcopy(cur_to_split_node) for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) self.transfer_inst.tree_node_queue.remote(mask_tree_node_queue, role=consts.HOST, idx=idx, suffix=(dep,))
def initialize_root_node(self, ): root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) return root_node
def initialize_root_node(self): LOGGER.info('initializing root node') root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) return root_node
def mix_mode_fit(self): LOGGER.info('running mix mode') if self.tree_type == plan.tree_type_dict['guest_feat_only']: LOGGER.debug('this tree uses guest feature only, skip') return if self.self_host_id != self.target_host_id: LOGGER.debug('not selected host, skip') return LOGGER.debug('use local host feature to build tree') self.sync_encrypted_grad_and_hess() root_sum_grad, root_sum_hess = self.get_grad_hess_sum(self.grad_and_hess) self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, root_node_id=0) # root node id is 0 self.cur_layer_nodes = [Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess,)] for dep in range(self.max_depth): tree_action, layer_target_host_id = self.get_node_plan(dep) self.sync_cur_layer_nodes(self.cur_layer_nodes, dep) if len(self.cur_layer_nodes) == 0: break self.update_instances_node_positions() batch = 0 split_info = [] for i in range(0, len(self.cur_layer_nodes), self.max_split_nodes): self.cur_to_split_nodes = self.cur_layer_nodes[i: i + self.max_split_nodes] batch_split_info = self.compute_best_splits_with_node_plan(tree_action, layer_target_host_id, node_map=self.get_node_map( self.cur_to_split_nodes), dep=dep, batch_idx=batch, mode=consts.MIX_TREE) batch += 1 split_info.extend(batch_split_info) self.update_host_side_tree(split_info, reach_max_depth=False) self.inst2node_idx = self.host_local_assign_instances_to_new_node() if self.cur_layer_nodes: self.update_host_side_tree([], reach_max_depth=True) # mark final layer nodes as leaves self.update_instances_node_positions() # update instances position self.host_local_assign_instances_to_new_node() # assign instances to final leaves self.convert_bin_to_real2() # convert bin num to val self.sync_leaf_nodes() # send leaf nodes to guest self.process_leaves_info() # remove encrypted g/h self.sync_sample_leaf_pos(self.sample_leaf_pos) # sync sample final leaf positions
def handle_leaf_nodes(self, nodes): """ decrypte hess and grad and return tree node list that only contains leaves """ max_node_id = -1 for n in nodes: n.sum_hess = self.decrypt(n.sum_hess) n.sum_grad = self.decrypt(n.sum_grad) n.weight = self.splitter.node_weight(n.sum_grad, n.sum_hess) n.sitename = self.sitename if n.id > max_node_id: max_node_id = n.id new_nodes = [Node() for i in range(max_node_id + 1)] for n in nodes: new_nodes[n.id] = n return new_nodes
def set_model_param(self, model_param): self.tree_node = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_node.append(_node) self.split_maskdict = dict(model_param.split_maskdict) self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict)