def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.data_bin_with_node_dispatch = None
        self.node_dispatch = None
        self.infos = None
        self.valid_features = None
        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.best_splitinfo_guest = None
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.tree_ = []
        self.tree_node_num = 0
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.predict_weights = None
        self.host_party_idlist = []
        self.runtime_idx = 0
        self.sitename = consts.GUEST
        self.feature_importances_ = {}
示例#2
0
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.data_bin_with_position = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.infos = None
        self.valid_features = None
        self.pubkey = None
        self.privakey = None
        self.tree_id = None
        self.encrypted_grad_and_hess = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.tree_ = None
        self.runtime_idx = 0
        self.sitename = consts.HOST
    def __init__(self, tree_param):
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)

        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()

        self.sitename = consts.GUEST  # will be modified in self.set_runtime_idx()
        self.complete_secure_tree = False
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.host_party_idlist = []
        self.compressor = None

        # goss subsample
        self.run_goss = False
        self.top_rate, self.other_rate = 0.2, 0.1  # goss sampling rate

        # cipher compressing
        self.cipher_encoder = None
        self.cipher_decompressor = None
        self.run_cipher_compressing = False
        self.key_length = None
        self.round_decimal = 7
        self.max_sample_weight = 1

        # code version control
        self.new_ver = True
    def __init__(self, tree_param):
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)

        # In FATE-1.8 reset feature importance to 'split'
        self.feature_importance_type = 'split'

        self.encrypter = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()

        self.sitename = consts.GUEST  # will be modified in self.set_runtime_idx()
        self.complete_secure_tree = False
        self.split_maskdict = {}  # save split value
        self.missing_dir_maskdict = {}  # save missing dir
        self.host_party_idlist = []
        self.compressor = None

        # goss subsample
        self.run_goss = False

        # cipher compressing
        self.task_type = None
        self.run_cipher_compressing = True
        self.packer = None
        self.max_sample_weight = 1

        # code version control
        self.new_ver = True

        # mo tree
        self.mo_tree = False
        self.class_num = 1
    def __init__(self, tree_param):

        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.encrypted_grad_and_hess = None
        self.runtime_idx = 0
        self.sitename = consts.HOST  # will be modified in self.set_runtime_idx()
        self.complete_secure_tree = False
        self.host_party_idlist = []

        # feature shuffling / missing_dir masking
        self.feature_num = -1
        self.missing_dir_mask_left = {}  # mask for left direction
        self.missing_dir_mask_right = {}  # mask for right direction
        self.split_maskdict = {}  # mask for split value
        self.missing_dir_maskdict = {}
        self.fid_bid_random_mapping = {}
        self.inverse_fid_bid_random_mapping = {}
        self.bin_num = None

        # goss subsample
        self.run_goss = False

        # transfer variable
        self.transfer_inst = HeteroDecisionTreeTransferVariable()

        # cipher compressing
        self.cipher_compressor = None
        self.run_cipher_compressing = True

        # code version control
        self.new_ver = True
示例#6
0
    def __init__(self, tree_param):
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()

        self.sitename = consts.GUEST  # will be modified in self.set_runtime_idx()
        self.complete_secure_tree = False
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}

        self.host_party_idlist = []
示例#7
0
    def __init__(self, tree_param):

        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.encrypted_grad_and_hess = None
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.runtime_idx = 0
        self.sitename = consts.HOST  # will be modified in self.set_runtime_idx()
        self.complete_secure_tree = False
        self.host_party_idlist = []

        # For fast histogram
        self.run_sparse_opt = False
        self.bin_num = None
        self.data_bin_dense = None
        self.data_bin_dense_with_position = None

        self.transfer_inst = HeteroDecisionTreeTransferVariable()
