def update_tree(self, split_info, reach_max_depth):

        LOGGER.info(
            "update tree node, splitlist length is {}, tree node queue size is"
            .format(len(split_info), len(self.cur_layer_nodes)))
        new_tree_node_queue = []
        for i in range(len(self.cur_layer_nodes)):
            sum_grad = self.cur_layer_nodes[i].sum_grad
            sum_hess = self.cur_layer_nodes[i].sum_hess
            if reach_max_depth or split_info[i].gain <= \
                    self.min_impurity_split + consts.FLOAT_ZERO:  # if reach max_depth, only convert nodes to leaves
                self.cur_layer_nodes[i].is_leaf = True
            else:
                pid = self.cur_layer_nodes[i].id
                self.cur_layer_nodes[i].left_nodeid = self.tree_node_num + 1
                self.cur_layer_nodes[i].right_nodeid = self.tree_node_num + 2
                self.tree_node_num += 2

                left_node = Node(id=self.cur_layer_nodes[i].left_nodeid,
                                 sitename=self.sitename,
                                 sum_grad=split_info[i].sum_grad,
                                 sum_hess=split_info[i].sum_hess,
                                 weight=self.splitter.node_weight(
                                     split_info[i].sum_grad,
                                     split_info[i].sum_hess),
                                 is_left_node=True,
                                 parent_nodeid=pid)

                right_node = Node(id=self.cur_layer_nodes[i].right_nodeid,
                                  sitename=self.sitename,
                                  sum_grad=sum_grad - split_info[i].sum_grad,
                                  sum_hess=sum_hess - split_info[i].sum_hess,
                                  weight=self.splitter.node_weight(
                                      sum_grad - split_info[i].sum_grad,
                                      sum_hess - split_info[i].sum_hess),
                                  is_left_node=False,
                                  parent_nodeid=pid)

                new_tree_node_queue.append(left_node)
                new_tree_node_queue.append(right_node)

                self.cur_layer_nodes[i].sitename = split_info[i].sitename
                if self.cur_layer_nodes[i].sitename == self.sitename:
                    self.cur_layer_nodes[i].fid = self.encode(
                        "feature_idx", split_info[i].best_fid)
                    self.cur_layer_nodes[i].bid = self.encode(
                        "feature_val", split_info[i].best_bid,
                        self.cur_layer_nodes[i].id)
                    self.cur_layer_nodes[i].missing_dir = self.encode(
                        "missing_dir", split_info[i].missing_dir,
                        self.cur_layer_nodes[i].id)
                else:
                    self.cur_layer_nodes[i].fid = split_info[i].best_fid
                    self.cur_layer_nodes[i].bid = split_info[i].best_bid

                self.update_feature_importance(split_info[i])

            self.tree_node.append(self.cur_layer_nodes[i])

        self.cur_layer_nodes = new_tree_node_queue
Пример #2
0
    def update_host_side_tree(self, split_info, reach_max_depth):

        LOGGER.info(
            "update tree node, splitlist length is {}, tree node queue size is {}"
            .format(len(split_info), len(self.cur_layer_nodes)))

        new_tree_node_queue = []
        for i in range(len(self.cur_layer_nodes)):

            sum_grad = self.cur_layer_nodes[i].sum_grad
            sum_hess = self.cur_layer_nodes[i].sum_hess

            # when host node can not be further split, fid/bid is set to -1
            if reach_max_depth or split_info[i].best_fid == -1:
                self.cur_layer_nodes[i].is_leaf = True
            else:
                self.cur_layer_nodes[i].left_nodeid = self.tree_node_num + 1
                self.cur_layer_nodes[i].right_nodeid = self.tree_node_num + 2
                self.tree_node_num += 2

                left_node = Node(id=self.cur_layer_nodes[i].left_nodeid,
                                 sitename=self.sitename,
                                 sum_grad=split_info[i].sum_grad,
                                 sum_hess=split_info[i].sum_hess)
                right_node = Node(
                    id=self.cur_layer_nodes[i].right_nodeid,
                    sitename=self.sitename,
                    sum_grad=sum_grad - split_info[i].sum_grad,
                    sum_hess=sum_hess - split_info[i].sum_hess,
                )

                new_tree_node_queue.append(left_node)
                new_tree_node_queue.append(right_node)

                self.cur_layer_nodes[i].sitename = split_info[i].sitename

                self.cur_layer_nodes[i].fid = split_info[i].best_fid
                self.cur_layer_nodes[i].bid = split_info[i].best_bid
                self.cur_layer_nodes[i].missing_dir = split_info[i].missing_dir

                if self.feature_importance_type == 'split':
                    self.update_feature_importance(split_info[i],
                                                   record_site_name=False)

            self.tree_node.append(self.cur_layer_nodes[i])

        self.cur_layer_nodes = new_tree_node_queue
