def find_split_single_histogram_host(self, histogram, valid_features, sitename, use_missing=False, zero_as_missing=False): node_splitinfo = [] node_grad_hess = [] missing_bin = 0 if use_missing: missing_bin = 1 for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0: continue node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num - missing_bin - 1): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: splitinfo = SplitInfo(sitename=sitename, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l, missing_dir=1) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) if use_missing: sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2] splitinfo = SplitInfo(sitename=sitename, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l, missing_dir=-1) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) return node_splitinfo, node_grad_hess
def sync_final_splitinfo_host(self, splitinfo_host, federated_best_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfos = [] for i in range(len(splitinfo_host)): best_idx, best_gain = federated_best_splitinfo_host[i] if best_idx != -1: assert splitinfo_host[i][best_idx].sitename == consts.HOST splitinfo = splitinfo_host[i][best_idx] splitinfo.best_fid = self.encode("feature_idx", splitinfo.best_fid) assert splitinfo.best_fid is not None splitinfo.best_bid = self.encode("feature_val", splitinfo.best_bid, self.cur_split_nodes[i].id) splitinfo.gain = best_gain else: splitinfo = SplitInfo(sitename=consts.HOST, best_fid=-1, best_bid=-1, gain=best_gain) final_splitinfos.append(splitinfo) federation.remote(obj=final_splitinfos, name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.final_splitinfo_host, dep, batch), role=consts.GUEST, idx=0)
def find_split_single_histogram_host(self, histogram, valid_features): node_splitinfo = [] node_grad_hess = [] for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0: continue node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: splitinfo = SplitInfo(sitename=consts.HOST, best_fid=fid, best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) return node_splitinfo, node_grad_hess
def test_splitinfo(self): pass param_dict = {"sitename": "testsplitinfo", "best_fid": 23, "best_bid": 233, "sum_grad": 2333, "sum_hess": 23333, "gain": 233333} splitinfo = SplitInfo(sitename="testsplitinfo", best_fid=23, best_bid=233, sum_grad=2333, sum_hess=23333, gain=233333) for key in param_dict: self.assertTrue(param_dict[key] == getattr(splitinfo, key))
def find_split_single_histogram_guest(self, histogram, valid_features): best_fid = None best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_bid = None best_sum_grad_l = None best_sum_hess_l = None for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0: continue sum_grad = histogram[fid][bin_num - 1][0] sum_hess = histogram[fid][bin_num - 1][1] node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain([sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l splitinfo = SplitInfo(sitename=consts.GUEST, best_fid=best_fid, best_bid=best_bid, gain=best_gain, sum_grad=best_sum_grad_l, sum_hess=best_sum_hess_l) return splitinfo
def sync_final_splitinfo_host(self, splitinfo_host, federated_best_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfos = [] for i in range(len(splitinfo_host)): best_idx, best_gain = federated_best_splitinfo_host[i] if best_idx != -1: assert splitinfo_host[i][best_idx].sitename == self.sitename splitinfo = splitinfo_host[i][best_idx] splitinfo.best_fid = self.encode("feature_idx", splitinfo.best_fid) assert splitinfo.best_fid is not None splitinfo.best_bid = self.encode("feature_val", splitinfo.best_bid, self.cur_split_nodes[i].id) splitinfo.missing_dir = self.encode("missing_dir", splitinfo.missing_dir, self.cur_split_nodes[i].id) splitinfo.gain = best_gain else: splitinfo = SplitInfo(sitename=self.sitename, best_fid=-1, best_bid=-1, gain=best_gain) final_splitinfos.append(splitinfo) self.transfer_inst.final_splitinfo_host.remote(final_splitinfos, role=consts.GUEST, idx=-1, suffix=( dep, batch, )) """
def find_split_host(self, histograms, valid_features): LOGGER.info("splitter find split of host") tree_node_splitinfo = [] encrypted_node_grad_hess = [] for i in range(len(histograms)): node_splitinfo = [] node_grad_hess = [] for fid in range(len(histograms[i])): if valid_features[fid] is False: continue bin_num = len(histograms[i][fid]) if bin_num == 0: continue node_cnt = histograms[i][fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num): sum_grad_l = histograms[i][fid][bid][0] sum_hess_l = histograms[i][fid][bid][1] node_cnt_l = histograms[i][fid][bid][2] node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: splitinfo = SplitInfo(sitename=consts.HOST, best_fid=fid, \ best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l) node_splitinfo.append(splitinfo) node_grad_hess.append((sum_grad_l, sum_hess_l)) tree_node_splitinfo.append(node_splitinfo) encrypted_node_grad_hess.append(node_grad_hess) return tree_node_splitinfo, encrypted_node_grad_hess
def find_split_single_histogram_guest(self, histogram, valid_features, sitename, use_missing, zero_as_missing): best_fid = None best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_bid = None best_sum_grad_l = None best_sum_hess_l = None missing_bin = 0 if use_missing: missing_bin = 1 # in default, missing value going to right missing_dir = 1 for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0 + missing_bin: continue sum_grad = histogram[fid][bin_num - 1][0] sum_hess = histogram[fid][bin_num - 1][1] node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break for bid in range(bin_num - missing_bin - 1): sum_grad_l = histogram[fid][bid][0] sum_hess_l = histogram[fid][bid][1] node_cnt_l = histogram[fid][bid][2] sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l node_cnt_r = node_cnt - node_cnt_l if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain([sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l missing_dir = 1 """ missing value handle: dispatch to left child""" if use_missing: sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2] sum_grad_r -= histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_r -= histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_r -= histogram[fid][-1][2] - histogram[fid][-2][2] if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node: gain = self.criterion.split_gain( [sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r]) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l missing_dir = -1 splitinfo = SplitInfo(sitename=sitename, best_fid=best_fid, best_bid=best_bid, gain=best_gain, sum_grad=best_sum_grad_l, sum_hess=best_sum_hess_l, missing_dir=missing_dir) return splitinfo