示例#8
0
class HeteroDecisionTreeHost(DecisionTree):
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.data_bin_with_position = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.infos = None
        self.valid_features = None
        self.pubkey = None
        self.privakey = None
        self.tree_id = None
        self.encrypted_grad_and_hess = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.tree_ = None
        self.runtime_idx = 0
        self.sitename = consts.HOST

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    # def set_runtime_idx(self, runtime_idx):
    #     self.runtime_idx = runtime_idx
    #     self.sitename = ":".join([consts.HOST, str(self.runtime_idx)])

    def set_inputinfo(self,
                      data_bin=None,
                      grad_and_hess=None,
                      bin_split_points=None,
                      bin_sparse_points=None):
        LOGGER.info("set input info")
        self.data_bin = data_bin
        self.grad_and_hess = grad_and_hess
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_points

    def set_valid_features(self, valid_features=None):
        LOGGER.info("set valid features")
        self.valid_features = valid_features

    def encode(self, etype="feature_idx", val=None, nid=None):
        if etype == "feature_idx":
            return val

        if etype == "feature_val":
            self.split_maskdict[nid] = val
            return None

        if etype == "missing_dir":
            self.missing_dir_maskdict[nid] = val
            return None

        raise TypeError("encode type %s is not support!" % (str(etype)))

    @staticmethod
    def decode(dtype="feature_idx",
               val=None,
               nid=None,
               split_maskdict=None,
               missing_dir_maskdict=None):
        if dtype == "feature_idx":
            return val

        if dtype == "feature_val":
            if nid in split_maskdict:
                return split_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        if dtype == "missing_dir":
            if nid in missing_dir_maskdict:
                return missing_dir_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        return TypeError("decode type %s is not support!" % (str(dtype)))

    def sync_encrypted_grad_and_hess(self):
        LOGGER.info("get encrypted grad and hess")
        self.grad_and_hess = self.transfer_inst.encrypted_grad_and_hess.get(
            idx=0)
        """
        self.grad_and_hess = federation.get(name=self.transfer_inst.encrypted_grad_and_hess.name,
                                            tag=self.transfer_inst.generate_transferid(
                                                self.transfer_inst.encrypted_grad_and_hess),
                                            idx=0)
        """

    def sync_node_positions(self, dep=-1):
        LOGGER.info("get tree node queue of depth {}".format(dep))
        node_positions = self.transfer_inst.node_positions.get(idx=0,
                                                               suffix=(dep, ))
        """
        node_positions = federation.get(name=self.transfer_inst.node_positions.name,
                                        tag=self.transfer_inst.generate_transferid(self.transfer_inst.node_positions,
                                                                                   dep),
                                        idx=0)
        """
        return node_positions

    def sync_tree_node_queue(self, dep=-1):
        LOGGER.info("get tree node queue of depth {}".format(dep))
        self.tree_node_queue = self.transfer_inst.tree_node_queue.get(
            idx=0, suffix=(dep, ))
        """
        self.tree_node_queue = federation.get(name=self.transfer_inst.tree_node_queue.name,
                                              tag=self.transfer_inst.generate_transferid(
                                                  self.transfer_inst.tree_node_queue, dep),
                                              idx=0)
        """

    def get_histograms(self, node_map={}):
        LOGGER.info("start to get node histograms")
        histograms = FeatureHistogram.calculate_histogram(
            self.data_bin_with_position, self.grad_and_hess,
            self.bin_split_points, self.bin_sparse_points, self.valid_features,
            node_map, self.use_missing, self.zero_as_missing)
        LOGGER.info("begin to accumulate histograms")
        acc_histograms = FeatureHistogram.accumulate_histogram(histograms)
        LOGGER.info("acc histogram shape is {}".format(len(acc_histograms)))
        return acc_histograms

    def sync_encrypted_splitinfo_host(self,
                                      encrypted_splitinfo_host,
                                      dep=-1,
                                      batch=-1):
        LOGGER.info("send encrypted splitinfo of depth {}, batch {}".format(
            dep, batch))
        self.transfer_inst.encrypted_splitinfo_host.remote(
            encrypted_splitinfo_host,
            role=consts.GUEST,
            idx=-1,
            suffix=(
                dep,
                batch,
            ))
        """
        federation.remote(obj=encrypted_splitinfo_host,
                          name=self.transfer_inst.encrypted_splitinfo_host.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.encrypted_splitinfo_host, dep,
                                                                     batch),
                          role=consts.GUEST,
                          idx=-1)
        """

    def sync_federated_best_splitinfo_host(self, dep=-1, batch=-1):
        LOGGER.info(
            "get federated best splitinfo of depth {}, batch {}".format(
                dep, batch))
        federated_best_splitinfo_host = self.transfer_inst.federated_best_splitinfo_host.get(
            idx=0, suffix=(
                dep,
                batch,
            ))
        """
        federated_best_splitinfo_host = federation.get(name=self.transfer_inst.federated_best_splitinfo_host.name,
                                                       tag=self.transfer_inst.generate_transferid(
                                                           self.transfer_inst.federated_best_splitinfo_host, dep,
                                                           batch),
                                                       idx=0)
        """

        return federated_best_splitinfo_host

    def sync_final_splitinfo_host(self,
                                  splitinfo_host,
                                  federated_best_splitinfo_host,
                                  dep=-1,
                                  batch=-1):
        LOGGER.info("send host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                assert splitinfo_host[i][best_idx].sitename == self.sitename
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = self.encode("feature_idx",
                                                 splitinfo.best_fid)
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = self.encode("feature_val",
                                                 splitinfo.best_bid,
                                                 self.cur_split_nodes[i].id)
                splitinfo.missing_dir = self.encode("missing_dir",
                                                    splitinfo.missing_dir,
                                                    self.cur_split_nodes[i].id)
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=self.sitename,
                                      best_fid=-1,
                                      best_bid=-1,
                                      gain=best_gain)

            final_splitinfos.append(splitinfo)

        self.transfer_inst.final_splitinfo_host.remote(final_splitinfos,
                                                       role=consts.GUEST,
                                                       idx=-1,
                                                       suffix=(
                                                           dep,
                                                           batch,
                                                       ))
        """
        federation.remote(obj=final_splitinfos,
                          name=self.transfer_inst.final_splitinfo_host.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.final_splitinfo_host, dep,
                                                                     batch),
                          role=consts.GUEST,
                          idx=-1)
        """

    def sync_dispatch_node_host(self, dep):
        LOGGER.info("get node from host to dispath, depth is {}".format(dep))
        dispatch_node_host = self.transfer_inst.dispatch_node_host.get(
            idx=0, suffix=(dep, ))
        """
        dispatch_node_host = federation.get(name=self.transfer_inst.dispatch_node_host.name,
                                            tag=self.transfer_inst.generate_transferid(
                                                self.transfer_inst.dispatch_node_host, dep),
                                            idx=0)
        """
        return dispatch_node_host

    @staticmethod
    def dispatch_node(value1,
                      value2,
                      sitename=None,
                      decoder=None,
                      split_maskdict=None,
                      bin_sparse_points=None,
                      use_missing=False,
                      zero_as_missing=False,
                      missing_dir_maskdict=None):

        unleaf_state, fid, bid, node_sitename, nodeid, left_nodeid, right_nodeid = value1
        if node_sitename != sitename:
            return value1

        fid = decoder("feature_idx", fid, split_maskdict=split_maskdict)
        bid = decoder("feature_val",
                      bid,
                      nodeid,
                      split_maskdict=split_maskdict)
        if not use_missing:
            if value2.features.get_data(fid, bin_sparse_points[fid]) <= bid:
                return unleaf_state, left_nodeid
            else:
                return unleaf_state, right_nodeid
        else:
            missing_dir = decoder("missing_dir",
                                  1,
                                  nodeid,
                                  missing_dir_maskdict=missing_dir_maskdict)
            missing_val = False
            if zero_as_missing:
                if value2.features.get_data(fid, None) is None or \
                        value2.features.get_data(fid) == NoneType():
                    missing_val = True
            elif use_missing and value2.features.get_data(fid) == NoneType():
                missing_val = True

            if missing_val:
                if missing_dir == 1:
                    return unleaf_state, right_nodeid
                else:
                    return unleaf_state, left_nodeid
            else:
                if value2.features.get_data(fid,
                                            bin_sparse_points[fid]) <= bid:
                    return unleaf_state, left_nodeid
                else:
                    return unleaf_state, right_nodeid

    def sync_dispatch_node_host_result(self,
                                       dispatch_node_host_result,
                                       dep=-1):
        LOGGER.info("send host dispatch result, depth is {}".format(dep))

        self.transfer_inst.dispatch_node_host_result.remote(
            dispatch_node_host_result,
            role=consts.GUEST,
            idx=-1,
            suffix=(dep, ))
        """
        federation.remote(obj=dispatch_node_host_result,
                          name=self.transfer_inst.dispatch_node_host_result.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.dispatch_node_host_result, dep),
                          role=consts.GUEST,
                          idx=-1)
        """

    def find_dispatch(self, dispatch_node_host, dep=-1):
        LOGGER.info("start to find host dispath of depth {}".format(dep))
        dispatch_node_method = functools.partial(
            self.dispatch_node,
            sitename=self.sitename,
            decoder=self.decode,
            split_maskdict=self.split_maskdict,
            bin_sparse_points=self.bin_sparse_points,
            use_missing=self.use_missing,
            zero_as_missing=self.zero_as_missing,
            missing_dir_maskdict=self.missing_dir_maskdict)
        dispatch_node_host_result = dispatch_node_host.join(
            self.data_bin, dispatch_node_method)
        self.sync_dispatch_node_host_result(dispatch_node_host_result, dep)

    def sync_tree(self):
        LOGGER.info("sync tree from guest")
        self.tree_ = self.transfer_inst.tree.get(idx=0)
        """
        self.tree_ = federation.get(name=self.transfer_inst.tree.name,
                                    tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree),
                                    idx=0)
        """

    def remove_duplicated_split_nodes(self, split_nid_used):
        LOGGER.info("remove duplicated nodes from split mask dict")
        duplicated_nodes = set(
            self.split_maskdict.keys()) - set(split_nid_used)
        for nid in duplicated_nodes:
            del self.split_maskdict[nid]

    def convert_bin_to_real(self):
        LOGGER.info("convert tree node bins to real value")
        split_nid_used = []
        for i in range(len(self.tree_)):
            if self.tree_[i].is_leaf is True:
                continue

            if self.tree_[i].sitename == self.sitename:
                fid = self.decode("feature_idx",
                                  self.tree_[i].fid,
                                  split_maskdict=self.split_maskdict)
                bid = self.decode("feature_val", self.tree_[i].bid,
                                  self.tree_[i].id, self.split_maskdict)
                LOGGER.debug("shape of bin_split_points is {}".format(
                    len(self.bin_split_points[fid])))
                real_splitval = self.encode("feature_val",
                                            self.bin_split_points[fid][bid],
                                            self.tree_[i].id)
                self.tree_[i].bid = real_splitval

                split_nid_used.append(self.tree_[i].id)

        self.remove_duplicated_split_nodes(split_nid_used)

    @staticmethod
    def traverse_tree(predict_state,
                      data_inst,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None,
                      sitename=consts.HOST,
                      use_missing=False,
                      zero_as_missing=False,
                      missing_dir_maskdict=None):

        nid, _ = predict_state
        if tree_[nid].sitename != sitename:
            return predict_state

        while tree_[nid].sitename == sitename:
            fid = decoder("feature_idx",
                          tree_[nid].fid,
                          split_maskdict=split_maskdict)
            bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict)

            if use_missing:
                missing_dir = decoder(
                    "missing_dir",
                    1,
                    nid,
                    missing_dir_maskdict=missing_dir_maskdict)
            else:
                missing_dir = 1

            if use_missing and zero_as_missing:
                missing_dir = decoder(
                    "missing_dir",
                    1,
                    nid,
                    missing_dir_maskdict=missing_dir_maskdict)
                if data_inst.features.get_data(fid) == NoneType(
                ) or data_inst.features.get_data(fid, None) is None:
                    if missing_dir == 1:
                        nid = tree_[nid].right_nodeid
                    else:
                        nid = tree_[nid].left_nodeid
                elif data_inst.features.get_data(fid) <= bid:
                    nid = tree_[nid].left_nodeid
                else:
                    nid = tree_[nid].right_nodeid
            elif data_inst.features.get_data(fid) == NoneType():
                if missing_dir == 1:
                    nid = tree_[nid].right_nodeid
                else:
                    nid = tree_[nid].left_nodeid
            elif data_inst.features.get_data(fid, 0) <= bid:
                nid = tree_[nid].left_nodeid
            else:
                nid = tree_[nid].right_nodeid

        return nid, 0

    def sync_predict_finish_tag(self, recv_times):
        LOGGER.info(
            "get the {}-th predict finish tag from guest".format(recv_times))
        finish_tag = self.transfer_inst.predict_finish_tag.get(
            idx=0, suffix=(recv_times, ))
        """
        finish_tag = federation.get(name=self.transfer_inst.predict_finish_tag.name,
                                    tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_finish_tag,
                                                                               recv_times),
                                    idx=0)
        """

        return finish_tag

    def sync_predict_data(self, recv_times):
        LOGGER.info(
            "srecv predict data to host, recv times is {}".format(recv_times))
        predict_data = self.transfer_inst.predict_data.get(
            idx=0, suffix=(recv_times, ))
        """
        predict_data = federation.get(name=self.transfer_inst.predict_data.name,
                                      tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data,
                                                                                 recv_times),
                                      idx=0)
        """

        return predict_data

    def sync_data_predicted_by_host(self, predict_data, send_times):
        LOGGER.info(
            "send predicted data by host, send times is {}".format(send_times))

        self.transfer_inst.predict_data_by_host.remote(predict_data,
                                                       role=consts.GUEST,
                                                       idx=0,
                                                       suffix=(send_times, ))
        """
        federation.remote(obj=predict_data,
                          name=self.transfer_inst.predict_data_by_host.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data_by_host,
                                                                     send_times),
                          role=consts.GUEST,
                          idx=0)
        """

    def fit(self):
        LOGGER.info("begin to fit host decision tree")
        self.sync_encrypted_grad_and_hess()

        for dep in range(self.max_depth):
            self.sync_tree_node_queue(dep)
            if len(self.tree_node_queue) == 0:
                break

            node_positions = self.sync_node_positions(dep)
            self.data_bin_with_position = self.data_bin.join(
                node_positions, lambda v1, v2: (v1, v2))

            batch = 0
            for i in range(0, len(self.tree_node_queue), self.max_split_nodes):
                self.cur_split_nodes = self.tree_node_queue[i:i + self.
                                                            max_split_nodes]
                node_map = {}
                node_num = 0
                for tree_node in self.cur_split_nodes:
                    node_map[tree_node.id] = node_num
                    node_num += 1

                acc_histograms = self.get_histograms(node_map=node_map)

                splitinfo_host, encrypted_splitinfo_host = self.splitter.find_split_host(
                    acc_histograms, self.valid_features,
                    self.data_bin._partitions, self.sitename, self.use_missing,
                    self.zero_as_missing)

                self.sync_encrypted_splitinfo_host(encrypted_splitinfo_host,
                                                   dep, batch)
                federated_best_splitinfo_host = self.sync_federated_best_splitinfo_host(
                    dep, batch)
                self.sync_final_splitinfo_host(splitinfo_host,
                                               federated_best_splitinfo_host,
                                               dep, batch)

                batch += 1

            dispatch_node_host = self.sync_dispatch_node_host(dep)
            self.find_dispatch(dispatch_node_host, dep)

        self.sync_tree()
        self.convert_bin_to_real()

        LOGGER.info("end to fit guest decision tree")

    def predict(self, data_inst):
        LOGGER.info("start to predict!")
        site_guest_send_times = 0
        while True:
            finish_tag = self.sync_predict_finish_tag(site_guest_send_times)
            if finish_tag is True:
                break

            predict_data = self.sync_predict_data(site_guest_send_times)

            traverse_tree = functools.partial(
                self.traverse_tree,
                tree_=self.tree_,
                decoder=self.decode,
                split_maskdict=self.split_maskdict,
                sitename=self.sitename,
                use_missing=self.use_missing,
                zero_as_missing=self.zero_as_missing,
                missing_dir_maskdict=self.missing_dir_maskdict)
            predict_data = predict_data.join(data_inst, traverse_tree)

            self.sync_data_predicted_by_host(predict_data,
                                             site_guest_send_times)

            site_guest_send_times += 1

        LOGGER.info("predict finish!")

    def get_model_meta(self):
        model_meta = DecisionTreeModelMeta()

        model_meta.max_depth = self.max_depth
        model_meta.min_sample_split = self.min_sample_split
        model_meta.min_impurity_split = self.min_impurity_split
        model_meta.min_leaf_node = self.min_leaf_node
        model_meta.use_missing = self.use_missing
        model_meta.zero_as_missing = self.zero_as_missing

        return model_meta

    def set_model_meta(self, model_meta):
        self.max_depth = model_meta.max_depth
        self.min_sample_split = model_meta.min_sample_split
        self.min_impurity_split = model_meta.min_impurity_split
        self.min_leaf_node = model_meta.min_leaf_node
        self.use_missing = model_meta.use_missing
        self.zero_as_missing = model_meta.zero_as_missing

    def get_model_param(self):
        model_param = DecisionTreeModelParam()
        for node in self.tree_:
            model_param.tree_.add(id=node.id,
                                  sitename=node.sitename,
                                  fid=node.fid,
                                  bid=node.bid,
                                  weight=node.weight,
                                  is_leaf=node.is_leaf,
                                  left_nodeid=node.left_nodeid,
                                  right_nodeid=node.right_nodeid,
                                  missing_dir=node.missing_dir)

        model_param.split_maskdict.update(self.split_maskdict)
        model_param.missing_dir_maskdict.update(self.missing_dir_maskdict)

        return model_param

    def set_model_param(self, model_param):
        self.tree_ = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid,
                         missing_dir=node_param.missing_dir)

            self.tree_.append(_node)

        self.split_maskdict = dict(model_param.split_maskdict)
        self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict)

    def get_model(self):
        model_meta = self.get_model_meta()
        model_param = self.get_model_param()

        return model_meta, model_param

    def load_model(self, model_meta=None, model_param=None):
        LOGGER.info("load tree model")
        self.set_model_meta(model_meta)
        self.set_model_param(model_param)
