예제 #1
0
    def _split_next(self):
        self._bridge.start(self._bridge.new_iter_id())

        split_info = tree_pb2.SplitInfo()
        self._bridge.receive_proto(
            self._bridge.current_iter_id, 'split_info') \
            .Unpack(split_info)

        node = self._nodes[split_info.node_id]

        node.left_child = self._add_node(node.node_id)
        left_child = self._nodes[node.left_child]
        left_child.weight = float('nan')

        node.right_child = self._add_node(node.node_id)
        right_child = self._nodes[node.right_child]
        right_child.weight = float('nan')

        self._num_leaves += 1

        if split_info.feature_id >= 0:
            self._set_node_partition(node, split_info)
            self._bridge.send_proto(
                self._bridge.current_iter_id, 'follower_split_info',
                tree_pb2.SplitInfo(left_samples=left_child.sample_ids,
                                   right_samples=right_child.sample_ids))
        else:
            node.is_owner = False
            left_child.sample_ids = list(split_info.left_samples)
            right_child.sample_ids = list(split_info.right_samples)

        self._bridge.commit()
        return left_child, right_child, split_info
예제 #2
0
    def _find_split_and_push(self, node):
        split_info = tree_pb2.SplitInfo(
            node_id=node.node_id, gain=-1)
        for fid, (grad_hist, hess_hist) in \
                enumerate(zip(node.grad_hists, node.hess_hists)):
            sum_g = sum(grad_hist)
            sum_h = sum(hess_hist)
            left_g = 0.0
            left_h = 0.0
            nan_g = grad_hist[-1]
            nan_h = hess_hist[-1]
            for i in range(len(grad_hist) - 2):
                left_g += grad_hist[i]
                left_h += hess_hist[i]
                self._compare_split(
                    split_info, True, fid, i,
                    left_g + nan_g, left_h + nan_h,
                    sum_g - left_g - nan_g, sum_h - left_h - nan_h)
                self._compare_split(
                    split_info, False, fid, i,
                    left_g, left_h,
                    sum_g - left_g, sum_h - left_h)

        self._split_candidates.put((-split_info.gain, split_info))

        return split_info.gain, split_info
예제 #3
0
    def _split_next(self):
        self._bridge.start(self._bridge.new_iter_id())

        _, split_info = self._split_candidates.get()
        node = self._nodes[split_info.node_id]

        node.left_child = self._add_node(node.node_id)
        left_child = self._nodes[node.left_child]
        left_child.weight = split_info.left_weight

        node.right_child = self._add_node(node.node_id)
        right_child = self._nodes[node.right_child]
        right_child.weight = split_info.right_weight

        self._num_leaves += 1

        if split_info.feature_id < self._binned.features.shape[1]:
            self._set_node_partition(node, split_info)
            self._bridge.send_proto(
                self._bridge.current_iter_id, 'split_info',
                tree_pb2.SplitInfo(node_id=split_info.node_id,
                                   feature_id=-1,
                                   left_samples=left_child.sample_ids,
                                   right_samples=right_child.sample_ids))
        else:
            node.is_owner = False
            fid = split_info.feature_id - self._binned.features.shape[1]
            self._bridge.send_proto(
                self._bridge.current_iter_id, 'split_info',
                tree_pb2.SplitInfo(node_id=split_info.node_id,
                                   feature_id=fid,
                                   split_point=split_info.split_point,
                                   default_left=split_info.default_left))

            split_info.feature_id = -1
            follower_split_info = tree_pb2.SplitInfo()
            self._bridge.receive_proto(
                self._bridge.current_iter_id, 'follower_split_info') \
                .Unpack(follower_split_info)
            left_child.sample_ids = list(follower_split_info.left_samples)
            right_child.sample_ids = list(follower_split_info.right_samples)

        self._bridge.commit()
        return left_child, right_child, split_info
예제 #4
0
파일: tree.py 프로젝트: feiga/fedlearner
    def _find_split_and_push(self, node):
        max_gain = -1
        max_fid = None
        split_point = None
        left_weight = None
        right_weight = None
        lam = self._l2_regularization
        for fid, (grad_hist, hess_hist) in \
                enumerate(zip(node.grad_hists, node.hess_hists)):
            sum_g = sum(grad_hist[:-1])
            sum_h = sum(hess_hist[:-1])
            left_g = 0.0
            left_h = 0.0
            for i in range(len(grad_hist[:-1]) - 1):
                left_g += grad_hist[i]
                left_h += hess_hist[i]
                right_g = sum_g - left_g
                right_h = sum_h - left_h
                gain = left_g*left_g/(left_h + lam) + \
                    right_g*right_g/(right_h + lam) - \
                    sum_g*sum_g/(sum_h + lam)
                if gain > max_gain:
                    max_gain = gain
                    max_fid = fid
                    split_point = i
                    left_weight = -left_g / (left_h + lam)
                    right_weight = -right_g / (right_h + lam)

        split_info = tree_pb2.SplitInfo(
            node_id=node.node_id,
            gain=max_gain,
            feature_id=max_fid,
            split_point=split_point,
            left_weight=left_weight * self._learning_rate,
            right_weight=right_weight * self._learning_rate)

        self._split_candidates.put((-max_gain, split_info))

        return max_gain, split_info