Пример #3
0
    def sync_cur_to_split_nodes(self, cur_to_split_node, dep=-1, idx=-1):
        LOGGER.info("send tree node queue of depth {}".format(dep))
        mask_tree_node_queue = copy.deepcopy(cur_to_split_node)
        for i in range(len(mask_tree_node_queue)):
            mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id)

        self.transfer_inst.tree_node_queue.remote(mask_tree_node_queue,
                                                  role=consts.HOST,
                                                  idx=idx,
                                                  suffix=(dep,))
 def initialize_root_node(self, ):
     root_sum_grad, root_sum_hess = self.get_grad_hess_sum(
         self.grad_and_hess)
     root_node = Node(id=0,
                      sitename=self.sitename,
                      sum_grad=root_sum_grad,
                      sum_hess=root_sum_hess,
                      weight=self.splitter.node_weight(
                          root_sum_grad, root_sum_hess))
     return root_node
Пример #5
0
 def initialize_root_node(self):
     LOGGER.info('initializing root node')
     root_sum_grad, root_sum_hess = self.get_grad_hess_sum(
         self.grad_and_hess)
     root_node = Node(id=0,
                      sitename=self.sitename,
                      sum_grad=root_sum_grad,
                      sum_hess=root_sum_hess,
                      weight=self.splitter.node_weight(
                          root_sum_grad, root_sum_hess))
     return root_node
Пример #6
0
    def mix_mode_fit(self):

        LOGGER.info('running mix mode')

        if self.tree_type == plan.tree_type_dict['guest_feat_only']:
            LOGGER.debug('this tree uses guest feature only, skip')
            return
        if self.self_host_id != self.target_host_id:
            LOGGER.debug('not selected host, skip')
            return

        LOGGER.debug('use local host feature to build tree')

        self.sync_encrypted_grad_and_hess()
        root_sum_grad, root_sum_hess = self.get_grad_hess_sum(self.grad_and_hess)
        self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin,
                                                               root_node_id=0)  # root node id is 0

        self.cur_layer_nodes = [Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess,)]

        for dep in range(self.max_depth):

            tree_action, layer_target_host_id = self.get_node_plan(dep)

            self.sync_cur_layer_nodes(self.cur_layer_nodes, dep)
            if len(self.cur_layer_nodes) == 0:
                break

            self.update_instances_node_positions()
            batch = 0
            split_info = []
            for i in range(0, len(self.cur_layer_nodes), self.max_split_nodes):
                self.cur_to_split_nodes = self.cur_layer_nodes[i: i + self.max_split_nodes]
                batch_split_info = self.compute_best_splits_with_node_plan(tree_action, layer_target_host_id,
                                                                           node_map=self.get_node_map(
                                                                           self.cur_to_split_nodes),
                                                                           dep=dep, batch_idx=batch,
                                                                           mode=consts.MIX_TREE)
                batch += 1
                split_info.extend(batch_split_info)

            self.update_host_side_tree(split_info, reach_max_depth=False)
            self.inst2node_idx = self.host_local_assign_instances_to_new_node()

        if self.cur_layer_nodes:
            self.update_host_side_tree([], reach_max_depth=True)  # mark final layer nodes as leaves
            self.update_instances_node_positions()  # update instances position
            self.host_local_assign_instances_to_new_node()  # assign instances to final leaves

        self.convert_bin_to_real2()  # convert bin num to val
        self.sync_leaf_nodes()  # send leaf nodes to guest
        self.process_leaves_info()  # remove encrypted g/h
        self.sync_sample_leaf_pos(self.sample_leaf_pos)  # sync sample final leaf positions
 def handle_leaf_nodes(self, nodes):
     """
     decrypte hess and grad and return tree node list that only contains leaves
     """
     max_node_id = -1
     for n in nodes:
         n.sum_hess = self.decrypt(n.sum_hess)
         n.sum_grad = self.decrypt(n.sum_grad)
         n.weight = self.splitter.node_weight(n.sum_grad, n.sum_hess)
         n.sitename = self.sitename
         if n.id > max_node_id:
             max_node_id = n.id
     new_nodes = [Node() for i in range(max_node_id + 1)]
     for n in nodes:
         new_nodes[n.id] = n
     return new_nodes
Пример #8
0
    def set_model_param(self, model_param):
        self.tree_node = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid,
                         missing_dir=node_param.missing_dir)

            self.tree_node.append(_node)

        self.split_maskdict = dict(model_param.split_maskdict)
        self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict)