class TestXgboostCriterion(unittest.TestCase): def setUp(self): self.reg_lambda = 0.3 self.criterion = XgboostCriterion(reg_lambda=self.reg_lambda) def test_init(self): self.assertTrue( np.fabs(self.criterion.reg_lambda - self.reg_lambda) < consts.FLOAT_ZERO) def test_split_gain(self): node = [0.5, 0.6] left = [0.1, 0.2] right = [0.4, 0.4] gain_all = node[0] * node[0] / (node[1] + self.reg_lambda) gain_left = left[0] * left[0] / (left[1] + self.reg_lambda) gain_right = right[0] * right[0] / (right[1] + self.reg_lambda) split_gain = gain_left + gain_right - gain_all self.assertTrue( np.fabs(self.criterion.split_gain(node, left, right) - split_gain) < consts.FLOAT_ZERO) def test_node_gain(self): grad = 0.5 hess = 6 gain = grad * grad / (hess + self.reg_lambda) self.assertTrue( np.fabs(self.criterion.node_gain(grad, hess) - gain) < consts.FLOAT_ZERO) def test_node_weight(self): grad = 0.5 hess = 6 weight = -grad / (hess + self.reg_lambda) self.assertTrue( np.fabs(self.criterion.node_weight(grad, hess) - weight) < consts.FLOAT_ZERO)
class Splitter(object): def __init__(self, criterion_method, criterion_params=[0, 1], min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1): LOGGER.info("splitter init!") if not isinstance(criterion_method, str): raise TypeError( "criterion_method type should be str, but %s find" % (type(criterion_method).__name__)) if criterion_method == "xgboost": if not criterion_params: self.criterion = XgboostCriterion() else: try: reg_lambda = float(criterion_params[0]) self.criterion = XgboostCriterion(reg_lambda) except: warnings.warn( "criterion_params' first criterion_params should be numeric" ) self.criterion = XgboostCriterion() self.min_impurity_split = min_impurity_split self.min_sample_split = min_sample_split self.min_leaf_node = min_leaf_node def find_split_single_histogram_guest(self, histogram, valid_features): best_fid = None best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_bid = None best_sum_grad_l = None best_sum_hess_l = None for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0: continue sum_grad = histogram[fid][bin_num - 1][0] sum_hess = histogram[fid][bin_num - 1][1] node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain([sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l splitinfo = SplitInfo(sitename=consts.GUEST, best_fid=best_fid, best_bid=best_bid, gain=best_gain, sum_grad=best_sum_grad_l, sum_hess=best_sum_hess_l) return splitinfo def find_split(self, histograms, valid_features, partitions=1): LOGGER.info("splitter find split of raw data") histogram_table = eggroll.parallelize(histograms, include_key=False, partition=partitions) splitinfo_table = histogram_table.mapValues( lambda sub_hist: self.find_split_single_histogram_guest( sub_hist, valid_features)) tree_node_splitinfo = [ splitinfo[1] for splitinfo in splitinfo_table.collect() ] return tree_node_splitinfo def find_split_single_histogram_host(self, histogram, valid_features): node_splitinfo = [] node_grad_hess = [] for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0: continue node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: splitinfo = SplitInfo(sitename=consts.HOST, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) return node_splitinfo, node_grad_hess def find_split_host(self, histograms, valid_features, partitions=1): LOGGER.info("splitter find split of host") histogram_table = eggroll.parallelize(histograms, include_key=False, partition=partitions) host_splitinfo_table = histogram_table.mapValues( lambda hist: self.find_split_single_histogram_host( hist, valid_features)) tree_node_splitinfo = [] encrypted_node_grad_hess = [] for _, splitinfo in host_splitinfo_table.collect(): tree_node_splitinfo.append(splitinfo[0]) encrypted_node_grad_hess.append(splitinfo[1]) return tree_node_splitinfo, encrypted_node_grad_hess def node_gain(self, grad, hess): return self.criterion.node_gain(grad, hess) def node_weight(self, grad, hess): return self.criterion.node_weight(grad, hess) def split_gain(self, sum_grad, sum_hess, sum_grad_l, sum_hess_l, sum_grad_r, sum_hess_r): gain = self.criterion.split_gain([sum_grad, sum_hess], \ [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) return gain
class Splitter(object): def __init__(self, criterion_method, criterion_params=[0, 1], min_impurity_split=1e-2, min_sample_split=2, min_leaf_node=1): LOGGER.info("splitter init!") if not isinstance(criterion_method, str): raise TypeError("criterion_method type should be str, but %s find" % (type(criterion_method).__name__)) if criterion_method == "xgboost": if not criterion_params: self.criterion = XgboostCriterion() else: try: reg_lambda = float(criterion_params[0]) self.criterion = XgboostCriterion(reg_lambda) except: warnings.warn("criterion_params' first criterion_params should be numeric") self.criterion = XgboostCriterion() self.min_impurity_split = min_impurity_split self.min_sample_split = min_sample_split self.min_leaf_node = min_leaf_node def find_split_single_histogram_guest(self, histogram, valid_features, sitename, use_missing, zero_as_missing): best_fid = None best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_bid = None best_sum_grad_l = None best_sum_hess_l = None missing_bin = 0 if use_missing: missing_bin = 1 # in default, missing value going to right missing_dir = 1 for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0 + missing_bin: continue sum_grad = histogram[fid][bin_num - 1][0] sum_hess = histogram[fid][bin_num - 1][1] node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num - missing_bin - 1): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain([sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l missing_dir = 1 """ missing value handle: dispatch to left child""" if use_missing: sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2] sum_grad_r -= histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_r -= histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_r -= histogram[fid][-1][2] - histogram[fid][-2][2] if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain([sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l missing_dir = -1 splitinfo = SplitInfo(sitename=sitename, best_fid=best_fid, best_bid=best_bid, gain=best_gain, sum_grad=best_sum_grad_l, sum_hess=best_sum_hess_l, missing_dir=missing_dir) return splitinfo def find_split(self, histograms, valid_features, partitions=1, sitename=consts.GUEST, use_missing=False, zero_as_missing=False): LOGGER.info("splitter find split of raw data") histogram_table = session.parallelize(histograms, include_key=False, partition=partitions) splitinfo_table = histogram_table.mapValues(lambda sub_hist: self.find_split_single_histogram_guest(sub_hist, valid_features, sitename, use_missing, zero_as_missing)) tree_node_splitinfo = [None for i in range(len(histograms))] for id, splitinfo in splitinfo_table.collect(): tree_node_splitinfo[id] = splitinfo # tree_node_splitinfo = [splitinfo[1] for splitinfo in splitinfo_table.collect()] return tree_node_splitinfo def find_split_single_histogram_host(self, fid_with_histogram, valid_features, sitename, use_missing=False, zero_as_missing=False): node_splitinfo = [] node_grad_hess = [] missing_bin = 0 if use_missing: missing_bin = 1 fid, histogram = fid_with_histogram if valid_features[fid] is False: return [], [] bin_num = len(histogram) if bin_num == 0: return [], [] node_cnt = histogram[bin_num - 1][2] if node_cnt < self.min_sample_split: return [], [] for bid in range(bin_num - missing_bin - 1): sum_grad_l = histogram[bid][0] sum_hess_l = histogram[bid][1] node_cnt_l = histogram[bid][2] node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: splitinfo = SplitInfo(sitename=sitename, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l, missing_dir=1) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) if use_missing: sum_grad_l += histogram[-1][0] - histogram[-2][0] sum_hess_l += histogram[-1][1] - histogram[-2][1] node_cnt_l += histogram[-1][2] - histogram[-2][2] splitinfo = SplitInfo(sitename=sitename, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l, missing_dir=-1) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) return node_splitinfo, node_grad_hess def find_split_host(self, histograms, valid_features, node_map, sitename=consts.HOST, use_missing=False, zero_as_missing=False): LOGGER.info("splitter find split of host") tree_node_splitinfo = [[] for i in range(len(node_map))] encrypted_node_grad_hess = [[] for i in range(len(node_map))] host_splitinfo_table = histograms.mapValues(lambda fid_with_hist: self.find_split_single_histogram_host(fid_with_hist, valid_features, sitename, use_missing, zero_as_missing)) for (nid, fid), splitinfo in host_splitinfo_table.collect(): tree_node_splitinfo[nid].extend(splitinfo[0]) encrypted_node_grad_hess[nid].extend(splitinfo[1]) return tree_node_splitinfo, BigObjectTransfer(encrypted_node_grad_hess) def node_gain(self, grad, hess): return self.criterion.node_gain(grad, hess) def node_weight(self, grad, hess): return self.criterion.node_weight(grad, hess) def split_gain(self, sum_grad, sum_hess, sum_grad_l, sum_hess_l, sum_grad_r, sum_hess_r): gain = self.criterion.split_gain([sum_grad, sum_hess], \ [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) return gain