class HeteroDecisionTreeGuest(DecisionTree):
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.data_bin_with_node_dispatch = None
        self.node_dispatch = None
        self.infos = None
        self.valid_features = None
        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.best_splitinfo_guest = None
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.tree_ = []
        self.tree_node_num = 0
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.predict_weights = None
        self.host_party_idlist = []
        self.runtime_idx = 0
        self.sitename = consts.GUEST
        self.feature_importances_ = {}

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def set_host_party_idlist(self, host_party_idlist):
        self.host_party_idlist = host_party_idlist

    def set_inputinfo(self,
                      data_bin=None,
                      grad_and_hess=None,
                      bin_split_points=None,
                      bin_sparse_points=None):
        LOGGER.info("set input info")
        self.data_bin = data_bin
        self.grad_and_hess = grad_and_hess
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_points

    def set_encrypter(self, encrypter):
        LOGGER.info("set encrypter")
        self.encrypter = encrypter

    def set_encrypted_mode_calculator(self, encrypted_mode_calculator):
        self.encrypted_mode_calculator = encrypted_mode_calculator

    def encrypt(self, val):
        return self.encrypter.encrypt(val)

    def decrypt(self, val):
        return self.encrypter.decrypt(val)

    def encode(self, etype="feature_idx", val=None, nid=None):
        if etype == "feature_idx":
            return val

        if etype == "feature_val":
            self.split_maskdict[nid] = val
            return None

        if etype == "missing_dir":
            self.missing_dir_maskdict[nid] = val
            return None

        raise TypeError("encode type %s is not support!" % (str(etype)))

    @staticmethod
    def decode(dtype="feature_idx",
               val=None,
               nid=None,
               split_maskdict=None,
               missing_dir_maskdict=None):
        if dtype == "feature_idx":
            return val

        if dtype == "feature_val":
            if nid in split_maskdict:
                return split_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        if dtype == "missing_dir":
            if nid in missing_dir_maskdict:
                return missing_dir_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        return TypeError("decode type %s is not support!" % (str(dtype)))

    def set_valid_features(self, valid_features=None):
        LOGGER.info("set valid features")
        self.valid_features = valid_features

    def sync_encrypted_grad_and_hess(self):
        LOGGER.info("send encrypted grad and hess to host")
        encrypted_grad_and_hess = self.encrypt_grad_and_hess()
        # LOGGER.debug("encrypted_grad_and_hess is {}".format(list(encrypted_grad_and_hess.collect())))

        self.transfer_inst.encrypted_grad_and_hess.remote(
            encrypted_grad_and_hess, role=consts.HOST, idx=-1)
        """
        federation.remote(obj=encrypted_grad_and_hess,
                          name=self.transfer_inst.encrypted_grad_and_hess.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.encrypted_grad_and_hess),
                          role=consts.HOST,
                          idx=-1)
        """

    def encrypt_grad_and_hess(self):
        LOGGER.info("start to encrypt grad and hess")
        encrypted_grad_and_hess = self.encrypted_mode_calculator.encrypt(
            self.grad_and_hess)
        return encrypted_grad_and_hess

    def get_grad_hess_sum(self, grad_and_hess_table):
        LOGGER.info("calculate the sum of grad and hess")
        grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[
            0] + value2[0], value1[1] + value2[1]))
        return grad, hess

    def dispatch_all_node_to_root(self, root_id=0):
        LOGGER.info("dispatch all node to root")
        self.node_dispatch = self.data_bin.mapValues(lambda data_inst:
                                                     (1, root_id))

    def get_histograms(self, node_map={}):
        LOGGER.info("start to get node histograms")
        histograms = FeatureHistogram.calculate_histogram(
            self.data_bin_with_node_dispatch, self.grad_and_hess,
            self.bin_split_points, self.bin_sparse_points, self.valid_features,
            node_map, self.use_missing, self.zero_as_missing)
        acc_histograms = FeatureHistogram.accumulate_histogram(histograms)
        return acc_histograms

    def sync_tree_node_queue(self, tree_node_queue, dep=-1):
        LOGGER.info("send tree node queue of depth {}".format(dep))
        mask_tree_node_queue = copy.deepcopy(tree_node_queue)
        for i in range(len(mask_tree_node_queue)):
            mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id)

        self.transfer_inst.tree_node_queue.remote(mask_tree_node_queue,
                                                  role=consts.HOST,
                                                  idx=-1,
                                                  suffix=(dep, ))
        """
        federation.remote(obj=mask_tree_node_queue,
                          name=self.transfer_inst.tree_node_queue.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree_node_queue, dep),
                          role=consts.HOST,
                          idx=-1)
        """

    def sync_node_positions(self, dep):
        LOGGER.info("send node positions of depth {}".format(dep))
        self.transfer_inst.node_positions.remote(self.node_dispatch,
                                                 role=consts.HOST,
                                                 idx=-1,
                                                 suffix=(dep, ))
        """
        federation.remote(obj=self.node_dispatch,
                          name=self.transfer_inst.node_positions.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.node_positions, dep),
                          role=consts.HOST,
                          idx=-1)
        """

    def sync_encrypted_splitinfo_host(self, dep=-1, batch=-1):
        LOGGER.info("get encrypted splitinfo of depth {}, batch {}".format(
            dep, batch))
        encrypted_splitinfo_host = self.transfer_inst.encrypted_splitinfo_host.get(
            idx=-1, suffix=(
                dep,
                batch,
            ))
        """
        encrypted_splitinfo_host = federation.get(name=self.transfer_inst.encrypted_splitinfo_host.name,
                                                  tag=self.transfer_inst.generate_transferid(
                                                      self.transfer_inst.encrypted_splitinfo_host, dep, batch),
                                                  idx=-1)
        """
        return encrypted_splitinfo_host

    def sync_federated_best_splitinfo_host(self,
                                           federated_best_splitinfo_host,
                                           dep=-1,
                                           batch=-1,
                                           idx=-1):
        LOGGER.info(
            "send federated best splitinfo of depth {}, batch {}".format(
                dep, batch))
        self.transfer_inst.federated_best_splitinfo_host.remote(
            federated_best_splitinfo_host,
            role=consts.HOST,
            idx=idx,
            suffix=(
                dep,
                batch,
            ))
        """
        federation.remote(obj=federated_best_splitinfo_host,
                          name=self.transfer_inst.federated_best_splitinfo_host.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.federated_best_splitinfo_host,
                                                                     dep,
                                                                     batch),
                          role=consts.HOST,
                          idx=idx)
        """

    def find_host_split(self, value):
        cur_split_node, encrypted_splitinfo_host = value
        sum_grad = cur_split_node.sum_grad
        sum_hess = cur_split_node.sum_hess
        best_gain = self.min_impurity_split - consts.FLOAT_ZERO
        best_idx = -1

        for i in range(len(encrypted_splitinfo_host)):
            sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i]
            sum_grad_l = self.decrypt(sum_grad_l)
            sum_hess_l = self.decrypt(sum_hess_l)
            sum_grad_r = sum_grad - sum_grad_l
            sum_hess_r = sum_hess - sum_hess_l
            gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l,
                                            sum_hess_l, sum_grad_r, sum_hess_r)

            if gain > self.min_impurity_split and gain > best_gain:
                best_gain = gain
                best_idx = i

        best_gain = self.encrypt(best_gain)
        return best_idx, best_gain

    def federated_find_split(self, dep=-1, batch=-1):
        LOGGER.info("federated find split of depth {}, batch {}".format(
            dep, batch))
        encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host(
            dep, batch)

        for i in range(len(encrypted_splitinfo_host)):
            encrypted_splitinfo_host_table = session.parallelize(
                zip(self.cur_split_nodes, encrypted_splitinfo_host[i]),
                include_key=False,
                partition=self.data_bin._partitions)

            splitinfos = encrypted_splitinfo_host_table.mapValues(
                self.find_host_split).collect()
            best_splitinfo_host = [splitinfo[1] for splitinfo in splitinfos]

            self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep,
                                                    batch, i)

    def sync_final_split_host(self, dep=-1, batch=-1):
        LOGGER.info("get host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfo_host = self.transfer_inst.final_splitinfo_host.get(
            idx=-1, suffix=(
                dep,
                batch,
            ))
        """
        final_splitinfo_host = federation.get(name=self.transfer_inst.final_splitinfo_host.name,
                                              tag=self.transfer_inst.generate_transferid(
                                                  self.transfer_inst.final_splitinfo_host, dep, batch),
                                              idx=-1)
        """
        return final_splitinfo_host

    def find_best_split_guest_and_host(self, splitinfo_guest_host):
        best_gain_host = self.decrypt(splitinfo_guest_host[1].gain)
        best_gain_host_idx = 1
        for i in range(1, len(splitinfo_guest_host)):
            gain_host_i = self.decrypt(splitinfo_guest_host[i].gain)
            if best_gain_host < gain_host_i:
                best_gain_host = gain_host_i
                best_gain_host_idx = i

        if splitinfo_guest_host[0].gain >= best_gain_host - consts.FLOAT_ZERO:
            best_splitinfo = splitinfo_guest_host[0]
        else:
            best_splitinfo = splitinfo_guest_host[best_gain_host_idx]
            best_splitinfo.sum_grad = self.decrypt(best_splitinfo.sum_grad)
            best_splitinfo.sum_hess = self.decrypt(best_splitinfo.sum_hess)
            best_splitinfo.gain = best_gain_host

        return best_splitinfo

    def merge_splitinfo(self, splitinfo_guest, splitinfo_host):
        LOGGER.info("merge splitinfo")
        merge_infos = []
        for i in range(len(splitinfo_guest)):
            splitinfo = [splitinfo_guest[i]]
            for j in range(len(splitinfo_host)):
                splitinfo.append(splitinfo_host[j][i])

            merge_infos.append(splitinfo)

        splitinfo_guest_host_table = session.parallelize(
            merge_infos,
            include_key=False,
            partition=self.data_bin._partitions)
        best_splitinfo_table = splitinfo_guest_host_table.mapValues(
            self.find_best_split_guest_and_host)
        best_splitinfos = [
            best_splitinfo[1]
            for best_splitinfo in best_splitinfo_table.collect()
        ]

        return best_splitinfos

    def update_feature_importance(self, splitinfo):
        if self.feature_importance_type == "split":
            inc = 1
        elif self.feature_importance_type == "gain":
            inc = splitinfo.gain
        else:
            raise ValueError(
                "feature importance type {} not support yet".format(
                    self.feature_importance_type))

        sitename = splitinfo.sitename
        fid = splitinfo.best_fid

        if (sitename, fid) not in self.feature_importances_:
            self.feature_importances_[(sitename, fid)] = 0

        self.feature_importances_[(sitename, fid)] += inc

    def update_tree_node_queue(self, splitinfos, max_depth_reach):
        LOGGER.info(
            "update tree node, splitlist length is {}, tree node queue size is"
            .format(len(splitinfos), len(self.tree_node_queue)))
        new_tree_node_queue = []
        for i in range(len(self.tree_node_queue)):
            sum_grad = self.tree_node_queue[i].sum_grad
            sum_hess = self.tree_node_queue[i].sum_hess
            if max_depth_reach or splitinfos[i].gain <= \
                    self.min_impurity_split + consts.FLOAT_ZERO:
                self.tree_node_queue[i].is_leaf = True
            else:
                self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1
                self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2
                self.tree_node_num += 2

                left_node = Node(id=self.tree_node_queue[i].left_nodeid,
                                 sitename=self.sitename,
                                 sum_grad=splitinfos[i].sum_grad,
                                 sum_hess=splitinfos[i].sum_hess,
                                 weight=self.splitter.node_weight(
                                     splitinfos[i].sum_grad,
                                     splitinfos[i].sum_hess))
                right_node = Node(id=self.tree_node_queue[i].right_nodeid,
                                  sitename=self.sitename,
                                  sum_grad=sum_grad - splitinfos[i].sum_grad,
                                  sum_hess=sum_hess - splitinfos[i].sum_hess,
                                  weight=self.splitter.node_weight( \
                                      sum_grad - splitinfos[i].sum_grad,
                                      sum_hess - splitinfos[i].sum_hess))

                new_tree_node_queue.append(left_node)
                new_tree_node_queue.append(right_node)

                self.tree_node_queue[i].sitename = splitinfos[i].sitename
                if self.tree_node_queue[i].sitename == self.sitename:
                    self.tree_node_queue[i].fid = self.encode(
                        "feature_idx", splitinfos[i].best_fid)
                    self.tree_node_queue[i].bid = self.encode(
                        "feature_val", splitinfos[i].best_bid,
                        self.tree_node_queue[i].id)
                    self.tree_node_queue[i].missing_dir = self.encode(
                        "missing_dir", splitinfos[i].missing_dir,
                        self.tree_node_queue[i].id)
                else:
                    self.tree_node_queue[i].fid = splitinfos[i].best_fid
                    self.tree_node_queue[i].bid = splitinfos[i].best_bid

                self.update_feature_importance(splitinfos[i])
            self.tree_.append(self.tree_node_queue[i])

        self.tree_node_queue = new_tree_node_queue

    @staticmethod
    def dispatch_node(value,
                      tree_=None,
                      decoder=None,
                      sitename=consts.GUEST,
                      split_maskdict=None,
                      bin_sparse_points=None,
                      use_missing=False,
                      zero_as_missing=False,
                      missing_dir_maskdict=None):
        unleaf_state, nodeid = value[1]

        if tree_[nodeid].is_leaf is True:
            return tree_[nodeid].weight
        else:
            if tree_[nodeid].sitename == sitename:
                fid = decoder("feature_idx",
                              tree_[nodeid].fid,
                              split_maskdict=split_maskdict)
                bid = decoder("feature_val",
                              tree_[nodeid].bid,
                              nodeid,
                              split_maskdict=split_maskdict)
                if not use_missing:
                    if value[0].features.get_data(
                            fid, bin_sparse_points[fid]) <= bid:
                        return 1, tree_[nodeid].left_nodeid
                    else:
                        return 1, tree_[nodeid].right_nodeid
                else:
                    missing_dir = decoder(
                        "missing_dir",
                        tree_[nodeid].missing_dir,
                        nodeid,
                        missing_dir_maskdict=missing_dir_maskdict)

                    missing_val = False
                    if zero_as_missing:
                        if value[0].features.get_data(fid, None) is None or \
                                value[0].features.get_data(fid) == NoneType():
                            missing_val = True
                    elif use_missing and value[0].features.get_data(
                            fid) == NoneType():
                        missing_val = True

                    if missing_val:
                        if missing_dir == 1:
                            return 1, tree_[nodeid].right_nodeid
                        else:
                            return 1, tree_[nodeid].left_nodeid
                    else:
                        LOGGER.debug(
                            "fid is {}, bid is {}, sitename is {}".format(
                                fid, bid, sitename))
                        if value[0].features.get_data(
                                fid, bin_sparse_points[fid]) <= bid:
                            return 1, tree_[nodeid].left_nodeid
                        else:
                            return 1, tree_[nodeid].right_nodeid
            else:
                return (1, tree_[nodeid].fid, tree_[nodeid].bid,
                        tree_[nodeid].sitename, nodeid,
                        tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid)

    def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1):
        LOGGER.info("send node to host to dispath, depth is {}".format(dep))
        self.transfer_inst.dispatch_node_host.remote(dispatch_guest_data,
                                                     role=consts.HOST,
                                                     idx=-1,
                                                     suffix=(dep, ))
        """
        federation.remote(obj=dispatch_guest_data,
                          name=self.transfer_inst.dispatch_node_host.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.dispatch_node_host, dep),
                          role=consts.HOST,
                          idx=-1)
        """

    def sync_dispatch_node_host_result(self, dep=-1):
        LOGGER.info("get host dispatch result, depth is {}".format(dep))
        dispatch_node_host_result = self.transfer_inst.dispatch_node_host_result.get(
            idx=-1, suffix=(dep, ))
        """
        dispatch_node_host_result = federation.get(name=self.transfer_inst.dispatch_node_host_result.name,
                                                   tag=self.transfer_inst.generate_transferid(
                                                       self.transfer_inst.dispatch_node_host_result, dep),
                                                   idx=-1)
        """
        return dispatch_node_host_result

    def redispatch_node(self, dep=-1):
        LOGGER.info("redispatch node of depth {}".format(dep))
        dispatch_node_method = functools.partial(
            self.dispatch_node,
            tree_=self.tree_,
            decoder=self.decode,
            sitename=self.sitename,
            split_maskdict=self.split_maskdict,
            bin_sparse_points=self.bin_sparse_points,
            use_missing=self.use_missing,
            zero_as_missing=self.zero_as_missing,
            missing_dir_maskdict=self.missing_dir_maskdict)
        dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues(
            dispatch_node_method)
        tree_node_num = self.tree_node_num
        LOGGER.info("remask dispatch node result of depth {}".format(dep))

        dispatch_to_host_result = dispatch_guest_result.filter(
            lambda key, value: isinstance(value, tuple) and len(value) > 2)

        dispatch_guest_result = dispatch_guest_result.subtractByKey(
            dispatch_to_host_result)
        leaf = dispatch_guest_result.filter(
            lambda key, value: isinstance(value, tuple) is False)
        if self.predict_weights is None:
            self.predict_weights = leaf
        else:
            self.predict_weights = self.predict_weights.union(leaf)

        dispatch_guest_result = dispatch_guest_result.subtractByKey(leaf)

        self.sync_dispatch_node_host(dispatch_to_host_result, dep)
        dispatch_node_host_result = self.sync_dispatch_node_host_result(dep)

        self.node_dispatch = None
        for idx in range(len(dispatch_node_host_result)):
            if self.node_dispatch is None:
                self.node_dispatch = dispatch_node_host_result[idx]
            else:
                self.node_dispatch = self.node_dispatch.join(dispatch_node_host_result[idx], \
                                                             lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \
                                                                 unleaf_state_nodeid1 if len(
                                                                     unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2)
        self.node_dispatch = self.node_dispatch.union(dispatch_guest_result)

    def sync_tree(self):
        LOGGER.info("sync tree to host")

        self.transfer_inst.tree.remote(self.tree_, role=consts.HOST, idx=-1)
        """
        federation.remote(obj=self.tree_,
                          name=self.transfer_inst.tree.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.tree),
                          role=consts.HOST,
                          idx=-1)
        """

    def convert_bin_to_real(self):
        LOGGER.info("convert tree node bins to real value")
        for i in range(len(self.tree_)):
            if self.tree_[i].is_leaf is True:
                continue
            if self.tree_[i].sitename == self.sitename:
                fid = self.decode("feature_idx",
                                  self.tree_[i].fid,
                                  split_maskdict=self.split_maskdict)
                bid = self.decode("feature_val", self.tree_[i].bid,
                                  self.tree_[i].id, self.split_maskdict)
                real_splitval = self.encode("feature_val",
                                            self.bin_split_points[fid][bid],
                                            self.tree_[i].id)
                self.tree_[i].bid = real_splitval

    def fit(self):
        LOGGER.info("begin to fit guest decision tree")
        self.sync_encrypted_grad_and_hess()

        # LOGGER.debug("self.grad and hess is {}".format(list(self.grad_and_hess.collect())))
        root_sum_grad, root_sum_hess = self.get_grad_hess_sum(
            self.grad_and_hess)
        root_node = Node(id=0,
                         sitename=self.sitename,
                         sum_grad=root_sum_grad,
                         sum_hess=root_sum_hess,
                         weight=self.splitter.node_weight(
                             root_sum_grad, root_sum_hess))
        self.tree_node_queue = [root_node]

        self.dispatch_all_node_to_root()

        for dep in range(self.max_depth):
            LOGGER.info(
                "start to fit depth {}, tree node queue size is {}".format(
                    dep, len(self.tree_node_queue)))

            self.sync_tree_node_queue(self.tree_node_queue, dep)
            if len(self.tree_node_queue) == 0:
                break

            self.sync_node_positions(dep)

            self.data_bin_with_node_dispatch = self.data_bin.join(
                self.node_dispatch, lambda data_inst, dispatch_info:
                (data_inst, dispatch_info))

            batch = 0
            splitinfos = []
            for i in range(0, len(self.tree_node_queue), self.max_split_nodes):
                self.cur_split_nodes = self.tree_node_queue[i:i + self.
                                                            max_split_nodes]

                node_map = {}
                node_num = 0
                for tree_node in self.cur_split_nodes:
                    node_map[tree_node.id] = node_num
                    node_num += 1

                acc_histograms = self.get_histograms(node_map=node_map)

                self.best_splitinfo_guest = self.splitter.find_split(
                    acc_histograms, self.valid_features,
                    self.data_bin._partitions, self.sitename, self.use_missing,
                    self.zero_as_missing)
                self.federated_find_split(dep, batch)
                final_splitinfo_host = self.sync_final_split_host(dep, batch)

                cur_splitinfos = self.merge_splitinfo(
                    self.best_splitinfo_guest, final_splitinfo_host)
                splitinfos.extend(cur_splitinfos)

                batch += 1

            max_depth_reach = True if dep + 1 == self.max_depth else False
            self.update_tree_node_queue(splitinfos, max_depth_reach)

            self.redispatch_node(dep)

        self.sync_tree()
        self.convert_bin_to_real()
        tree_ = self.tree_
        LOGGER.info("tree node num is %d" % len(tree_))
        LOGGER.info("end to fit guest decision tree")

    @staticmethod
    def traverse_tree(predict_state,
                      data_inst,
                      tree_=None,
                      decoder=None,
                      sitename=consts.GUEST,
                      split_maskdict=None,
                      use_missing=None,
                      zero_as_missing=None,
                      missing_dir_maskdict=None):
        nid, tag = predict_state

        while tree_[nid].sitename == sitename:
            if tree_[nid].is_leaf is True:
                return tree_[nid].weight

            fid = decoder("feature_idx",
                          tree_[nid].fid,
                          split_maskdict=split_maskdict)
            bid = decoder("feature_val",
                          tree_[nid].bid,
                          nid,
                          split_maskdict=split_maskdict)
            if use_missing:
                missing_dir = decoder(
                    "missing_dir",
                    1,
                    nid,
                    missing_dir_maskdict=missing_dir_maskdict)
            else:
                missing_dir = 1

            if use_missing and zero_as_missing:
                missing_dir = decoder(
                    "missing_dir",
                    1,
                    nid,
                    missing_dir_maskdict=missing_dir_maskdict)
                if data_inst.features.get_data(fid) == NoneType(
                ) or data_inst.features.get_data(fid, None) is None:
                    if missing_dir == 1:
                        nid = tree_[nid].right_nodeid
                    else:
                        nid = tree_[nid].left_nodeid
                elif data_inst.features.get_data(fid) <= bid:
                    nid = tree_[nid].left_nodeid
                else:
                    nid = tree_[nid].right_nodeid
            elif data_inst.features.get_data(fid) == NoneType():
                if missing_dir == 1:
                    nid = tree_[nid].right_nodeid
                else:
                    nid = tree_[nid].left_nodeid
            elif data_inst.features.get_data(fid, 0) <= bid:
                nid = tree_[nid].left_nodeid
            else:
                nid = tree_[nid].right_nodeid

        return nid, 1

    def sync_predict_finish_tag(self, finish_tag, send_times):
        LOGGER.info("send the {}-th predict finish tag {} to host".format(
            finish_tag, send_times))

        self.transfer_inst.predict_finish_tag.remote(finish_tag,
                                                     role=consts.HOST,
                                                     idx=-1,
                                                     suffix=(send_times, ))
        """
        federation.remote(obj=finish_tag,
                          name=self.transfer_inst.predict_finish_tag.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_finish_tag, send_times),
                          role=consts.HOST,
                          idx=-1)
        """

    def sync_predict_data(self, predict_data, send_times):
        LOGGER.info("send predict data to host, sending times is {}".format(
            send_times))
        self.transfer_inst.predict_data.remote(predict_data,
                                               role=consts.HOST,
                                               idx=-1,
                                               suffix=(send_times, ))
        """
        federation.remote(obj=predict_data,
                          name=self.transfer_inst.predict_data.name,
                          tag=self.transfer_inst.generate_transferid(self.transfer_inst.predict_data, send_times),
                          role=consts.HOST,
                          idx=-1)
        """

    def sync_data_predicted_by_host(self, send_times):
        LOGGER.info(
            "get predicted data by host, recv times is {}".format(send_times))
        predict_data = self.transfer_inst.predict_data_by_host.get(
            idx=-1, suffix=(send_times, ))
        """
        predict_data = federation.get(name=self.transfer_inst.predict_data_by_host.name,
                                      tag=self.transfer_inst.generate_transferid(
                                          self.transfer_inst.predict_data_by_host, send_times),
                                      idx=-1)
        """
        return predict_data

    def predict(self, data_inst):
        LOGGER.info("start to predict!")
        predict_data = data_inst.mapValues(lambda data_inst: (0, 1))
        site_host_send_times = 0
        predict_result = None

        while True:
            traverse_tree = functools.partial(
                self.traverse_tree,
                tree_=self.tree_,
                decoder=self.decode,
                sitename=self.sitename,
                split_maskdict=self.split_maskdict,
                use_missing=self.use_missing,
                zero_as_missing=self.zero_as_missing,
                missing_dir_maskdict=self.missing_dir_maskdict)
            predict_data = predict_data.join(data_inst, traverse_tree)
            predict_leaf = predict_data.filter(
                lambda key, value: isinstance(value, tuple) is False)
            if predict_result is None:
                predict_result = predict_leaf
            else:
                predict_result = predict_result.union(predict_leaf)

            predict_data = predict_data.subtractByKey(predict_leaf)

            unleaf_node_count = predict_data.count()

            if unleaf_node_count == 0:
                self.sync_predict_finish_tag(True, site_host_send_times)
                break

            self.sync_predict_finish_tag(False, site_host_send_times)
            self.sync_predict_data(predict_data, site_host_send_times)

            predict_data_host = self.sync_data_predicted_by_host(
                site_host_send_times)
            for i in range(len(predict_data_host)):
                predict_data = predict_data.join(
                    predict_data_host[i],
                    lambda state1_nodeid1, state2_nodeid2: state1_nodeid1
                    if state1_nodeid1[1] == 0 else state2_nodeid2)

            site_host_send_times += 1

        LOGGER.info("predict finish!")
        return predict_result

    def get_model_meta(self):
        model_meta = DecisionTreeModelMeta()
        model_meta.criterion_meta.CopyFrom(
            CriterionMeta(criterion_method=self.criterion_method,
                          criterion_param=self.criterion_params))

        model_meta.max_depth = self.max_depth
        model_meta.min_sample_split = self.min_sample_split
        model_meta.min_impurity_split = self.min_impurity_split
        model_meta.min_leaf_node = self.min_leaf_node
        model_meta.use_missing = self.use_missing
        model_meta.zero_as_missing = self.zero_as_missing

        return model_meta

    def set_model_meta(self, model_meta):
        self.max_depth = model_meta.max_depth
        self.min_sample_split = model_meta.min_sample_split
        self.min_impurity_split = model_meta.min_impurity_split
        self.min_leaf_node = model_meta.min_leaf_node
        self.criterion_method = model_meta.criterion_meta.criterion_method
        self.criterion_params = list(model_meta.criterion_meta.criterion_param)
        self.use_missing = model_meta.use_missing
        self.zero_as_missing = model_meta.zero_as_missing

    def get_model_param(self):
        model_param = DecisionTreeModelParam()
        for node in self.tree_:
            model_param.tree_.add(id=node.id,
                                  sitename=node.sitename,
                                  fid=node.fid,
                                  bid=node.bid,
                                  weight=node.weight,
                                  is_leaf=node.is_leaf,
                                  left_nodeid=node.left_nodeid,
                                  right_nodeid=node.right_nodeid,
                                  missing_dir=node.missing_dir)
            LOGGER.debug(
                "missing_dir is {}, sitename is {}, is_leaf is {}".format(
                    node.missing_dir, node.sitename, node.is_leaf))

        model_param.split_maskdict.update(self.split_maskdict)
        model_param.missing_dir_maskdict.update(self.missing_dir_maskdict)

        return model_param

    def set_model_param(self, model_param):
        self.tree_ = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid,
                         missing_dir=node_param.missing_dir)

            self.tree_.append(_node)

        self.split_maskdict = dict(model_param.split_maskdict)
        self.missing_dir_maskdict = dict(model_param.missing_dir_maskdict)

    def get_model(self):
        model_meta = self.get_model_meta()
        model_param = self.get_model_param()

        return model_meta, model_param

    def load_model(self, model_meta=None, model_param=None):
        LOGGER.info("load tree model")
        self.set_model_meta(model_meta)
        self.set_model_param(model_param)

    def get_feature_importance(self):
        return self.feature_importances_