def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.data_bin_with_node_dispatch = None self.node_dispatch = None self.infos = None self.valid_features = None self.encrypter = None self.encrypted_mode_calculator = None self.best_splitinfo_guest = None self.tree_node_queue = None self.cur_split_nodes = None self.tree_ = [] self.tree_node_num = 0 self.split_maskdict = {} self.missing_dir_maskdict = {} self.transfer_inst = HeteroDecisionTreeTransferVariable() self.predict_weights = None self.host_party_idlist = [] self.runtime_idx = 0 self.sitename = consts.GUEST self.feature_importances_ = {}
def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeHost, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.data_bin_with_position = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.infos = None self.valid_features = None self.pubkey = None self.privakey = None self.tree_id = None self.encrypted_grad_and_hess = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.tree_node_queue = None self.cur_split_nodes = None self.split_maskdict = {} self.missing_dir_maskdict = {} self.tree_ = None self.runtime_idx = 0 self.sitename = consts.HOST
def __init__(self, tree_param): super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.encrypter = None self.encrypted_mode_calculator = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.sitename = consts.GUEST # will be modified in self.set_runtime_idx() self.complete_secure_tree = False self.split_maskdict = {} self.missing_dir_maskdict = {} self.host_party_idlist = [] self.compressor = None # goss subsample self.run_goss = False self.top_rate, self.other_rate = 0.2, 0.1 # goss sampling rate # cipher compressing self.cipher_encoder = None self.cipher_decompressor = None self.run_cipher_compressing = False self.key_length = None self.round_decimal = 7 self.max_sample_weight = 1 # code version control self.new_ver = True
def __init__(self, tree_param): super(HeteroDecisionTreeGuest, self).__init__(tree_param) # In FATE-1.8 reset feature importance to 'split' self.feature_importance_type = 'split' self.encrypter = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.sitename = consts.GUEST # will be modified in self.set_runtime_idx() self.complete_secure_tree = False self.split_maskdict = {} # save split value self.missing_dir_maskdict = {} # save missing dir self.host_party_idlist = [] self.compressor = None # goss subsample self.run_goss = False # cipher compressing self.task_type = None self.run_cipher_compressing = True self.packer = None self.max_sample_weight = 1 # code version control self.new_ver = True # mo tree self.mo_tree = False self.class_num = 1
def __init__(self, tree_param): super(HeteroDecisionTreeHost, self).__init__(tree_param) self.encrypted_grad_and_hess = None self.runtime_idx = 0 self.sitename = consts.HOST # will be modified in self.set_runtime_idx() self.complete_secure_tree = False self.host_party_idlist = [] # feature shuffling / missing_dir masking self.feature_num = -1 self.missing_dir_mask_left = {} # mask for left direction self.missing_dir_mask_right = {} # mask for right direction self.split_maskdict = {} # mask for split value self.missing_dir_maskdict = {} self.fid_bid_random_mapping = {} self.inverse_fid_bid_random_mapping = {} self.bin_num = None # goss subsample self.run_goss = False # transfer variable self.transfer_inst = HeteroDecisionTreeTransferVariable() # cipher compressing self.cipher_compressor = None self.run_cipher_compressing = True # code version control self.new_ver = True
def __init__(self, tree_param): super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.encrypter = None self.encrypted_mode_calculator = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.sitename = consts.GUEST # will be modified in self.set_runtime_idx() self.complete_secure_tree = False self.split_maskdict = {} self.missing_dir_maskdict = {} self.host_party_idlist = []
def __init__(self, tree_param): super(HeteroDecisionTreeHost, self).__init__(tree_param) self.encrypted_grad_and_hess = None self.split_maskdict = {} self.missing_dir_maskdict = {} self.runtime_idx = 0 self.sitename = consts.HOST # will be modified in self.set_runtime_idx() self.complete_secure_tree = False self.host_party_idlist = [] # For fast histogram self.run_sparse_opt = False self.bin_num = None self.data_bin_dense = None self.data_bin_dense_with_position = None self.transfer_inst = HeteroDecisionTreeTransferVariable()
class HeteroDecisionTreeHost(DecisionTree): def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeHost, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.data_bin_with_position = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.infos = None self.valid_features = None self.pubkey = None self.privakey = None self.tree_id = None self.encrypted_grad_and_hess = None self.transfer_inst = HeteroDecisionTreeTransferVariable() self.tree_node_queue = None self.cur_split_nodes = None self.split_maskdict = {} self.missing_dir_maskdict = {} self.tree_ = None self.runtime_idx = 0 self.sitename = consts.HOST def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) # def set_runtime_idx(self, runtime_idx): # self.runtime_idx = runtime_idx # self.sitename = ":".join([consts.HOST, str(self.runtime_idx)]) def set_inputinfo(self, data_bin=None, grad_and_hess=None, bin_split_points=None, bin_sparse_points=None): LOGGER.info("set input info") self.data_bin = data_bin self.grad_and_hess = grad_and_hess self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_points def set_valid_features(self, valid_features=None): LOGGER.info("set valid features") self.valid_features = valid_features def encode(self, etype="feature_idx", val=None, nid=None): if etype == "feature_idx": return val if etype == "feature_val": self.split_maskdict[nid] = val return None if etype == "missing_dir": self.missing_dir_maskdict[nid] = val return None raise TypeError("encode type %s is not support!" % (str(etype))) @staticmethod def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None, missing_dir_maskdict=None): if dtype == "feature_idx": return val if dtype == "feature_val": if nid in split_maskdict: return split_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) if dtype == "missing_dir": if nid in missing_dir_maskdict: return missing_dir_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) return TypeError("decode type %s is not support!" % (str(dtype))) def sync_encrypted_grad_and_hess(self): LOGGER.info("get encrypted grad and hess") self.grad_and_hess = self.transfer_inst.encrypted_grad_and_hess.get( idx=0) """ self.grad_and_hess = federation.get(name=self.transfer_inst.encrypted_grad_and_hess.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_grad_and_hess), idx=0) """ def sync_node_positions(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) node_positions = self.transfer_inst.node_positions.get(idx=0, suffix=(dep, )) """ node_positions = federation.get(name=self.transfer_inst.node_positions.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.node_positions, dep), idx=0) """ return node_positions def sync_tree_node_queue(self, dep=-1): LOGGER.info("get tree node queue of depth {}".format(dep)) self.tree_node_queue = self.transfer_inst.tree_node_queue.get( idx=0, suffix=(dep, )) """ self.tree_node_queue = federation.get(name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.tree_node_queue, dep), idx=0) """ def get_histograms(self, node_map={}): LOGGER.info("start to get node histograms") histograms = FeatureHistogram.calculate_histogram( self.data_bin_with_position, self.grad_and_hess, self.bin_split_points, self.bin_sparse_points, self.valid_features, node_map, self.use_missing, self.zero_as_missing) LOGGER.info("begin to accumulate histograms") acc_histograms = FeatureHistogram.accumulate_histogram(histograms) LOGGER.info("acc histogram shape is {}".format(len(acc_histograms))) return acc_histograms def sync_encrypted_splitinfo_host(self, encrypted_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send encrypted splitinfo of depth {}, batch {}".format( dep, batch)) self.transfer_inst.encrypted_splitinfo_host.remote( encrypted_splitinfo_host, role=consts.GUEST, idx=-1, suffix=( dep, batch, )) """ federation.remote(obj=encrypted_splitinfo_host, name=self.transfer_inst.encrypted_splitinfo_host.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.encrypted_splitinfo_host, dep, batch), role=consts.GUEST, idx=-1) """ def sync_federated_best_splitinfo_host(self, dep=-1, batch=-1): LOGGER.info( "get federated best splitinfo of depth {}, batch {}".format( dep, batch)) federated_best_splitinfo_host = self.transfer_inst.federated_best_splitinfo_host.get( idx=0, suffix=( dep, batch, )) """ federated_best_splitinfo_host = federation.get(name=self.transfer_inst.federated_best_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.federated_best_splitinfo_host, dep, batch), idx=0) """ return federated_best_splitinfo_host def sync_final_splitinfo_host(self, splitinfo_host, federated_best_splitinfo_host, dep=-1, batch=-1): LOGGER.info("send host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfos = [] for i in range(len(splitinfo_host)): best_idx, best_gain = federated_best_splitinfo_host[i] if best_idx != -1: assert splitinfo_host[i][best_idx].sitename == self.sitename splitinfo = splitinfo_host[i][best_idx] splitinfo.best_fid = self.encode("feature_idx", splitinfo.best_fid) assert splitinfo.best_fid is not None splitinfo.best_bid = self.encode("feature_val", splitinfo.best_bid, self.cur_split_nodes[i].id) splitinfo.missing_dir = self.encode("missing_dir", splitinfo.missing_dir, self.cur_split_nodes[i].id) splitinfo.gain = best_gain else: splitinfo = SplitInfo(sitename=self.sitename, best_fid=-1, best_bid=-1, gain=best_gain) final_splitinfos.append(splitinfo) self.transfer_inst.final_splitinfo_host.remote(final_splitinfos, role=consts.GUEST, idx=-1, suffix=( dep, batch, )) """ federation.remote(obj=final_splitinfos, name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.final_splitinfo_host, dep, batch), role=consts.GUEST, idx=-1) """ def sync_dispatch_node_host(self, dep): LOGGER.info("get node from host to dispath, depth is {}".format(dep)) dispatch_node_host = self.transfer_inst.dispatch_node_host.get( idx=0, suffix=(dep, )) """ dispatch_node_host = federation.get(name=self.transfer_inst.dispatch_node_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host, dep), idx=0) """ return dispatch_node_host @staticmethod def dispatch_node(value1, value2, sitename=None, decoder=None, split_maskdict=None, bin_sparse_points=None, use_missing=False, zero_as_missing=False, missing_dir_maskdict=None): unleaf_state, fid, bid, node_sitename, nodeid, left_nodeid, right_nodeid = value1 if node_sitename != sitename: return value1 fid = decoder("feature_idx", fid, split_maskdict=split_maskdict) bid = decoder("feature_val", bid, nodeid, split_maskdict=split_maskdict) if not use_missing: if value2.features.get_data(fid, bin_sparse_points[fid]) <= bid: return unleaf_state, left_nodeid else: return unleaf_state, right_nodeid else: missing_dir = decoder("missing_dir", 1, nodeid, missing_dir_maskdict=missing_dir_maskdict) missing_val = False if zero_as_missing: if value2.features.get_data(fid, None) is None or \ value2.features.get_data(fid) == NoneType(): missing_val = True elif use_missing and value2.features.get_data(fid) == NoneType(): missing_val = True if missing_val: if missing_dir == 1: return unleaf_state, right_nodeid else: return unleaf_state, left_nodeid else: if value2.features.get_data(fid, bin_sparse_points[fid]) <= bid: return unleaf_state, left_nodeid else: return unleaf_state, right_nodeid def sync_dispatch_node_host_result(self, dispatch_node_host_result, dep=-1): LOGGER.info("send host dispatch result, depth is {}".format(dep)) self.transfer_inst.dispatch_node_host_result.remote( dispatch_node_host_result, role=consts.GUEST, idx=-1, suffix=(dep, )) """ federation.remote(obj=dispatch_node_host_result, name=self.transfer_inst.dispatch_node_host_result.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.dispatch_node_host_result, dep), role=consts.GUEST, idx=-1) """ def find_dispatch(self, dispatch_node_host, dep=-1): LOGGER.info("start to find host dispath of depth {}".format(dep)) dispatch_node_method = functools.partial( self.dispatch_node, sitename=self.sitename, decoder=self.decode, split_maskdict=self.split_maskdict, bin_sparse_points=self.bin_sparse_points, use_missing=self.use_missing, zero_as_missing=self.zero_as_missing, missing_dir_maskdict=self.missing_dir_maskdict) dispatch_node_host_result = dispatch_node_host.join( self.data_bin, dispatch_node_method) self.sync_dispatch_node_host_result(dispatch_node_host_result, dep) def sync_tree(self): LOGGER.info("sync tree from guest") self.tree_ = self.transfer_inst.tree.get(idx=0) """ self.tree_ = federation.get(name=self.transfer_inst.tree.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree), idx=0) """ def remove_duplicated_split_nodes(self, split_nid_used): LOGGER.info("remove duplicated nodes from split mask dict") duplicated_nodes = set( self.split_maskdict.keys()) - set(split_nid_used) for nid in duplicated_nodes: del self.split_maskdict[nid] def convert_bin_to_real(self): LOGGER.info("convert tree node bins to real value") split_nid_used = [] for i in range(len(self.tree_)): if self.tree_[i].is_leaf is True: continue if self.tree_[i].sitename == self.sitename: fid = self.decode("feature_idx", self.tree_[i].fid, split_maskdict=self.split_maskdict) bid = self.decode("feature_val", self.tree_[i].bid, self.tree_[i].id, self.split_maskdict) LOGGER.debug("shape of bin_split_points is {}".format( len(self.bin_split_points[fid]))) real_splitval = self.encode("feature_val", self.bin_split_points[fid][bid], self.tree_[i].id) self.tree_[i].bid = real_splitval split_nid_used.append(self.tree_[i].id) self.remove_duplicated_split_nodes(split_nid_used) @staticmethod def traverse_tree(predict_state, data_inst, tree_=None, decoder=None, split_maskdict=None, sitename=consts.HOST, use_missing=False, zero_as_missing=False, missing_dir_maskdict=None): nid, _ = predict_state if tree_[nid].sitename != sitename: return predict_state while tree_[nid].sitename == sitename: fid = decoder("feature_idx", tree_[nid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict) if use_missing: missing_dir = decoder( "missing_dir", 1, nid, missing_dir_maskdict=missing_dir_maskdict) else: missing_dir = 1 if use_missing and zero_as_missing: missing_dir = decoder( "missing_dir", 1, nid, missing_dir_maskdict=missing_dir_maskdict) if data_inst.features.get_data(fid) == NoneType( ) or data_inst.features.get_data(fid, None) is None: if missing_dir == 1: nid = tree_[nid].right_nodeid else: nid = tree_[nid].left_nodeid elif data_inst.features.get_data(fid) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid elif data_inst.features.get_data(fid) == NoneType(): if missing_dir == 1: nid = tree_[nid].right_nodeid else: nid = tree_[nid].left_nodeid elif data_inst.features.get_data(fid, 0) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid return nid, 0 def sync_predict_finish_tag(self, recv_times): LOGGER.info( "get the {}-th predict finish tag from guest".format(recv_times)) finish_tag = self.transfer_inst.predict_finish_tag.get( idx=0, suffix=(recv_times, )) """ finish_tag = federation.get(name=self.transfer_inst.predict_finish_tag.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_finish_tag, recv_times), idx=0) """ return finish_tag def sync_predict_data(self, recv_times): LOGGER.info( "srecv predict data to host, recv times is {}".format(recv_times)) predict_data = self.transfer_inst.predict_data.get( idx=0, suffix=(recv_times, )) """ predict_data = federation.get(name=self.transfer_inst.predict_data.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data, recv_times), idx=0) """ return predict_data def sync_data_predicted_by_host(self, predict_data, send_times): LOGGER.info( "send predicted data by host, send times is {}".format(send_times)) self.transfer_inst.predict_data_by_host.remote(predict_data, role=consts.GUEST, idx=0, suffix=(send_times, )) """ federation.remote(obj=predict_data, name=self.transfer_inst.predict_data_by_host.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data_by_host, send_times), role=consts.GUEST, idx=0) """ def fit(self): LOGGER.info("begin to fit host decision tree") self.sync_encrypted_grad_and_hess() for dep in range(self.max_depth): self.sync_tree_node_queue(dep) if len(self.tree_node_queue) == 0: break node_positions = self.sync_node_positions(dep) self.data_bin_with_position = self.data_bin.join( node_positions, lambda v1, v2: (v1, v2)) batch = 0 for i in range(0, len(self.tree_node_queue), self.max_split_nodes): self.cur_split_nodes = self.tree_node_queue[i:i + self. max_split_nodes] node_map = {} node_num = 0 for tree_node in self.cur_split_nodes: node_map[tree_node.id] = node_num node_num += 1 acc_histograms = self.get_histograms(node_map=node_map) splitinfo_host, encrypted_splitinfo_host = self.splitter.find_split_host( acc_histograms, self.valid_features, self.data_bin._partitions, self.sitename, self.use_missing, self.zero_as_missing) self.sync_encrypted_splitinfo_host(encrypted_splitinfo_host, dep, batch) federated_best_splitinfo_host = self.sync_federated_best_splitinfo_host( dep, batch) self.sync_final_splitinfo_host(splitinfo_host, federated_best_splitinfo_host, dep, batch) batch += 1 dispatch_node_host = self.sync_dispatch_node_host(dep) self.find_dispatch(dispatch_node_host, dep) self.sync_tree() self.convert_bin_to_real() LOGGER.info("end to fit guest decision tree") def predict(self, data_inst): LOGGER.info("start to predict!") site_guest_send_times = 0 while True: finish_tag = self.sync_predict_finish_tag(site_guest_send_times) if finish_tag is True: break predict_data = self.sync_predict_data(site_guest_send_times) traverse_tree = functools.partial( self.traverse_tree, tree_=self.tree_, decoder=self.decode, split_maskdict=self.split_maskdict, sitename=self.sitename, use_missing=self.use_missing, zero_as_missing=self.zero_as_missing, missing_dir_maskdict=self.missing_dir_maskdict) predict_data = predict_data.join(data_inst, traverse_tree) self.sync_data_predicted_by_host(predict_data, site_guest_send_times) site_guest_send_times += 1 LOGGER.info("predict finish!") def get_model_meta(self): model_meta = DecisionTreeModelMeta() model_meta.max_depth = self.max_depth model_meta.min_sample_split = self.min_sample_split model_meta.min_impurity_split = self.min_impurity_split model_meta.min_leaf_node = self.min_leaf_node model_meta.use_missing = self.use_missing model_meta.zero_as_missing = self.zero_as_missing return model_meta def set_model_meta(self, model_meta): self.max_depth = model_meta.max_depth self.min_sample_split = model_meta.min_sample_split self.min_impurity_split = model_meta.min_impurity_split self.min_leaf_node = model_meta.min_leaf_node self.use_missing = model_meta.use_missing self.zero_as_missing = model_meta.zero_as_missing def get_model_param(self): model_param = DecisionTreeModelParam() for node in self.tree_: model_param.tree_.add(id=node.id, sitename=node.sitename, fid=node.fid, bid=node.bid, weight=node.weight, is_leaf=node.is_leaf, left_nodeid=node.left_nodeid, right_nodeid=node.right_nodeid, missing_dir=node.missing_dir) model_param.split_maskdict.update(self.split_maskdict) model_param.missing_dir_maskdict.update(self.missing_dir_maskdict) return model_param def set_model_param(self, model_param): self.tree_ = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_.append(_node) self.split_maskdict = dict(model_param.split_maskdict) self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict) def get_model(self): model_meta = self.get_model_meta() model_param = self.get_model_param() return model_meta, model_param def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param)
class HeteroDecisionTreeGuest(DecisionTree): def __init__(self, tree_param): LOGGER.info("hetero decision tree guest init!") super(HeteroDecisionTreeGuest, self).__init__(tree_param) self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node) self.data_bin = None self.grad_and_hess = None self.bin_split_points = None self.bin_sparse_points = None self.data_bin_with_node_dispatch = None self.node_dispatch = None self.infos = None self.valid_features = None self.encrypter = None self.encrypted_mode_calculator = None self.best_splitinfo_guest = None self.tree_node_queue = None self.cur_split_nodes = None self.tree_ = [] self.tree_node_num = 0 self.split_maskdict = {} self.missing_dir_maskdict = {} self.transfer_inst = HeteroDecisionTreeTransferVariable() self.predict_weights = None self.host_party_idlist = [] self.runtime_idx = 0 self.sitename = consts.GUEST self.feature_importances_ = {} def set_flowid(self, flowid=0): LOGGER.info("set flowid, flowid is {}".format(flowid)) self.transfer_inst.set_flowid(flowid) def set_host_party_idlist(self, host_party_idlist): self.host_party_idlist = host_party_idlist def set_inputinfo(self, data_bin=None, grad_and_hess=None, bin_split_points=None, bin_sparse_points=None): LOGGER.info("set input info") self.data_bin = data_bin self.grad_and_hess = grad_and_hess self.bin_split_points = bin_split_points self.bin_sparse_points = bin_sparse_points def set_encrypter(self, encrypter): LOGGER.info("set encrypter") self.encrypter = encrypter def set_encrypted_mode_calculator(self, encrypted_mode_calculator): self.encrypted_mode_calculator = encrypted_mode_calculator def encrypt(self, val): return self.encrypter.encrypt(val) def decrypt(self, val): return self.encrypter.decrypt(val) def encode(self, etype="feature_idx", val=None, nid=None): if etype == "feature_idx": return val if etype == "feature_val": self.split_maskdict[nid] = val return None if etype == "missing_dir": self.missing_dir_maskdict[nid] = val return None raise TypeError("encode type %s is not support!" % (str(etype))) @staticmethod def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None, missing_dir_maskdict=None): if dtype == "feature_idx": return val if dtype == "feature_val": if nid in split_maskdict: return split_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) if dtype == "missing_dir": if nid in missing_dir_maskdict: return missing_dir_maskdict[nid] else: raise ValueError( "decode val %s cause error, can't reconize it!" % (str(val))) return TypeError("decode type %s is not support!" % (str(dtype))) def set_valid_features(self, valid_features=None): LOGGER.info("set valid features") self.valid_features = valid_features def sync_encrypted_grad_and_hess(self): LOGGER.info("send encrypted grad and hess to host") encrypted_grad_and_hess = self.encrypt_grad_and_hess() # LOGGER.debug("encrypted_grad_and_hess is {}".format(list(encrypted_grad_and_hess.collect()))) self.transfer_inst.encrypted_grad_and_hess.remote( encrypted_grad_and_hess, role=consts.HOST, idx=-1) """ federation.remote(obj=encrypted_grad_and_hess, name=self.transfer_inst.encrypted_grad_and_hess.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.encrypted_grad_and_hess), role=consts.HOST, idx=-1) """ def encrypt_grad_and_hess(self): LOGGER.info("start to encrypt grad and hess") encrypted_grad_and_hess = self.encrypted_mode_calculator.encrypt( self.grad_and_hess) return encrypted_grad_and_hess def get_grad_hess_sum(self, grad_and_hess_table): LOGGER.info("calculate the sum of grad and hess") grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[ 0] + value2[0], value1[1] + value2[1])) return grad, hess def dispatch_all_node_to_root(self, root_id=0): LOGGER.info("dispatch all node to root") self.node_dispatch = self.data_bin.mapValues(lambda data_inst: (1, root_id)) def get_histograms(self, node_map={}): LOGGER.info("start to get node histograms") histograms = FeatureHistogram.calculate_histogram( self.data_bin_with_node_dispatch, self.grad_and_hess, self.bin_split_points, self.bin_sparse_points, self.valid_features, node_map, self.use_missing, self.zero_as_missing) acc_histograms = FeatureHistogram.accumulate_histogram(histograms) return acc_histograms def sync_tree_node_queue(self, tree_node_queue, dep=-1): LOGGER.info("send tree node queue of depth {}".format(dep)) mask_tree_node_queue = copy.deepcopy(tree_node_queue) for i in range(len(mask_tree_node_queue)): mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id) self.transfer_inst.tree_node_queue.remote(mask_tree_node_queue, role=consts.HOST, idx=-1, suffix=(dep, )) """ federation.remote(obj=mask_tree_node_queue, name=self.transfer_inst.tree_node_queue.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree_node_queue, dep), role=consts.HOST, idx=-1) """ def sync_node_positions(self, dep): LOGGER.info("send node positions of depth {}".format(dep)) self.transfer_inst.node_positions.remote(self.node_dispatch, role=consts.HOST, idx=-1, suffix=(dep, )) """ federation.remote(obj=self.node_dispatch, name=self.transfer_inst.node_positions.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.node_positions, dep), role=consts.HOST, idx=-1) """ def sync_encrypted_splitinfo_host(self, dep=-1, batch=-1): LOGGER.info("get encrypted splitinfo of depth {}, batch {}".format( dep, batch)) encrypted_splitinfo_host = self.transfer_inst.encrypted_splitinfo_host.get( idx=-1, suffix=( dep, batch, )) """ encrypted_splitinfo_host = federation.get(name=self.transfer_inst.encrypted_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.encrypted_splitinfo_host, dep, batch), idx=-1) """ return encrypted_splitinfo_host def sync_federated_best_splitinfo_host(self, federated_best_splitinfo_host, dep=-1, batch=-1, idx=-1): LOGGER.info( "send federated best splitinfo of depth {}, batch {}".format( dep, batch)) self.transfer_inst.federated_best_splitinfo_host.remote( federated_best_splitinfo_host, role=consts.HOST, idx=idx, suffix=( dep, batch, )) """ federation.remote(obj=federated_best_splitinfo_host, name=self.transfer_inst.federated_best_splitinfo_host.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.federated_best_splitinfo_host, dep, batch), role=consts.HOST, idx=idx) """ def find_host_split(self, value): cur_split_node, encrypted_splitinfo_host = value sum_grad = cur_split_node.sum_grad sum_hess = cur_split_node.sum_hess best_gain = self.min_impurity_split - consts.FLOAT_ZERO best_idx = -1 for i in range(len(encrypted_splitinfo_host)): sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i] sum_grad_l = self.decrypt(sum_grad_l) sum_hess_l = self.decrypt(sum_hess_l) sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l, sum_hess_l, sum_grad_r, sum_hess_r) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_idx = i best_gain = self.encrypt(best_gain) return best_idx, best_gain def federated_find_split(self, dep=-1, batch=-1): LOGGER.info("federated find split of depth {}, batch {}".format( dep, batch)) encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host( dep, batch) for i in range(len(encrypted_splitinfo_host)): encrypted_splitinfo_host_table = session.parallelize( zip(self.cur_split_nodes, encrypted_splitinfo_host[i]), include_key=False, partition=self.data_bin._partitions) splitinfos = encrypted_splitinfo_host_table.mapValues( self.find_host_split).collect() best_splitinfo_host = [splitinfo[1] for splitinfo in splitinfos] self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep, batch, i) def sync_final_split_host(self, dep=-1, batch=-1): LOGGER.info("get host final splitinfo of depth {}, batch {}".format( dep, batch)) final_splitinfo_host = self.transfer_inst.final_splitinfo_host.get( idx=-1, suffix=( dep, batch, )) """ final_splitinfo_host = federation.get(name=self.transfer_inst.final_splitinfo_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.final_splitinfo_host, dep, batch), idx=-1) """ return final_splitinfo_host def find_best_split_guest_and_host(self, splitinfo_guest_host): best_gain_host = self.decrypt(splitinfo_guest_host[1].gain) best_gain_host_idx = 1 for i in range(1, len(splitinfo_guest_host)): gain_host_i = self.decrypt(splitinfo_guest_host[i].gain) if best_gain_host < gain_host_i: best_gain_host = gain_host_i best_gain_host_idx = i if splitinfo_guest_host[0].gain >= best_gain_host - consts.FLOAT_ZERO: best_splitinfo = splitinfo_guest_host[0] else: best_splitinfo = splitinfo_guest_host[best_gain_host_idx] best_splitinfo.sum_grad = self.decrypt(best_splitinfo.sum_grad) best_splitinfo.sum_hess = self.decrypt(best_splitinfo.sum_hess) best_splitinfo.gain = best_gain_host return best_splitinfo def merge_splitinfo(self, splitinfo_guest, splitinfo_host): LOGGER.info("merge splitinfo") merge_infos = [] for i in range(len(splitinfo_guest)): splitinfo = [splitinfo_guest[i]] for j in range(len(splitinfo_host)): splitinfo.append(splitinfo_host[j][i]) merge_infos.append(splitinfo) splitinfo_guest_host_table = session.parallelize( merge_infos, include_key=False, partition=self.data_bin._partitions) best_splitinfo_table = splitinfo_guest_host_table.mapValues( self.find_best_split_guest_and_host) best_splitinfos = [ best_splitinfo[1] for best_splitinfo in best_splitinfo_table.collect() ] return best_splitinfos def update_feature_importance(self, splitinfo): if self.feature_importance_type == "split": inc = 1 elif self.feature_importance_type == "gain": inc = splitinfo.gain else: raise ValueError( "feature importance type {} not support yet".format( self.feature_importance_type)) sitename = splitinfo.sitename fid = splitinfo.best_fid if (sitename, fid) not in self.feature_importances_: self.feature_importances_[(sitename, fid)] = 0 self.feature_importances_[(sitename, fid)] += inc def update_tree_node_queue(self, splitinfos, max_depth_reach): LOGGER.info( "update tree node, splitlist length is {}, tree node queue size is" .format(len(splitinfos), len(self.tree_node_queue))) new_tree_node_queue = [] for i in range(len(self.tree_node_queue)): sum_grad = self.tree_node_queue[i].sum_grad sum_hess = self.tree_node_queue[i].sum_hess if max_depth_reach or splitinfos[i].gain <= \ self.min_impurity_split + consts.FLOAT_ZERO: self.tree_node_queue[i].is_leaf = True else: self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1 self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2 self.tree_node_num += 2 left_node = Node(id=self.tree_node_queue[i].left_nodeid, sitename=self.sitename, sum_grad=splitinfos[i].sum_grad, sum_hess=splitinfos[i].sum_hess, weight=self.splitter.node_weight( splitinfos[i].sum_grad, splitinfos[i].sum_hess)) right_node = Node(id=self.tree_node_queue[i].right_nodeid, sitename=self.sitename, sum_grad=sum_grad - splitinfos[i].sum_grad, sum_hess=sum_hess - splitinfos[i].sum_hess, weight=self.splitter.node_weight( \ sum_grad - splitinfos[i].sum_grad, sum_hess - splitinfos[i].sum_hess)) new_tree_node_queue.append(left_node) new_tree_node_queue.append(right_node) self.tree_node_queue[i].sitename = splitinfos[i].sitename if self.tree_node_queue[i].sitename == self.sitename: self.tree_node_queue[i].fid = self.encode( "feature_idx", splitinfos[i].best_fid) self.tree_node_queue[i].bid = self.encode( "feature_val", splitinfos[i].best_bid, self.tree_node_queue[i].id) self.tree_node_queue[i].missing_dir = self.encode( "missing_dir", splitinfos[i].missing_dir, self.tree_node_queue[i].id) else: self.tree_node_queue[i].fid = splitinfos[i].best_fid self.tree_node_queue[i].bid = splitinfos[i].best_bid self.update_feature_importance(splitinfos[i]) self.tree_.append(self.tree_node_queue[i]) self.tree_node_queue = new_tree_node_queue @staticmethod def dispatch_node(value, tree_=None, decoder=None, sitename=consts.GUEST, split_maskdict=None, bin_sparse_points=None, use_missing=False, zero_as_missing=False, missing_dir_maskdict=None): unleaf_state, nodeid = value[1] if tree_[nodeid].is_leaf is True: return tree_[nodeid].weight else: if tree_[nodeid].sitename == sitename: fid = decoder("feature_idx", tree_[nodeid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nodeid].bid, nodeid, split_maskdict=split_maskdict) if not use_missing: if value[0].features.get_data( fid, bin_sparse_points[fid]) <= bid: return 1, tree_[nodeid].left_nodeid else: return 1, tree_[nodeid].right_nodeid else: missing_dir = decoder( "missing_dir", tree_[nodeid].missing_dir, nodeid, missing_dir_maskdict=missing_dir_maskdict) missing_val = False if zero_as_missing: if value[0].features.get_data(fid, None) is None or \ value[0].features.get_data(fid) == NoneType(): missing_val = True elif use_missing and value[0].features.get_data( fid) == NoneType(): missing_val = True if missing_val: if missing_dir == 1: return 1, tree_[nodeid].right_nodeid else: return 1, tree_[nodeid].left_nodeid else: LOGGER.debug( "fid is {}, bid is {}, sitename is {}".format( fid, bid, sitename)) if value[0].features.get_data( fid, bin_sparse_points[fid]) <= bid: return 1, tree_[nodeid].left_nodeid else: return 1, tree_[nodeid].right_nodeid else: return (1, tree_[nodeid].fid, tree_[nodeid].bid, tree_[nodeid].sitename, nodeid, tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid) def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1): LOGGER.info("send node to host to dispath, depth is {}".format(dep)) self.transfer_inst.dispatch_node_host.remote(dispatch_guest_data, role=consts.HOST, idx=-1, suffix=(dep, )) """ federation.remote(obj=dispatch_guest_data, name=self.transfer_inst.dispatch_node_host.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.dispatch_node_host, dep), role=consts.HOST, idx=-1) """ def sync_dispatch_node_host_result(self, dep=-1): LOGGER.info("get host dispatch result, depth is {}".format(dep)) dispatch_node_host_result = self.transfer_inst.dispatch_node_host_result.get( idx=-1, suffix=(dep, )) """ dispatch_node_host_result = federation.get(name=self.transfer_inst.dispatch_node_host_result.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.dispatch_node_host_result, dep), idx=-1) """ return dispatch_node_host_result def redispatch_node(self, dep=-1): LOGGER.info("redispatch node of depth {}".format(dep)) dispatch_node_method = functools.partial( self.dispatch_node, tree_=self.tree_, decoder=self.decode, sitename=self.sitename, split_maskdict=self.split_maskdict, bin_sparse_points=self.bin_sparse_points, use_missing=self.use_missing, zero_as_missing=self.zero_as_missing, missing_dir_maskdict=self.missing_dir_maskdict) dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues( dispatch_node_method) tree_node_num = self.tree_node_num LOGGER.info("remask dispatch node result of depth {}".format(dep)) dispatch_to_host_result = dispatch_guest_result.filter( lambda key, value: isinstance(value, tuple) and len(value) > 2) dispatch_guest_result = dispatch_guest_result.subtractByKey( dispatch_to_host_result) leaf = dispatch_guest_result.filter( lambda key, value: isinstance(value, tuple) is False) if self.predict_weights is None: self.predict_weights = leaf else: self.predict_weights = self.predict_weights.union(leaf) dispatch_guest_result = dispatch_guest_result.subtractByKey(leaf) self.sync_dispatch_node_host(dispatch_to_host_result, dep) dispatch_node_host_result = self.sync_dispatch_node_host_result(dep) self.node_dispatch = None for idx in range(len(dispatch_node_host_result)): if self.node_dispatch is None: self.node_dispatch = dispatch_node_host_result[idx] else: self.node_dispatch = self.node_dispatch.join(dispatch_node_host_result[idx], \ lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \ unleaf_state_nodeid1 if len( unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2) self.node_dispatch = self.node_dispatch.union(dispatch_guest_result) def sync_tree(self): LOGGER.info("sync tree to host") self.transfer_inst.tree.remote(self.tree_, role=consts.HOST, idx=-1) """ federation.remote(obj=self.tree_, name=self.transfer_inst.tree.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree), role=consts.HOST, idx=-1) """ def convert_bin_to_real(self): LOGGER.info("convert tree node bins to real value") for i in range(len(self.tree_)): if self.tree_[i].is_leaf is True: continue if self.tree_[i].sitename == self.sitename: fid = self.decode("feature_idx", self.tree_[i].fid, split_maskdict=self.split_maskdict) bid = self.decode("feature_val", self.tree_[i].bid, self.tree_[i].id, self.split_maskdict) real_splitval = self.encode("feature_val", self.bin_split_points[fid][bid], self.tree_[i].id) self.tree_[i].bid = real_splitval def fit(self): LOGGER.info("begin to fit guest decision tree") self.sync_encrypted_grad_and_hess() # LOGGER.debug("self.grad and hess is {}".format(list(self.grad_and_hess.collect()))) root_sum_grad, root_sum_hess = self.get_grad_hess_sum( self.grad_and_hess) root_node = Node(id=0, sitename=self.sitename, sum_grad=root_sum_grad, sum_hess=root_sum_hess, weight=self.splitter.node_weight( root_sum_grad, root_sum_hess)) self.tree_node_queue = [root_node] self.dispatch_all_node_to_root() for dep in range(self.max_depth): LOGGER.info( "start to fit depth {}, tree node queue size is {}".format( dep, len(self.tree_node_queue))) self.sync_tree_node_queue(self.tree_node_queue, dep) if len(self.tree_node_queue) == 0: break self.sync_node_positions(dep) self.data_bin_with_node_dispatch = self.data_bin.join( self.node_dispatch, lambda data_inst, dispatch_info: (data_inst, dispatch_info)) batch = 0 splitinfos = [] for i in range(0, len(self.tree_node_queue), self.max_split_nodes): self.cur_split_nodes = self.tree_node_queue[i:i + self. max_split_nodes] node_map = {} node_num = 0 for tree_node in self.cur_split_nodes: node_map[tree_node.id] = node_num node_num += 1 acc_histograms = self.get_histograms(node_map=node_map) self.best_splitinfo_guest = self.splitter.find_split( acc_histograms, self.valid_features, self.data_bin._partitions, self.sitename, self.use_missing, self.zero_as_missing) self.federated_find_split(dep, batch) final_splitinfo_host = self.sync_final_split_host(dep, batch) cur_splitinfos = self.merge_splitinfo( self.best_splitinfo_guest, final_splitinfo_host) splitinfos.extend(cur_splitinfos) batch += 1 max_depth_reach = True if dep + 1 == self.max_depth else False self.update_tree_node_queue(splitinfos, max_depth_reach) self.redispatch_node(dep) self.sync_tree() self.convert_bin_to_real() tree_ = self.tree_ LOGGER.info("tree node num is %d" % len(tree_)) LOGGER.info("end to fit guest decision tree") @staticmethod def traverse_tree(predict_state, data_inst, tree_=None, decoder=None, sitename=consts.GUEST, split_maskdict=None, use_missing=None, zero_as_missing=None, missing_dir_maskdict=None): nid, tag = predict_state while tree_[nid].sitename == sitename: if tree_[nid].is_leaf is True: return tree_[nid].weight fid = decoder("feature_idx", tree_[nid].fid, split_maskdict=split_maskdict) bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict=split_maskdict) if use_missing: missing_dir = decoder( "missing_dir", 1, nid, missing_dir_maskdict=missing_dir_maskdict) else: missing_dir = 1 if use_missing and zero_as_missing: missing_dir = decoder( "missing_dir", 1, nid, missing_dir_maskdict=missing_dir_maskdict) if data_inst.features.get_data(fid) == NoneType( ) or data_inst.features.get_data(fid, None) is None: if missing_dir == 1: nid = tree_[nid].right_nodeid else: nid = tree_[nid].left_nodeid elif data_inst.features.get_data(fid) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid elif data_inst.features.get_data(fid) == NoneType(): if missing_dir == 1: nid = tree_[nid].right_nodeid else: nid = tree_[nid].left_nodeid elif data_inst.features.get_data(fid, 0) <= bid: nid = tree_[nid].left_nodeid else: nid = tree_[nid].right_nodeid return nid, 1 def sync_predict_finish_tag(self, finish_tag, send_times): LOGGER.info("send the {}-th predict finish tag {} to host".format( finish_tag, send_times)) self.transfer_inst.predict_finish_tag.remote(finish_tag, role=consts.HOST, idx=-1, suffix=(send_times, )) """ federation.remote(obj=finish_tag, name=self.transfer_inst.predict_finish_tag.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_finish_tag, send_times), role=consts.HOST, idx=-1) """ def sync_predict_data(self, predict_data, send_times): LOGGER.info("send predict data to host, sending times is {}".format( send_times)) self.transfer_inst.predict_data.remote(predict_data, role=consts.HOST, idx=-1, suffix=(send_times, )) """ federation.remote(obj=predict_data, name=self.transfer_inst.predict_data.name, tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data, send_times), role=consts.HOST, idx=-1) """ def sync_data_predicted_by_host(self, send_times): LOGGER.info( "get predicted data by host, recv times is {}".format(send_times)) predict_data = self.transfer_inst.predict_data_by_host.get( idx=-1, suffix=(send_times, )) """ predict_data = federation.get(name=self.transfer_inst.predict_data_by_host.name, tag=self.transfer_inst.generate_transferid( self.transfer_inst.predict_data_by_host, send_times), idx=-1) """ return predict_data def predict(self, data_inst): LOGGER.info("start to predict!") predict_data = data_inst.mapValues(lambda data_inst: (0, 1)) site_host_send_times = 0 predict_result = None while True: traverse_tree = functools.partial( self.traverse_tree, tree_=self.tree_, decoder=self.decode, sitename=self.sitename, split_maskdict=self.split_maskdict, use_missing=self.use_missing, zero_as_missing=self.zero_as_missing, missing_dir_maskdict=self.missing_dir_maskdict) predict_data = predict_data.join(data_inst, traverse_tree) predict_leaf = predict_data.filter( lambda key, value: isinstance(value, tuple) is False) if predict_result is None: predict_result = predict_leaf else: predict_result = predict_result.union(predict_leaf) predict_data = predict_data.subtractByKey(predict_leaf) unleaf_node_count = predict_data.count() if unleaf_node_count == 0: self.sync_predict_finish_tag(True, site_host_send_times) break self.sync_predict_finish_tag(False, site_host_send_times) self.sync_predict_data(predict_data, site_host_send_times) predict_data_host = self.sync_data_predicted_by_host( site_host_send_times) for i in range(len(predict_data_host)): predict_data = predict_data.join( predict_data_host[i], lambda state1_nodeid1, state2_nodeid2: state1_nodeid1 if state1_nodeid1[1] == 0 else state2_nodeid2) site_host_send_times += 1 LOGGER.info("predict finish!") return predict_result def get_model_meta(self): model_meta = DecisionTreeModelMeta() model_meta.criterion_meta.CopyFrom( CriterionMeta(criterion_method=self.criterion_method, criterion_param=self.criterion_params)) model_meta.max_depth = self.max_depth model_meta.min_sample_split = self.min_sample_split model_meta.min_impurity_split = self.min_impurity_split model_meta.min_leaf_node = self.min_leaf_node model_meta.use_missing = self.use_missing model_meta.zero_as_missing = self.zero_as_missing return model_meta def set_model_meta(self, model_meta): self.max_depth = model_meta.max_depth self.min_sample_split = model_meta.min_sample_split self.min_impurity_split = model_meta.min_impurity_split self.min_leaf_node = model_meta.min_leaf_node self.criterion_method = model_meta.criterion_meta.criterion_method self.criterion_params = list(model_meta.criterion_meta.criterion_param) self.use_missing = model_meta.use_missing self.zero_as_missing = model_meta.zero_as_missing def get_model_param(self): model_param = DecisionTreeModelParam() for node in self.tree_: model_param.tree_.add(id=node.id, sitename=node.sitename, fid=node.fid, bid=node.bid, weight=node.weight, is_leaf=node.is_leaf, left_nodeid=node.left_nodeid, right_nodeid=node.right_nodeid, missing_dir=node.missing_dir) LOGGER.debug( "missing_dir is {}, sitename is {}, is_leaf is {}".format( node.missing_dir, node.sitename, node.is_leaf)) model_param.split_maskdict.update(self.split_maskdict) model_param.missing_dir_maskdict.update(self.missing_dir_maskdict) return model_param def set_model_param(self, model_param): self.tree_ = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_.append(_node) self.split_maskdict = dict(model_param.split_maskdict) self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict) def get_model(self): model_meta = self.get_model_meta() model_param = self.get_model_param() return model_meta, model_param def load_model(self, model_meta=None, model_param=None): LOGGER.info("load tree model") self.set_model_meta(model_meta) self.set_model_param(model_param) def get_feature_importance(self): return self.feature_importances_