def __init__(self, tree_param: DecisionTreeModelParam, valid_feature: dict,
                 epoch_idx: int, tree_idx: int, flow_id: int):

        super(HomoDecisionTreeArbiter, self).__init__(tree_param)
        self.splitter = Splitter(
            self.criterion_method,
            self.criterion_params,
            self.min_impurity_split,
            self.min_sample_split,
            self.min_leaf_node,
        )

        self.transfer_inst = HomoDecisionTreeTransferVariable()
        """
        initializing here
        """
        self.valid_features = valid_feature

        self.tree_node = []  # start from root node
        self.tree_node_num = 0
        self.cur_layer_node = []

        self.runtime_idx = 0
        self.sitename = consts.ARBITER
        self.epoch_idx = epoch_idx
        self.tree_idx = tree_idx

        # secure aggregator
        self.set_flowid(flow_id)
        self.aggregator = DecisionTreeArbiterAggregator(verbose=False)

        # stored histogram for faster computation {node_id:histogram_bag}
        self.stored_histograms = {}
Exemple #2
0
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.data_bin_with_position = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.infos = None
        self.valid_features = None
        self.pubkey = None
        self.privakey = None
        self.tree_id = None
        self.encrypted_grad_and_hess = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.split_maskdict = {}
        self.missing_dir_maskdict = {}
        self.tree_ = None
        self.runtime_idx = 0
        self.sitename = consts.HOST
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.data_bin_with_node_dispatch = None
        self.node_dispatch = None
        self.infos = None
        self.valid_features = None
        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.best_splitinfo_guest = None
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.tree_ = []
        self.tree_node_num = 0
        self.split_maskdict = {}
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.predict_weights = None
        self.runtime_idx = 0
        self.feature_importances_ = {}
class HeteroDecisionTreeHost(DecisionTree):
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeHost, self).__init__(tree_param)

        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.data_bin_with_position = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.infos = None
        self.valid_features = None
        self.pubkey = None
        self.privakey = None
        self.tree_id = None
        self.encrypted_grad_and_hess = None
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.split_maskdict = {}
        self.tree_ = None

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def set_inputinfo(self,
                      data_bin=None,
                      grad_and_hess=None,
                      bin_split_points=None,
                      bin_sparse_points=None):
        LOGGER.info("set input info")
        self.data_bin = data_bin
        self.grad_and_hess = grad_and_hess
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_points

    def set_valid_features(self, valid_features=None):
        LOGGER.info("set valid features")
        self.valid_features = valid_features

    def encode(self, etype="feature_idx", val=None, nid=None):
        if etype == "feature_idx":
            return val

        if etype == "feature_val":
            self.split_maskdict[nid] = val
            return None

        raise TypeError("encode type %s is not support!" % (str(etype)))

    @staticmethod
    def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None):
        if dtype == "feature_idx":
            return val

        if dtype == "feature_val":
            if nid in split_maskdict:
                return split_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        return TypeError("decode type %s is not support!" % (str(dtype)))

    def sync_encrypted_grad_and_hess(self):
        LOGGER.info("get encrypted grad and hess")
        self.grad_and_hess = federation.get(
            name=self.transfer_inst.encrypted_grad_and_hess.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.encrypted_grad_and_hess),
            idx=0)

    def sync_node_positions(self, dep=-1):
        LOGGER.info("get tree node queue of depth {}".format(dep))
        node_positions = federation.get(
            name=self.transfer_inst.node_positions.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.node_positions, dep),
            idx=0)
        return node_positions

    def sync_tree_node_queue(self, dep=-1):
        LOGGER.info("get tree node queue of depth {}".format(dep))
        self.tree_node_queue = federation.get(
            name=self.transfer_inst.tree_node_queue.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.tree_node_queue, dep),
            idx=0)

    def get_histograms(self, node_map={}):
        LOGGER.info("start to get node histograms")
        # self.data_bin_with_position = self.data_bin.join(node_positions, lambda v1, v2: (v1, v2))
        histograms = FeatureHistogram.calculate_histogram(
            self.data_bin_with_position, self.grad_and_hess,
            self.bin_split_points, self.bin_sparse_points, self.valid_features,
            node_map)
        LOGGER.info("begin to accumulate histograms")
        acc_histograms = FeatureHistogram.accumulate_histogram(histograms)
        LOGGER.info("acc histogram shape is {}".format(len(acc_histograms)))
        return acc_histograms

    def sync_encrypted_splitinfo_host(self,
                                      encrypted_splitinfo_host,
                                      dep=-1,
                                      batch=-1):
        LOGGER.info("send encrypted splitinfo of depth {}, batch {}".format(
            dep, batch))
        federation.remote(
            obj=encrypted_splitinfo_host,
            name=self.transfer_inst.encrypted_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.encrypted_splitinfo_host, dep, batch),
            role=consts.GUEST,
            idx=0)

    def sync_federated_best_splitinfo_host(self, dep=-1, batch=-1):
        LOGGER.info(
            "get federated best splitinfo of depth {}, batch {}".format(
                dep, batch))
        federated_best_splitinfo_host = federation.get(
            name=self.transfer_inst.federated_best_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.federated_best_splitinfo_host, dep, batch),
            idx=0)
        return federated_best_splitinfo_host

    def sync_final_splitinfo_host(self,
                                  splitinfo_host,
                                  federated_best_splitinfo_host,
                                  dep=-1,
                                  batch=-1):
        LOGGER.info("send host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                assert splitinfo_host[i][best_idx].sitename == consts.HOST
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = self.encode("feature_idx",
                                                 splitinfo.best_fid)
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = self.encode("feature_val",
                                                 splitinfo.best_bid,
                                                 self.cur_split_nodes[i].id)
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=consts.HOST,
                                      best_fid=-1,
                                      best_bid=-1,
                                      gain=best_gain)

            final_splitinfos.append(splitinfo)

        federation.remote(obj=final_splitinfos,
                          name=self.transfer_inst.final_splitinfo_host.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.final_splitinfo_host, dep,
                              batch),
                          role=consts.GUEST,
                          idx=0)

    def sync_dispatch_node_host(self, dep):
        LOGGER.info("get node from host to dispath, depth is {}".format(dep))
        dispatch_node_host = federation.get(
            name=self.transfer_inst.dispatch_node_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.dispatch_node_host, dep),
            idx=0)
        return dispatch_node_host

    @staticmethod
    def dispatch_node(value1,
                      value2,
                      decoder=None,
                      split_maskdict=None,
                      bin_sparse_points=None):
        if len(value1) <= 2:
            return value1

        unleaf_state, fid, bid, nodeid, left_nodeid, right_nodeid = value1
        fid = decoder("feature_idx", fid, split_maskdict=split_maskdict)
        bid = decoder("feature_val",
                      bid,
                      nodeid,
                      split_maskdict=split_maskdict)
        if value2.features.get_data(fid, bin_sparse_points[fid]) <= bid:
            return unleaf_state, left_nodeid
        else:
            return unleaf_state, right_nodeid

    def sync_dispatch_node_host_result(self,
                                       dispatch_node_host_result,
                                       dep=-1):
        LOGGER.info("send host dispatch result, depth is {}".format(dep))
        federation.remote(
            obj=dispatch_node_host_result,
            name=self.transfer_inst.dispatch_node_host_result.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.dispatch_node_host_result, dep),
            role=consts.GUEST,
            idx=0)

    def find_dispatch(self, dispatch_node_host, dep=-1):
        LOGGER.info("start to find host dispath of depth {}".format(dep))
        dispatch_node_method = functools.partial(
            self.dispatch_node,
            decoder=self.decode,
            split_maskdict=self.split_maskdict,
            bin_sparse_points=self.bin_sparse_points)
        dispatch_node_host_result = dispatch_node_host.join(
            self.data_bin, dispatch_node_method)
        self.sync_dispatch_node_host_result(dispatch_node_host_result, dep)

    def sync_tree(self):
        LOGGER.info("sync tree from guest")
        self.tree_ = federation.get(name=self.transfer_inst.tree.name,
                                    tag=self.transfer_inst.generate_transferid(
                                        self.transfer_inst.tree),
                                    idx=0)

    def remove_duplicated_split_nodes(self, split_nid_used):
        LOGGER.info("remove duplicated nodes from split mask dict")
        duplicated_nodes = set(
            self.split_maskdict.keys()) - set(split_nid_used)
        for nid in duplicated_nodes:
            del self.split_maskdict[nid]

    def convert_bin_to_real(self):
        LOGGER.info("convert tree node bins to real value")
        split_nid_used = []
        for i in range(len(self.tree_)):
            if self.tree_[i].is_leaf is True:
                continue

            if self.tree_[i].sitename == consts.HOST:
                fid = self.decode("feature_idx",
                                  self.tree_[i].fid,
                                  split_maskdict=self.split_maskdict)
                bid = self.decode("feature_val", self.tree_[i].bid,
                                  self.tree_[i].id, self.split_maskdict)
                real_splitval = self.encode("feature_val",
                                            self.bin_split_points[fid][bid],
                                            self.tree_[i].id)
                self.tree_[i].bid = real_splitval

                split_nid_used.append(self.tree_[i].id)

        self.remove_duplicated_split_nodes(split_nid_used)

    @staticmethod
    def traverse_tree(predict_state,
                      data_inst,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None):
        tag, nid = predict_state
        if tag == 0:
            return (tag, nid)

        while tree_[nid].sitename != consts.GUEST:
            fid = decoder("feature_idx",
                          tree_[nid].fid,
                          split_maskdict=split_maskdict)
            bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict)

            if data_inst.features.get_data(fid, 0) <= bid:
                nid = tree_[nid].left_nodeid
            else:
                nid = tree_[nid].right_nodeid

        return (1, nid)

    def sync_predict_finish_tag(self, recv_times):
        LOGGER.info(
            "get the {}-th predict finish tag from guest".format(recv_times))
        finish_tag = federation.get(
            name=self.transfer_inst.predict_finish_tag.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.predict_finish_tag, recv_times),
            idx=0)
        return finish_tag

    def sync_predict_data(self, recv_times):
        LOGGER.info(
            "srecv predict data to host, recv times is {}".format(recv_times))
        predict_data = federation.get(
            name=self.transfer_inst.predict_data.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.predict_data, recv_times),
            idx=0)

        return predict_data

    def sync_data_predicted_by_host(self, predict_data, send_times):
        LOGGER.info(
            "send predicted data by host, send times is {}".format(send_times))
        federation.remote(obj=predict_data,
                          name=self.transfer_inst.predict_data_by_host.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.predict_data_by_host,
                              send_times),
                          role=consts.GUEST,
                          idx=0)

    def fit(self):
        LOGGER.info("begin to fit host decision tree")
        self.sync_encrypted_grad_and_hess()

        for dep in range(self.max_depth):
            self.sync_tree_node_queue(dep)
            if len(self.tree_node_queue) == 0:
                break

            node_positions = self.sync_node_positions(dep)
            self.data_bin_with_position = self.data_bin.join(
                node_positions, lambda v1, v2: (v1, v2))

            batch = 0
            for i in range(0, len(self.tree_node_queue), self.max_split_nodes):
                self.cur_split_nodes = self.tree_node_queue[i:i + self.
                                                            max_split_nodes]
                node_map = {}
                node_num = 0
                for tree_node in self.cur_split_nodes:
                    node_map[tree_node.id] = node_num
                    node_num += 1

                acc_histograms = self.get_histograms(node_map=node_map)

                splitinfo_host, encrypted_splitinfo_host = self.splitter.find_split_host(
                    acc_histograms, self.valid_features,
                    self.data_bin._partitions)
                self.sync_encrypted_splitinfo_host(encrypted_splitinfo_host,
                                                   dep, batch)
                federated_best_splitinfo_host = self.sync_federated_best_splitinfo_host(
                    dep, batch)
                self.sync_final_splitinfo_host(splitinfo_host,
                                               federated_best_splitinfo_host,
                                               dep, batch)

                batch += 1

            dispatch_node_host = self.sync_dispatch_node_host(dep)
            self.find_dispatch(dispatch_node_host, dep)

        self.sync_tree()
        self.convert_bin_to_real()

        LOGGER.info("end to fit guest decision tree")

    def predict(self, data_inst):
        LOGGER.info("start to predict!")
        site_guest_send_times = 0
        while True:
            finish_tag = self.sync_predict_finish_tag(site_guest_send_times)
            if finish_tag is True:
                break

            predict_data = self.sync_predict_data(site_guest_send_times)

            traverse_tree = functools.partial(
                self.traverse_tree,
                tree_=self.tree_,
                decoder=self.decode,
                split_maskdict=self.split_maskdict)
            predict_data = predict_data.join(data_inst, traverse_tree)

            self.sync_data_predicted_by_host(predict_data,
                                             site_guest_send_times)

            site_guest_send_times += 1

        LOGGER.info("predict finish!")

    def get_model_meta(self):
        model_meta = DecisionTreeModelMeta()

        model_meta.max_depth = self.max_depth
        model_meta.min_sample_split = self.min_sample_split
        model_meta.min_impurity_split = self.min_impurity_split
        model_meta.min_leaf_node = self.min_leaf_node

        return model_meta

    def set_model_meta(self, model_meta):
        self.max_depth = model_meta.max_depth
        self.min_sample_split = model_meta.min_sample_split
        self.min_impurity_split = model_meta.min_impurity_split
        self.min_leaf_node = model_meta.min_leaf_node

    def get_model_param(self):
        model_param = DecisionTreeModelParam()
        for node in self.tree_:
            model_param.tree_.add(id=node.id,
                                  sitename=node.sitename,
                                  fid=node.fid,
                                  bid=node.bid,
                                  weight=node.weight,
                                  is_leaf=node.is_leaf,
                                  left_nodeid=node.left_nodeid,
                                  right_nodeid=node.right_nodeid)

        model_param.split_maskdict.update(self.split_maskdict)

        return model_param

    def set_model_param(self, model_param):
        self.tree_ = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid)

            self.tree_.append(_node)

        self.split_maskdict = dict(model_param.split_maskdict)

    def get_model(self):
        model_meta = self.get_model_meta()
        model_param = self.get_model_param()

        return model_meta, model_param

    def load_model(self, model_meta=None, model_param=None):
        LOGGER.info("load tree model")
        self.set_model_meta(model_meta)
        self.set_model_param(model_param)
class HomoDecisionTreeArbiter(DecisionTree):
    def __init__(self, tree_param: DecisionTreeModelParam, valid_feature: dict,
                 epoch_idx: int, tree_idx: int, flow_id: int):

        super(HomoDecisionTreeArbiter, self).__init__(tree_param)
        self.splitter = Splitter(
            self.criterion_method,
            self.criterion_params,
            self.min_impurity_split,
            self.min_sample_split,
            self.min_leaf_node,
        )

        self.transfer_inst = HomoDecisionTreeTransferVariable()
        """
        initializing here
        """
        self.valid_features = valid_feature

        self.tree_node = []  # start from root node
        self.tree_node_num = 0
        self.cur_layer_node = []

        self.runtime_idx = 0
        self.sitename = consts.ARBITER
        self.epoch_idx = epoch_idx
        self.tree_idx = tree_idx

        # secure aggregator
        self.set_flowid(flow_id)
        self.aggregator = DecisionTreeArbiterAggregator(verbose=False)

        # stored histogram for faster computation {node_id:histogram_bag}
        self.stored_histograms = {}

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def sync_node_sample_numbers(self, suffix):
        cur_layer_node_num = self.transfer_inst.cur_layer_node_num.get(
            -1, suffix=suffix)
        for num in cur_layer_node_num[1:]:
            assert num == cur_layer_node_num[0]
        return cur_layer_node_num[0]

    def federated_find_best_split(self,
                                  node_histograms,
                                  parallel_partitions=10) -> List[SplitInfo]:

        # node histograms [[HistogramBag,HistogramBag,...],[HistogramBag,HistogramBag,....],..]
        LOGGER.debug(
            'federated finding best splits,histograms from {} guest received'.
            format(len(node_histograms)))
        LOGGER.debug('aggregating histograms .....')
        acc_histogram = node_histograms
        best_splits = self.splitter.find_split(acc_histogram,
                                               self.valid_features,
                                               parallel_partitions,
                                               self.sitename, self.use_missing,
                                               self.zero_as_missing)
        return best_splits

    def sync_best_splits(self, split_info, suffix):
        LOGGER.debug('sending best split points')
        self.transfer_inst.best_split_points.remote(split_info,
                                                    idx=-1,
                                                    suffix=suffix)

    def sync_local_histogram(self, suffix) -> List[HistogramBag]:
        LOGGER.debug('get local histograms')
        node_local_histogram = self.aggregator.aggregate_histogram(
            suffix=suffix)
        LOGGER.debug('num of histograms {}'.format(len(node_local_histogram)))
        return node_local_histogram

    def histogram_subtraction(self, left_node_histogram, stored_histograms):
        # histogram subtraction
        all_histograms = []
        for left_hist in left_node_histogram:
            all_histograms.append(left_hist)
            # LOGGER.debug('hist id is {}, pid is {}'.format(left_hist.hid, left_hist.p_hid))
            # root node hist
            if left_hist.hid == 0:
                continue
            right_hist = stored_histograms[left_hist.p_hid] - left_hist
            right_hist.hid, right_hist.p_hid = left_hist.hid + 1, right_hist.p_hid
            all_histograms.append(right_hist)

        return all_histograms

    def fit(self):

        LOGGER.info(
            'begin to fit h**o decision tree, epoch {}, tree idx {}'.format(
                self.epoch_idx, self.tree_idx))

        g_sum, h_sum = self.aggregator.aggregate_root_node_info(
            suffix=('root_node_sync1', self.epoch_idx))
        LOGGER.debug('g_sum is {},h_sum is {}'.format(g_sum, h_sum))
        self.aggregator.broadcast_root_info(g_sum,
                                            h_sum,
                                            suffix=('root_node_sync2',
                                                    self.epoch_idx))

        if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1:
            self.max_split_nodes += 1
            LOGGER.warning(
                'an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}'
                .format(self.max_split_nodes))

        for dep in range(self.max_depth):

            if dep + 1 == self.max_depth:
                break

            LOGGER.debug('at dep {}'.format(dep))

            split_info = []
            # get cur layer node num:
            cur_layer_node_num = self.sync_node_sample_numbers(
                suffix=(dep, self.epoch_idx, self.tree_idx))
            LOGGER.debug(
                '{} nodes to split at this layer'.format(cur_layer_node_num))

            layer_stored_hist = {}

            for batch_id, i in enumerate(
                    range(0, cur_layer_node_num, self.max_split_nodes)):

                LOGGER.debug('cur batch id is {}'.format(batch_id))

                left_node_histogram = self.sync_local_histogram(
                    suffix=(batch_id, dep, self.epoch_idx, self.tree_idx))

                all_histograms = self.histogram_subtraction(
                    left_node_histogram, self.stored_histograms)

                # store histogram
                for hist in all_histograms:
                    layer_stored_hist[hist.hid] = hist

                # FIXME stable parallel_partitions
                best_splits = self.federated_find_best_split(
                    all_histograms, parallel_partitions=10)
                split_info += best_splits

            self.stored_histograms = layer_stored_hist

            self.sync_best_splits(split_info, suffix=(dep, self.epoch_idx))
            LOGGER.debug('best_splits_sent')

    def predict(self, data_inst=None):
        """
        Do nothing
        """
        LOGGER.debug('start predicting')
class HeteroDecisionTreeGuest(DecisionTree):
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.data_bin_with_node_dispatch = None
        self.node_dispatch = None
        self.infos = None
        self.valid_features = None
        self.encrypter = None
        self.encrypted_mode_calculator = None
        self.best_splitinfo_guest = None
        self.tree_node_queue = None
        self.cur_split_nodes = None
        self.tree_ = []
        self.tree_node_num = 0
        self.split_maskdict = {}
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.predict_weights = None
        self.runtime_idx = 0
        self.feature_importances_ = {}

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def set_runtime_idx(self, runtime_idx):
        self.runtime_idx = runtime_idx

    def set_inputinfo(self,
                      data_bin=None,
                      grad_and_hess=None,
                      bin_split_points=None,
                      bin_sparse_points=None):
        LOGGER.info("set input info")
        self.data_bin = data_bin
        self.grad_and_hess = grad_and_hess
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_points

    def set_encrypter(self, encrypter):
        LOGGER.info("set encrypter")
        self.encrypter = encrypter

    def set_encrypted_mode_calculator(self, encrypted_mode_calculator):
        self.encrypted_mode_calculator = encrypted_mode_calculator

    def encrypt(self, val):
        return self.encrypter.encrypt(val)

    def decrypt(self, val):
        return self.encrypter.decrypt(val)

    def encode(self, etype="feature_idx", val=None, nid=None):
        if etype == "feature_idx":
            return val

        if etype == "feature_val":
            self.split_maskdict[nid] = val
            return None

        raise TypeError("encode type %s is not support!" % (str(etype)))

    @staticmethod
    def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None):
        if dtype == "feature_idx":
            return val

        if dtype == "feature_val":
            if nid in split_maskdict:
                return split_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        return TypeError("decode type %s is not support!" % (str(dtype)))

    def set_valid_features(self, valid_features=None):
        LOGGER.info("set valid features")
        self.valid_features = valid_features

    def sync_encrypted_grad_and_hess(self):
        LOGGER.info("send encrypted grad and hess to host")
        encrypted_grad_and_hess = self.encrypt_grad_and_hess()
        federation.remote(obj=encrypted_grad_and_hess,
                          name=self.transfer_inst.encrypted_grad_and_hess.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.encrypted_grad_and_hess),
                          role=consts.HOST,
                          idx=-1)

    def encrypt_grad_and_hess(self):
        LOGGER.info("start to encrypt grad and hess")
        encrypted_grad_and_hess = self.encrypted_mode_calculator.encrypt(
            self.grad_and_hess)
        return encrypted_grad_and_hess

    def get_grad_hess_sum(self, grad_and_hess_table):
        LOGGER.info("calculate the sum of grad and hess")
        grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[
            0] + value2[0], value1[1] + value2[1]))
        return grad, hess

    def dispatch_all_node_to_root(self, root_id=0):
        LOGGER.info("dispatch all node to root")
        self.node_dispatch = self.data_bin.mapValues(lambda data_inst:
                                                     (1, root_id))

    def get_histograms(self, node_map={}):
        LOGGER.info("start to get node histograms")
        histograms = FeatureHistogram.calculate_histogram(
            self.data_bin_with_node_dispatch, self.grad_and_hess,
            self.bin_split_points, self.bin_sparse_points, self.valid_features,
            node_map)
        acc_histograms = FeatureHistogram.accumulate_histogram(histograms)
        return acc_histograms

    def sync_tree_node_queue(self, tree_node_queue, dep=-1):
        LOGGER.info("send tree node queue of depth {}".format(dep))
        mask_tree_node_queue = copy.deepcopy(tree_node_queue)
        for i in range(len(mask_tree_node_queue)):
            mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id)

        federation.remote(obj=mask_tree_node_queue,
                          name=self.transfer_inst.tree_node_queue.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.tree_node_queue, dep),
                          role=consts.HOST,
                          idx=-1)

    def sync_node_positions(self, dep):
        LOGGER.info("send node positions of depth {}".format(dep))
        federation.remote(obj=self.node_dispatch,
                          name=self.transfer_inst.node_positions.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.node_positions, dep),
                          role=consts.HOST,
                          idx=-1)

    def sync_encrypted_splitinfo_host(self, dep=-1, batch=-1):
        LOGGER.info("get encrypted splitinfo of depth {}, batch {}".format(
            dep, batch))
        encrypted_splitinfo_host = federation.get(
            name=self.transfer_inst.encrypted_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.encrypted_splitinfo_host, dep, batch),
            idx=-1)
        return encrypted_splitinfo_host

    def sync_federated_best_splitinfo_host(self,
                                           federated_best_splitinfo_host,
                                           dep=-1,
                                           batch=-1,
                                           idx=-1):
        LOGGER.info(
            "send federated best splitinfo of depth {}, batch {}".format(
                dep, batch))
        federation.remote(
            obj=federated_best_splitinfo_host,
            name=self.transfer_inst.federated_best_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.federated_best_splitinfo_host, dep, batch),
            role=consts.HOST,
            idx=idx)

    def find_host_split(self, value):
        cur_split_node, encrypted_splitinfo_host = value
        sum_grad = cur_split_node.sum_grad
        sum_hess = cur_split_node.sum_hess
        best_gain = self.min_impurity_split - consts.FLOAT_ZERO
        best_idx = -1

        for i in range(len(encrypted_splitinfo_host)):
            sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i]
            sum_grad_l = self.decrypt(sum_grad_l)
            sum_hess_l = self.decrypt(sum_hess_l)
            sum_grad_r = sum_grad - sum_grad_l
            sum_hess_r = sum_hess - sum_hess_l
            gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l,
                                            sum_hess_l, sum_grad_r, sum_hess_r)

            if gain > self.min_impurity_split and gain > best_gain:
                best_gain = gain
                best_idx = i

        best_gain = self.encrypt(best_gain)
        return best_idx, best_gain

    def federated_find_split(self, dep=-1, batch=-1):
        LOGGER.info("federated find split of depth {}, batch {}".format(
            dep, batch))
        encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host(
            dep, batch)

        for i in range(len(encrypted_splitinfo_host)):
            encrypted_splitinfo_host_table = eggroll.parallelize(
                zip(self.cur_split_nodes, encrypted_splitinfo_host[i]),
                include_key=False,
                partition=self.data_bin._partitions)

            splitinfos = encrypted_splitinfo_host_table.mapValues(
                self.find_host_split).collect()
            best_splitinfo_host = [splitinfo[1] for splitinfo in splitinfos]

            self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep,
                                                    batch, i)

    def sync_final_split_host(self, dep=-1, batch=-1):
        LOGGER.info("get host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfo_host = federation.get(
            name=self.transfer_inst.final_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.final_splitinfo_host, dep, batch),
            idx=-1)

        return final_splitinfo_host

    def find_best_split_guest_and_host(self, splitinfo_guest_host):
        best_gain_host = self.decrypt(splitinfo_guest_host[1].gain)
        best_gain_host_idx = 1
        for i in range(1, len(splitinfo_guest_host)):
            gain_host_i = self.decrypt(splitinfo_guest_host[i].gain)
            if best_gain_host < gain_host_i:
                best_gain_host = gain_host_i
                best_gain_host_idx = i

        if splitinfo_guest_host[0].gain >= best_gain_host - consts.FLOAT_ZERO:
            best_splitinfo = splitinfo_guest_host[0]
        else:
            best_splitinfo = splitinfo_guest_host[best_gain_host_idx]
            best_splitinfo.sum_grad = self.decrypt(best_splitinfo.sum_grad)
            best_splitinfo.sum_hess = self.decrypt(best_splitinfo.sum_hess)
            best_splitinfo.gain = best_gain_host

        return best_splitinfo

    def merge_splitinfo(self, splitinfo_guest, splitinfo_host):
        LOGGER.info("merge splitinfo")
        merge_infos = []
        for i in range(len(splitinfo_guest)):
            splitinfo = [splitinfo_guest[i]]
            for j in range(len(splitinfo_host)):
                splitinfo.append(splitinfo_host[j][i])

            merge_infos.append(splitinfo)

        splitinfo_guest_host_table = eggroll.parallelize(
            merge_infos,
            include_key=False,
            partition=self.data_bin._partitions)
        best_splitinfo_table = splitinfo_guest_host_table.mapValues(
            self.find_best_split_guest_and_host)
        best_splitinfos = [
            best_splitinfo[1]
            for best_splitinfo in best_splitinfo_table.collect()
        ]

        return best_splitinfos

    def update_feature_importance(self, splitinfo):
        if self.feature_importance_type == "split":
            inc = 1
        elif self.feature_importance_type == "gain":
            inc = splitinfo.gain
        else:
            raise ValueError(
                "feature importance type {} not support yet".format(
                    self.feature_importance_type))

        sitename = splitinfo.sitename
        fid = splitinfo.best_fid

        if (sitename, fid) not in self.feature_importances_:
            self.feature_importances_[(sitename, fid)] = 0

        self.feature_importances_[(sitename, fid)] += inc

    def update_tree_node_queue(self, splitinfos, max_depth_reach):
        LOGGER.info(
            "update tree node, splitlist length is {}, tree node queue size is"
            .format(len(splitinfos), len(self.tree_node_queue)))
        new_tree_node_queue = []
        for i in range(len(self.tree_node_queue)):
            sum_grad = self.tree_node_queue[i].sum_grad
            sum_hess = self.tree_node_queue[i].sum_hess
            if max_depth_reach or splitinfos[i].gain <= \
                    self.min_impurity_split + consts.FLOAT_ZERO:
                self.tree_node_queue[i].is_leaf = True
            else:
                self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1
                self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2
                self.tree_node_num += 2

                left_node = Node(id=self.tree_node_queue[i].left_nodeid,
                                 sitename=consts.GUEST,
                                 sum_grad=splitinfos[i].sum_grad,
                                 sum_hess=splitinfos[i].sum_hess,
                                 weight=self.splitter.node_weight(
                                     splitinfos[i].sum_grad,
                                     splitinfos[i].sum_hess))
                right_node = Node(id=self.tree_node_queue[i].right_nodeid,
                                  sitename=consts.GUEST,
                                  sum_grad=sum_grad - splitinfos[i].sum_grad,
                                  sum_hess=sum_hess - splitinfos[i].sum_hess,
                                  weight=self.splitter.node_weight( \
                                      sum_grad - splitinfos[i].sum_grad,
                                      sum_hess - splitinfos[i].sum_hess))

                new_tree_node_queue.append(left_node)
                new_tree_node_queue.append(right_node)

                self.tree_node_queue[i].sitename = splitinfos[i].sitename
                if self.tree_node_queue[i].sitename == consts.GUEST:
                    self.tree_node_queue[i].fid = self.encode(
                        "feature_idx", splitinfos[i].best_fid)
                    self.tree_node_queue[i].bid = self.encode(
                        "feature_val", splitinfos[i].best_bid,
                        self.tree_node_queue[i].id)
                else:
                    self.tree_node_queue[i].fid = splitinfos[i].best_fid
                    self.tree_node_queue[i].bid = splitinfos[i].best_bid

                self.update_feature_importance(splitinfos[i])
            self.tree_.append(self.tree_node_queue[i])

        self.tree_node_queue = new_tree_node_queue

    @staticmethod
    def dispatch_node(value,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None,
                      bin_sparse_points=None):
        unleaf_state, nodeid = value[1]

        if tree_[nodeid].is_leaf is True:
            return tree_[nodeid].weight
        else:
            if tree_[nodeid].sitename == consts.GUEST:
                fid = decoder("feature_idx",
                              tree_[nodeid].fid,
                              split_maskdict=split_maskdict)
                bid = decoder("feature_val", tree_[nodeid].bid, nodeid,
                              split_maskdict)
                if value[0].features.get_data(fid,
                                              bin_sparse_points[fid]) <= bid:
                    return (1, tree_[nodeid].left_nodeid)
                else:
                    return (1, tree_[nodeid].right_nodeid)
            else:
                return (1, tree_[nodeid].fid, tree_[nodeid].bid,
                        tree_[nodeid].sitename, nodeid,
                        tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid)

    def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1):
        LOGGER.info("send node to host to dispath, depth is {}".format(dep))
        federation.remote(obj=dispatch_guest_data,
                          name=self.transfer_inst.dispatch_node_host.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.dispatch_node_host, dep),
                          role=consts.HOST,
                          idx=-1)

    def sync_dispatch_node_host_result(self, dep=-1):
        LOGGER.info("get host dispatch result, depth is {}".format(dep))
        dispatch_node_host_result = federation.get(
            name=self.transfer_inst.dispatch_node_host_result.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.dispatch_node_host_result, dep),
            idx=-1)

        return dispatch_node_host_result

    def redispatch_node(self, dep=-1):
        LOGGER.info("redispatch node of depth {}".format(dep))
        dispatch_node_method = functools.partial(
            self.dispatch_node,
            tree_=self.tree_,
            decoder=self.decode,
            split_maskdict=self.split_maskdict,
            bin_sparse_points=self.bin_sparse_points)
        dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues(
            dispatch_node_method)
        tree_node_num = self.tree_node_num
        LOGGER.info("remask dispatch node result of depth {}".format(dep))

        dispatch_to_host_result = dispatch_guest_result.filter(
            lambda key, value: isinstance(value, tuple) and len(value) > 2)

        dispatch_guest_result = dispatch_guest_result.subtractByKey(
            dispatch_to_host_result)
        leaf = dispatch_guest_result.filter(
            lambda key, value: isinstance(value, tuple) is False)
        if self.predict_weights is None:
            self.predict_weights = leaf
        else:
            self.predict_weights = self.predict_weights.union(leaf)

        dispatch_guest_result = dispatch_guest_result.subtractByKey(leaf)

        self.sync_dispatch_node_host(dispatch_to_host_result, dep)
        dispatch_node_host_result = self.sync_dispatch_node_host_result(dep)

        self.node_dispatch = None
        for idx in range(len(dispatch_node_host_result)):
            if self.node_dispatch is None:
                self.node_dispatch = dispatch_node_host_result[idx]
            else:
                self.node_dispatch = self.node_dispatch.join(dispatch_node_host_result[idx], \
                                                                 lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \
                                                                 unleaf_state_nodeid1 if len(
                                                                 unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2)
        self.node_dispatch = self.node_dispatch.union(dispatch_guest_result)

    def sync_tree(self):
        LOGGER.info("sync tree to host")

        federation.remote(obj=self.tree_,
                          name=self.transfer_inst.tree.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.tree),
                          role=consts.HOST,
                          idx=-1)

    def convert_bin_to_real(self):
        LOGGER.info("convert tree node bins to real value")
        for i in range(len(self.tree_)):
            if self.tree_[i].is_leaf is True:
                continue
            if self.tree_[i].sitename == consts.GUEST:
                fid = self.decode("feature_idx",
                                  self.tree_[i].fid,
                                  split_maskdict=self.split_maskdict)
                bid = self.decode("feature_val", self.tree_[i].bid,
                                  self.tree_[i].id, self.split_maskdict)
                real_splitval = self.encode("feature_val",
                                            self.bin_split_points[fid][bid],
                                            self.tree_[i].id)
                self.tree_[i].bid = real_splitval

    def fit(self):
        LOGGER.info("begin to fit guest decision tree")
        self.sync_encrypted_grad_and_hess()

        root_sum_grad, root_sum_hess = self.get_grad_hess_sum(
            self.grad_and_hess)
        root_node = Node(id=0,
                         sitename=consts.GUEST,
                         sum_grad=root_sum_grad,
                         sum_hess=root_sum_hess,
                         weight=self.splitter.node_weight(
                             root_sum_grad, root_sum_hess))
        self.tree_node_queue = [root_node]

        self.dispatch_all_node_to_root()

        for dep in range(self.max_depth):
            LOGGER.info(
                "start to fit depth {}, tree node queue size is {}".format(
                    dep, len(self.tree_node_queue)))

            self.sync_tree_node_queue(self.tree_node_queue, dep)
            if len(self.tree_node_queue) == 0:
                break

            self.sync_node_positions(dep)

            self.data_bin_with_node_dispatch = self.data_bin.join(
                self.node_dispatch, lambda data_inst, dispatch_info:
                (data_inst, dispatch_info))

            batch = 0
            splitinfos = []
            for i in range(0, len(self.tree_node_queue), self.max_split_nodes):
                self.cur_split_nodes = self.tree_node_queue[i:i + self.
                                                            max_split_nodes]

                node_map = {}
                node_num = 0
                for tree_node in self.cur_split_nodes:
                    node_map[tree_node.id] = node_num
                    node_num += 1

                acc_histograms = self.get_histograms(node_map=node_map)

                self.best_splitinfo_guest = self.splitter.find_split(
                    acc_histograms, self.valid_features,
                    self.data_bin._partitions)
                self.federated_find_split(dep, batch)
                final_splitinfo_host = self.sync_final_split_host(dep, batch)

                cur_splitinfos = self.merge_splitinfo(
                    self.best_splitinfo_guest, final_splitinfo_host)
                splitinfos.extend(cur_splitinfos)

                batch += 1

            max_depth_reach = True if dep + 1 == self.max_depth else False
            self.update_tree_node_queue(splitinfos, max_depth_reach)

            self.redispatch_node(dep)

        self.sync_tree()
        self.convert_bin_to_real()
        tree_ = self.tree_
        LOGGER.info("tree node num is %d" % len(tree_))
        LOGGER.info("end to fit guest decision tree")

    @staticmethod
    def traverse_tree(predict_state,
                      data_inst,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None):
        nid, tag = predict_state

        while tree_[nid].sitename == consts.GUEST:
            if tree_[nid].is_leaf is True:
                return tree_[nid].weight

            fid = decoder("feature_idx",
                          tree_[nid].fid,
                          split_maskdict=split_maskdict)
            bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict)

            if data_inst.features.get_data(fid, 0) <= bid:
                nid = tree_[nid].left_nodeid
            else:
                nid = tree_[nid].right_nodeid

        return nid, 1

    def sync_predict_finish_tag(self, finish_tag, send_times):
        LOGGER.info("send the {}-th predict finish tag {} to host".format(
            finish_tag, send_times))
        federation.remote(obj=finish_tag,
                          name=self.transfer_inst.predict_finish_tag.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.predict_finish_tag,
                              send_times),
                          role=consts.HOST,
                          idx=-1)

    def sync_predict_data(self, predict_data, send_times):
        LOGGER.info("send predict data to host, sending times is {}".format(
            send_times))
        federation.remote(obj=predict_data,
                          name=self.transfer_inst.predict_data.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.predict_data, send_times),
                          role=consts.HOST,
                          idx=-1)

    def sync_data_predicted_by_host(self, send_times):
        LOGGER.info(
            "get predicted data by host, recv times is {}".format(send_times))
        predict_data = federation.get(
            name=self.transfer_inst.predict_data_by_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.predict_data_by_host, send_times),
            idx=-1)
        return predict_data

    def predict(self, data_inst):
        LOGGER.info("start to predict!")
        predict_data = data_inst.mapValues(lambda data_inst: (0, 1))
        site_host_send_times = 0
        predict_result = None

        while True:
            traverse_tree = functools.partial(
                self.traverse_tree,
                tree_=self.tree_,
                decoder=self.decode,
                split_maskdict=self.split_maskdict)
            predict_data = predict_data.join(data_inst, traverse_tree)
            predict_leaf = predict_data.filter(
                lambda key, value: isinstance(value, tuple) is False)
            if predict_result is None:
                predict_result = predict_leaf
            else:
                predict_result = predict_result.union(predict_leaf)

            predict_data = predict_data.subtractByKey(predict_leaf)

            unleaf_node_count = predict_data.count()

            if unleaf_node_count == 0:
                self.sync_predict_finish_tag(True, site_host_send_times)
                break

            self.sync_predict_finish_tag(False, site_host_send_times)
            self.sync_predict_data(predict_data, site_host_send_times)

            predict_data_host = self.sync_data_predicted_by_host(
                site_host_send_times)
            for i in range(len(predict_data_host)):
                predict_data = predict_data.join(
                    predict_data_host[i],
                    lambda state1_nodeid1, state2_nodeid2: state1_nodeid1
                    if state1_nodeid1[1] == 0 else state2_nodeid2)

            site_host_send_times += 1

        LOGGER.info("predict finish!")
        return predict_result

    def get_model_meta(self):
        model_meta = DecisionTreeModelMeta()
        model_meta.criterion_meta.CopyFrom(
            CriterionMeta(criterion_method=self.criterion_method,
                          criterion_param=self.criterion_params))

        model_meta.max_depth = self.max_depth
        model_meta.min_sample_split = self.min_sample_split
        model_meta.min_impurity_split = self.min_impurity_split
        model_meta.min_leaf_node = self.min_leaf_node

        return model_meta

    def set_model_meta(self, model_meta):
        self.max_depth = model_meta.max_depth
        self.min_sample_split = model_meta.min_sample_split
        self.min_impurity_split = model_meta.min_impurity_split
        self.min_leaf_node = model_meta.min_leaf_node
        self.criterion_method = model_meta.criterion_meta.criterion_method
        self.criterion_params = list(model_meta.criterion_meta.criterion_param)

    def get_model_param(self):
        model_param = DecisionTreeModelParam()
        for node in self.tree_:
            model_param.tree_.add(id=node.id,
                                  sitename=node.sitename,
                                  fid=node.fid,
                                  bid=node.bid,
                                  weight=node.weight,
                                  is_leaf=node.is_leaf,
                                  left_nodeid=node.left_nodeid,
                                  right_nodeid=node.right_nodeid)

        model_param.split_maskdict.update(self.split_maskdict)

        return model_param

    def set_model_param(self, model_param):
        self.tree_ = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid)

            self.tree_.append(_node)

        self.split_maskdict = dict(model_param.split_maskdict)

    def get_model(self):
        model_meta = self.get_model_meta()
        model_param = self.get_model_param()

        return model_meta, model_param

    def load_model(self, model_meta=None, model_param=None):
        LOGGER.info("load tree model")
        self.set_model_meta(model_meta)
        self.set_model_param(model_param)

    def get_feature_importance(self):
        return self.feature_importances_
    def __init__(self, tree_param: DecisionTreeParam, data_bin = None, bin_split_points: np.array = None,
                 bin_sparse_point=None, g_h = None, valid_feature: dict = None, epoch_idx: int = None,
                 role: str = None, tree_idx: int = None, flow_id: int = None, mode='train'):

        """
        Parameters
        ----------
        tree_param: decision tree parameter object
        data_bin binned: data instance
        bin_split_points: data split points
        bin_sparse_point: sparse data point
        g_h computed: g val and h val of instances
        valid_feature: dict points out valid features {valid:true,invalid:false}
        epoch_idx: current epoch index
        role: host or guest
        flow_id: flow id
        mode: train / predict
        """

        super(HomoDecisionTreeClient, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)
        self.data_bin = data_bin
        self.g_h = g_h
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_point
        self.epoch_idx = epoch_idx
        self.tree_idx = tree_idx

        # check max_split_nodes
        if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1:
            self.max_split_nodes += 1
            LOGGER.warning('an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}'.format(self.max_split_nodes))

        self.transfer_inst = HomoDecisionTreeTransferVariable()

        """
        initializing here
        """
        self.valid_features = valid_feature

        self.tree_node = []  # start from root node
        self.tree_node_num = 0
        self.cur_layer_node = []

        self.runtime_idx = 0
        self.sitename = consts.GUEST
        self.feature_importance = {}

        self.inst2node_idx = None

        # record weights of samples
        self.sample_weights = None

        # secure aggregator, class SecureBoostClientAggregator
        if mode == 'train':
            self.role = role
            self.set_flowid(flow_id)
            self.aggregator = DecisionTreeClientAggregator(verbose=False)

        elif mode == 'predict':
            self.role, self.aggregator = None, None
class HomoDecisionTreeClient(DecisionTree):

    def __init__(self, tree_param: DecisionTreeParam, data_bin = None, bin_split_points: np.array = None,
                 bin_sparse_point=None, g_h = None, valid_feature: dict = None, epoch_idx: int = None,
                 role: str = None, tree_idx: int = None, flow_id: int = None, mode='train'):

        """
        Parameters
        ----------
        tree_param: decision tree parameter object
        data_bin binned: data instance
        bin_split_points: data split points
        bin_sparse_point: sparse data point
        g_h computed: g val and h val of instances
        valid_feature: dict points out valid features {valid:true,invalid:false}
        epoch_idx: current epoch index
        role: host or guest
        flow_id: flow id
        mode: train / predict
        """

        super(HomoDecisionTreeClient, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params, self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)
        self.data_bin = data_bin
        self.g_h = g_h
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_point
        self.epoch_idx = epoch_idx
        self.tree_idx = tree_idx

        # check max_split_nodes
        if self.max_split_nodes != 0 and self.max_split_nodes % 2 == 1:
            self.max_split_nodes += 1
            LOGGER.warning('an even max_split_nodes value is suggested when using histogram-subtraction, max_split_nodes reset to {}'.format(self.max_split_nodes))

        self.transfer_inst = HomoDecisionTreeTransferVariable()

        """
        initializing here
        """
        self.valid_features = valid_feature

        self.tree_node = []  # start from root node
        self.tree_node_num = 0
        self.cur_layer_node = []

        self.runtime_idx = 0
        self.sitename = consts.GUEST
        self.feature_importance = {}

        self.inst2node_idx = None

        # record weights of samples
        self.sample_weights = None

        # secure aggregator, class SecureBoostClientAggregator
        if mode == 'train':
            self.role = role
            self.set_flowid(flow_id)
            self.aggregator = DecisionTreeClientAggregator(verbose=False)

        elif mode == 'predict':
            self.role, self.aggregator = None, None

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def get_grad_hess_sum(self, grad_and_hess_table):
        LOGGER.info("calculate the sum of grad and hess")
        grad, hess = grad_and_hess_table.reduce(
            lambda value1, value2: (value1[0] + value2[0], value1[1] + value2[1]))
        return grad, hess

    def update_feature_importance(self, split_info: List[SplitInfo]):

        for splitinfo in split_info:

            if self.feature_importance_type == "split":
                inc = 1
            elif self.feature_importance_type == "gain":
                inc = splitinfo.gain
            else:
                raise ValueError("feature importance type {} not support yet".format(self.feature_importance_type))

            fid = splitinfo.best_fid

            if fid not in self.feature_importance:
                self.feature_importance[fid] = 0

            self.feature_importance[fid] += inc

    def sync_local_node_histogram(self, acc_histogram: List[HistogramBag], suffix):
        # sending local histogram
        self.aggregator.send_histogram(acc_histogram, suffix=suffix)
        LOGGER.debug('local histogram sent at layer {}'.format(suffix[0]))

    def get_node_map(self, nodes: List[Node], left_node_only=True):
        node_map = {}
        idx = 0
        for node in nodes:
            if node.id != 0 and (not node.is_left_node and left_node_only):
                continue
            node_map[node.id] = idx
            idx += 1
        return node_map

    def get_local_histogram(self, cur_to_split: List[Node], g_h, table_with_assign,
                            split_points, sparse_point, valid_feature):
        LOGGER.info("start to get node histograms")
        node_map = self.get_node_map(nodes=cur_to_split)
        histograms = FeatureHistogram.calculate_histogram(
            table_with_assign, g_h,
            split_points, sparse_point,
            valid_feature, node_map,
            self.use_missing, self.zero_as_missing)

        hist_bags = []
        for hist_list in histograms:
            hist_bags.append(HistogramBag(hist_list))

        return hist_bags

    def get_left_node_local_histogram(self, cur_nodes: List[Node], tree: List[Node], g_h, table_with_assign,
                            split_points, sparse_point, valid_feature):

        node_map = self.get_node_map(cur_nodes, left_node_only=True)

        LOGGER.info("start to get node histograms")
        histograms = FeatureHistogram.calculate_histogram(
            table_with_assign, g_h,
            split_points, sparse_point,
            valid_feature, node_map,
            self.use_missing, self.zero_as_missing)

        hist_bags = []
        for hist_list in histograms:
            hist_bags.append(HistogramBag(hist_list))

        left_nodes = []
        for node in cur_nodes:
            if node.is_left_node or node.id == 0:
                left_nodes.append(node)

        # set histogram id and parent histogram id
        for node, hist_bag in zip(left_nodes, hist_bags):
            # LOGGER.debug('node id {}, node parent id {}, cur tree {}'.format(node.id, node.parent_nodeid, len(tree)))
            hist_bag.hid = node.id
            hist_bag.p_hid = node.parent_nodeid

        return hist_bags

    def update_tree(self, cur_to_split: List[Node], split_info: List[SplitInfo]):
        """
        update current tree structure
        ----------
        split_info
        """
        LOGGER.debug('updating tree_node, cur layer has {} node'.format(len(cur_to_split)))
        next_layer_node = []
        assert len(cur_to_split) == len(split_info)

        for idx in range(len(cur_to_split)):
            sum_grad = cur_to_split[idx].sum_grad
            sum_hess = cur_to_split[idx].sum_hess
            if split_info[idx].best_fid is None or split_info[idx].gain <= self.min_impurity_split + consts.FLOAT_ZERO:
                cur_to_split[idx].is_leaf = True
                self.tree_node.append(cur_to_split[idx])
                continue

            cur_to_split[idx].fid = split_info[idx].best_fid
            cur_to_split[idx].bid = split_info[idx].best_bid
            cur_to_split[idx].missing_dir = split_info[idx].missing_dir

            p_id = cur_to_split[idx].id
            l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2
            cur_to_split[idx].left_nodeid, cur_to_split[idx].right_nodeid = l_id, r_id
            self.tree_node_num += 2

            l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess

            # create new left node and new right node
            left_node = Node(id=l_id,
                             sitename=self.sitename,
                             sum_grad=l_g,
                             sum_hess=l_h,
                             weight=self.splitter.node_weight(l_g, l_h),
                             parent_nodeid=p_id,
                             sibling_nodeid=r_id,
                             is_left_node=True)
            right_node = Node(id=r_id,
                              sitename=self.sitename,
                              sum_grad=sum_grad - l_g,
                              sum_hess=sum_hess - l_h,
                              weight=self.splitter.node_weight(sum_grad - l_g, sum_hess - l_h),
                              parent_nodeid=p_id,
                              sibling_nodeid=l_id,
                              is_left_node=False)

            next_layer_node.append(left_node)
            print('append left,cur tree has {} node'.format(len(self.tree_node)))
            next_layer_node.append(right_node)
            print('append right,cur tree has {} node'.format(len(self.tree_node)))
            self.tree_node.append(cur_to_split[idx])

        return next_layer_node

    def convert_bin_to_val(self):
        """
        convert current bid in tree nodes to real value
        """
        for node in self.tree_node:
            if not node.is_leaf:
                node.bid = self.bin_split_points[node.fid][node.bid]

    def assign_instance_to_root_node(self, data_bin, root_node_id):
        return data_bin.mapValues(lambda inst: (1, root_node_id))

    @staticmethod
    def assign_a_instance(row, tree: List[Node], bin_sparse_point, use_missing, use_zero_as_missing):

        leaf_status, nodeid = row[1]
        node = tree[nodeid]
        if node.is_leaf:
            return node.weight

        fid = node.fid
        bid = node.bid

        missing_dir = node.missing_dir

        missing_val = False
        if use_zero_as_missing:
            if row[0].features.get_data(fid, None) is None or \
                    row[0].features.get_data(fid) == NoneType():
                missing_val = True
        elif use_missing and row[0].features.get_data(fid) == NoneType():
            missing_val = True

        if missing_val:
            if missing_dir == 1:
                return 1, tree[nodeid].right_nodeid
            else:
                return 1, tree[nodeid].left_nodeid
        else:
            if row[0].features.get_data(fid, bin_sparse_point[fid]) <= bid:
                return 1, tree[nodeid].left_nodeid
            else:
                return 1, tree[nodeid].right_nodeid

    def assign_instance_to_new_node(self, table_with_assignment, tree_node: List[Node]):

        LOGGER.debug('re-assign instance to new nodes')
        assign_method = functools.partial(self.assign_a_instance, tree=tree_node, bin_sparse_point=
                                          self.bin_sparse_points, use_missing=self.use_missing, use_zero_as_missing
                                          =self.zero_as_missing)
        # FIXME
        assign_result = table_with_assignment.mapValues(assign_method)
        leaf_val = assign_result.filter(lambda key, value: isinstance(value, tuple) is False)

        assign_result = assign_result.subtractByKey(leaf_val)

        return assign_result, leaf_val

    @staticmethod
    def get_node_sample_weights(inst2node, tree_node: List[Node]):
        """
        get samples' weights which correspond to its node assignment
        """
        func = functools.partial(lambda inst, nodes: nodes[inst[1]].weight, nodes=tree_node)
        return inst2node.mapValues(func)

    def get_feature_importance(self):
        return self.feature_importance

    def sync_tree(self,):
        pass

    def sync_cur_layer_node_num(self, node_num, suffix):
        self.transfer_inst.cur_layer_node_num.remote(node_num, role=consts.ARBITER, idx=-1, suffix=suffix)

    def sync_best_splits(self, suffix) -> List[SplitInfo]:

        best_splits = self.transfer_inst.best_split_points.get(idx=0, suffix=suffix)
        return best_splits

    def fit(self):
        """
        start to fit
        """
        LOGGER.info('begin to fit h**o decision tree, epoch {}, tree idx {}'.format(self.epoch_idx, self.tree_idx))

        # compute local g_sum and h_sum
        g_sum, h_sum = self.get_grad_hess_sum(self.g_h)

        # get aggregated root info
        self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx))
        g_h_dict = self.aggregator.get_aggregated_root_info(suffix=('root_node_sync2', self.epoch_idx))
        global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum']

        # initialize node
        root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight=
                         self.splitter.node_weight(global_g_sum, global_h_sum))

        self.cur_layer_node = [root_node]
        LOGGER.debug('assign samples to root node')
        self.inst2node_idx = self.assign_instance_to_root_node(self.data_bin, 0)

        for dep in range(self.max_depth):

            if dep + 1 == self.max_depth:

                for node in self.cur_layer_node:
                    node.is_leaf = True
                    self.tree_node.append(node)
                rest_sample_weights = self.get_node_sample_weights(self.inst2node_idx, self.tree_node)
                if self.sample_weights is None:
                    self.sample_weights = rest_sample_weights
                else:
                    self.sample_weights = self.sample_weights.union(rest_sample_weights)

                # stop fitting
                break

            LOGGER.debug('start to fit layer {}'.format(dep))

            table_with_assignment = self.data_bin.join(self.inst2node_idx, lambda inst, assignment: (inst, assignment))

            # send current layer node number:
            self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx))

            split_info, agg_histograms = [], []
            for batch_id, i in enumerate(range(0, len(self.cur_layer_node), self.max_split_nodes)):
                cur_to_split = self.cur_layer_node[i:i+self.max_split_nodes]

                node_map = self.get_node_map(nodes=cur_to_split)
                LOGGER.debug('node map is {}'.format(node_map))
                LOGGER.debug('computing histogram for batch{} at depth{}'.format(batch_id, dep))
                local_histogram = self.get_left_node_local_histogram(
                    cur_nodes=cur_to_split,
                    tree=self.tree_node,
                    g_h=self.g_h,
                    table_with_assign=table_with_assignment,
                    split_points=self.bin_split_points,
                    sparse_point=self.bin_sparse_points,
                    valid_feature=self.valid_features
                )

                LOGGER.debug('federated finding best splits for batch{} at layer {}'.format(batch_id, dep))
                self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx))

                agg_histograms += local_histogram

            split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx))
            LOGGER.debug('got best splits from arbiter')

            new_layer_node = self.update_tree(self.cur_layer_node, split_info)
            self.cur_layer_node = new_layer_node
            self.update_feature_importance(split_info)

            self.inst2node_idx, leaf_val = self.assign_instance_to_new_node(table_with_assignment, self.tree_node)

            # record leaf val
            if self.sample_weights is None:
                self.sample_weights = leaf_val
            else:
                self.sample_weights = self.sample_weights.union(leaf_val)

            LOGGER.debug('assigning instance to new nodes done')

        self.convert_bin_to_val()
        LOGGER.debug('fitting tree done')
        LOGGER.debug('tree node num is {}'.format(len(self.tree_node)))

    def traverse_tree(self, data_inst: Instance, tree: List[Node], use_missing=True, zero_as_missing=True):

        nid = 0# root node id
        while True:

            if tree[nid].is_leaf:
                return tree[nid].weight

            cur_node = tree[nid]
            fid,bid = cur_node.fid,cur_node.bid
            missing_dir = cur_node.missing_dir

            if use_missing and zero_as_missing:

                if data_inst.features.get_data(fid) == NoneType() or data_inst.features.get_data(fid, None) is None:

                    nid = tree[nid].right_nodeid if missing_dir == 1 else tree[nid].left_nodeid

                elif data_inst.features.get_data(fid) <= bid:
                    nid = tree[nid].left_nodeid
                else:
                    nid = tree[nid].right_nodeid

            elif data_inst.features.get_data(fid) == NoneType():

                nid = tree[nid].right_nodeid if missing_dir == 1 else tree[nid].left_nodeid

            elif data_inst.features.get_data(fid, 0) <= bid:
                nid = tree[nid].left_nodeid
            else:
                nid = tree[nid].right_nodeid

    def predict(self, data_inst):

        LOGGER.debug('tree start to predict')

        traverse_tree = functools.partial(self.traverse_tree,
                                          tree=self.tree_node,
                                          use_missing=self.use_missing,
                                          zero_as_missing=self.zero_as_missing,)

        predicted_weights = data_inst.mapValues(traverse_tree)

        return predicted_weights

    def get_model_meta(self):
        model_meta = DecisionTreeModelMeta()
        model_meta.criterion_meta.CopyFrom(CriterionMeta(criterion_method=self.criterion_method,
                                                         criterion_param=self.criterion_params))

        model_meta.max_depth = self.max_depth
        model_meta.min_sample_split = self.min_sample_split
        model_meta.min_impurity_split = self.min_impurity_split
        model_meta.min_leaf_node = self.min_leaf_node
        model_meta.use_missing = self.use_missing
        model_meta.zero_as_missing = self.zero_as_missing

        return model_meta

    def set_model_meta(self, model_meta):
        self.max_depth = model_meta.max_depth
        self.min_sample_split = model_meta.min_sample_split
        self.min_impurity_split = model_meta.min_impurity_split
        self.min_leaf_node = model_meta.min_leaf_node
        self.criterion_method = model_meta.criterion_meta.criterion_method
        self.criterion_params = list(model_meta.criterion_meta.criterion_param)
        self.use_missing = model_meta.use_missing
        self.zero_as_missing = model_meta.zero_as_missing

    def get_model_param(self):
        model_param = DecisionTreeModelParam()
        for node in self.tree_node:
            model_param.tree_.add(id=node.id,
                                  sitename=self.role,
                                  fid=node.fid,
                                  bid=node.bid,
                                  weight=node.weight,
                                  is_leaf=node.is_leaf,
                                  left_nodeid=node.left_nodeid,
                                  right_nodeid=node.right_nodeid,
                                  missing_dir=node.missing_dir)

        LOGGER.debug('output tree: epoch_idx:{} tree_idx:{}'.format(self.epoch_idx, self.tree_idx))
        return model_param

    def set_model_param(self, model_param):
        self.tree_node = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid,
                         missing_dir=node_param.missing_dir)

            self.tree_node.append(_node)

    def get_model(self):
        model_meta = self.get_model_meta()
        model_param = self.get_model_param()
        return model_meta, model_param

    def load_model(self, model_meta=None, model_param=None):
        LOGGER.info("load tree model")
        self.set_model_meta(model_meta)
        self.set_model_param(model_param)

    """
    For debug
    """
    def print_leafs(self):
        LOGGER.debug('printing tree')
        for node in self.tree_node:
            LOGGER.debug(node)

    @staticmethod
    def print_split(split_infos: [SplitInfo]):
        LOGGER.debug('printing split info')
        for info in split_infos:
            LOGGER.debug(info)

    @staticmethod
    def print_hist(hist_list: [HistogramBag]):
        LOGGER.debug('printing histogramBag')
        for bag in hist_list:
            LOGGER.debug(bag)
Exemple #9
0
class HeteroDecisionTreeGuest(DecisionTree):
    def __init__(self, tree_param):
        LOGGER.info("hetero decision tree guest init!")
        super(HeteroDecisionTreeGuest, self).__init__(tree_param)
        self.splitter = Splitter(self.criterion_method, self.criterion_params,
                                 self.min_impurity_split,
                                 self.min_sample_split, self.min_leaf_node)

        self.data_bin = None
        self.grad_and_hess = None
        self.bin_split_points = None
        self.bin_sparse_points = None
        self.data_bin_with_node_dispatch = None
        self.node_dispatch = None
        self.infos = None
        self.valid_features = None
        self.encrypter = None
        self.node_positions = None
        self.best_splitinfo_guest = None
        self.tree_node_queue = None
        self.tree_ = []
        self.tree_node_num = 0
        self.split_maskdict = {}
        self.transfer_inst = HeteroDecisionTreeTransferVariable()
        self.predict_weights = None

    def set_flowid(self, flowid=0):
        LOGGER.info("set flowid, flowid is {}".format(flowid))
        self.transfer_inst.set_flowid(flowid)

    def set_inputinfo(self,
                      data_bin=None,
                      grad_and_hess=None,
                      bin_split_points=None,
                      bin_sparse_points=None):
        LOGGER.info("set input info")
        self.data_bin = data_bin
        self.grad_and_hess = grad_and_hess
        self.bin_split_points = bin_split_points
        self.bin_sparse_points = bin_sparse_points

    def set_encrypter(self, encrypter):
        LOGGER.info("set encrypter")
        self.encrypter = encrypter

    def encrypt(self, val):
        return self.encrypter.encrypt(val)

    def decrypt(self, val):
        return self.encrypter.decrypt(val)

    def encode(self, etype="feature_idx", val=None, nid=None):
        if etype == "feature_idx":
            return val

        if etype == "feature_val":
            self.split_maskdict[nid] = val
            return None

        raise TypeError("encode type %s is not support!" % (str(etype)))

    @staticmethod
    def decode(dtype="feature_idx", val=None, nid=None, split_maskdict=None):
        if dtype == "feature_idx":
            return val

        if dtype == "feature_val":
            if nid in split_maskdict:
                return split_maskdict[nid]
            else:
                raise ValueError(
                    "decode val %s cause error, can't reconize it!" %
                    (str(val)))

        return TypeError("decode type %s is not support!" % (str(dtype)))

    def set_valid_features(self, valid_features=None):
        LOGGER.info("set valid features")
        self.valid_features = valid_features

    def sync_encrypted_grad_and_hess(self):
        LOGGER.info("send encrypted grad and hess to host")
        encrypted_grad_and_hess = self.encrypt_grad_and_hess()
        federation.remote(obj=encrypted_grad_and_hess,
                          name=self.transfer_inst.encrypted_grad_and_hess.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.encrypted_grad_and_hess),
                          role=consts.HOST,
                          idx=0)

    def encrypt_grad_and_hess(self):
        LOGGER.info("start to encrypt grad and hess")
        encrypter = self.encrypter
        encrypted_grad_and_hess = self.grad_and_hess.mapValues(
            lambda grad_hess:
            (encrypter.encrypt(grad_hess[0]), encrypter.encrypt(grad_hess[1])))
        LOGGER.info("finish to encrypt grad and hess")
        return encrypted_grad_and_hess

    def get_grad_hess_sum(self, grad_and_hess_table):
        LOGGER.info("calculate the sum of grad and hess")
        grad, hess = grad_and_hess_table.reduce(lambda value1, value2: (value1[
            0] + value2[0], value1[1] + value2[1]))
        return grad, hess

    def dispatch_all_node_to_root(self, root_id=0):
        LOGGER.info("dispatch all node to root")
        self.node_dispatch = self.data_bin.mapValues(lambda data_inst:
                                                     (1, root_id))

    def get_histograms(self, node_map={}):
        LOGGER.info("start to get node histograms")
        histograms = FeatureHistogram.calculate_histogram(
            self.data_bin_with_node_dispatch, self.grad_and_hess,
            self.bin_split_points, self.bin_sparse_points, self.valid_features,
            node_map)
        acc_histograms = FeatureHistogram.accumulate_histogram(histograms)
        LOGGER.info("acc histogram shape is {}".format(len(acc_histograms)))
        return acc_histograms

    def sync_tree_node_queue(self, tree_node_queue, dep=-1):
        LOGGER.info("send tree node queue of depth {}".format(dep))
        mask_tree_node_queue = tree_node_queue.copy()
        for i in range(len(mask_tree_node_queue)):
            mask_tree_node_queue[i] = Node(id=mask_tree_node_queue[i].id)

        federation.remote(obj=mask_tree_node_queue,
                          name=self.transfer_inst.tree_node_queue.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.tree_node_queue, dep),
                          role=consts.HOST,
                          idx=0)

    def sync_node_positions(self, dep):
        LOGGER.info("send node positions of depth {}".format(dep))
        federation.remote(obj=self.node_dispatch,
                          name=self.transfer_inst.node_positions.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.node_positions, dep),
                          role=consts.HOST,
                          idx=0)

    def sync_encrypted_splitinfo_host(self, dep=-1):
        LOGGER.info("get encrypted splitinfo of depth {}".format(dep))
        encrypted_splitinfo_host = federation.get(
            name=self.transfer_inst.encrypted_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.encrypted_splitinfo_host, dep),
            idx=0)
        return encrypted_splitinfo_host

    def sync_federated_best_splitinfo_host(self,
                                           federated_best_splitinfo_host,
                                           dep=-1):
        LOGGER.info("send federated best splitinfo of depth {}".format(dep))
        federation.remote(
            obj=federated_best_splitinfo_host,
            name=self.transfer_inst.federated_best_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.federated_best_splitinfo_host, dep),
            role=consts.HOST,
            idx=0)

    def federated_find_split(self, dep=-1):
        LOGGER.info("federated find split of depth {}".format(dep))
        encrypted_splitinfo_host = self.sync_encrypted_splitinfo_host(dep)
        best_splitinfo_host = []
        for i in range(len(encrypted_splitinfo_host)):
            sum_grad = self.tree_node_queue[i].sum_grad
            sum_hess = self.tree_node_queue[i].sum_hess
            best_gain = self.min_impurity_split - consts.FLOAT_ZERO
            best_idx = -1
            for j in range(len(encrypted_splitinfo_host[i])):
                sum_grad_l, sum_hess_l = encrypted_splitinfo_host[i][j]
                sum_grad_l = self.decrypt(sum_grad_l)
                sum_hess_l = self.decrypt(sum_hess_l)
                sum_grad_r = sum_grad - sum_grad_l
                sum_hess_r = sum_hess - sum_hess_l
                gain = self.splitter.split_gain(sum_grad, sum_hess, sum_grad_l,
                                                sum_hess_l, sum_grad_r,
                                                sum_hess_r)

                if gain > self.min_impurity_split and gain > best_gain:
                    best_gain = gain
                    best_idx = j

            best_gain = self.encrypt(best_gain)

            best_splitinfo_host.append([best_idx, best_gain])

        self.sync_federated_best_splitinfo_host(best_splitinfo_host, dep)

    def sync_final_split_host(self, dep=-1):
        LOGGER.info("get host final splitinfo of depth {}".format(dep))
        final_splitinfo_host = federation.get(
            name=self.transfer_inst.final_splitinfo_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.final_splitinfo_host, dep),
            idx=0)

        return final_splitinfo_host

    def merge_splitinfo(self, splitinfo_guest, splitinfo_host):
        LOGGER.info("merge splitinfo")
        splitinfos = []
        for i in range(len(splitinfo_guest)):
            splitinfo = None
            gain_host = self.decrypt(splitinfo_host[i].gain)
            if splitinfo_guest[i].gain >= gain_host - consts.FLOAT_ZERO:
                splitinfo = splitinfo_guest[i]
            else:
                splitinfo = splitinfo_host[i]
                splitinfo.sum_grad = self.decrypt(splitinfo.sum_grad)
                splitinfo.sum_hess = self.decrypt(splitinfo.sum_hess)
                splitinfo.gain = gain_host
            splitinfos.append(splitinfo)
        return splitinfos

    def update_tree_node_queue(self, splitinfos, max_depth_reach):
        LOGGER.info(
            "update tree node, splitlist length is {}, tree node queue size is"
            .format(len(splitinfos), len(self.tree_node_queue)))
        new_tree_node_queue = []
        for i in range(len(self.tree_node_queue)):
            sum_grad = self.tree_node_queue[i].sum_grad
            sum_hess = self.tree_node_queue[i].sum_hess
            if max_depth_reach or splitinfos[i].gain <= \
                    self.min_impurity_split + consts.FLOAT_ZERO:
                self.tree_node_queue[i].is_leaf = True
            else:
                self.tree_node_queue[i].left_nodeid = self.tree_node_num + 1
                self.tree_node_queue[i].right_nodeid = self.tree_node_num + 2
                self.tree_node_num += 2

                left_node = Node(id=self.tree_node_queue[i].left_nodeid,
                                 sitename=consts.GUEST,
                                 sum_grad=splitinfos[i].sum_grad,
                                 sum_hess=splitinfos[i].sum_hess,
                                 weight=self.splitter.node_weight(
                                     splitinfos[i].sum_grad,
                                     splitinfos[i].sum_hess))
                right_node = Node(id=self.tree_node_queue[i].right_nodeid,
                                  sitename=consts.GUEST,
                                  sum_grad=sum_grad - splitinfos[i].sum_grad,
                                  sum_hess=sum_hess - splitinfos[i].sum_hess,
                                  weight=self.splitter.node_weight( \
                                      sum_grad - splitinfos[i].sum_grad,
                                      sum_hess - splitinfos[i].sum_hess))

                new_tree_node_queue.append(left_node)
                new_tree_node_queue.append(right_node)
                LOGGER.info("tree_node_queue {} split!!!".format(
                    self.tree_node_queue[i].id))

                self.tree_node_queue[i].sitename = splitinfos[i].sitename
                if self.tree_node_queue[i].sitename == consts.GUEST:
                    self.tree_node_queue[i].fid = self.encode(
                        "feature_idx", splitinfos[i].best_fid)
                    self.tree_node_queue[i].bid = self.encode(
                        "feature_val", splitinfos[i].best_bid,
                        self.tree_node_queue[i].id)
                else:
                    self.tree_node_queue[i].fid = splitinfos[i].best_fid
                    self.tree_node_queue[i].bid = splitinfos[i].best_bid

            self.tree_.append(self.tree_node_queue[i])

        self.tree_node_queue = new_tree_node_queue

    @staticmethod
    def dispatch_node(value,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None,
                      bin_sparse_points=None):
        unleaf_state, nodeid = value[1]
        if unleaf_state == 0:
            return value[1]

        if tree_[nodeid].is_leaf is True:
            return (0, nodeid)
        else:
            if tree_[nodeid].sitename == consts.GUEST:
                fid = decoder("feature_idx",
                              tree_[nodeid].fid,
                              split_maskdict=split_maskdict)
                bid = decoder("feature_val", tree_[nodeid].bid, nodeid,
                              split_maskdict)
                if value[0].features.get_data(fid,
                                              bin_sparse_points[fid]) <= bid:
                    return (1, tree_[nodeid].left_nodeid)
                else:
                    return (1, tree_[nodeid].right_nodeid)
            else:
                return (1, tree_[nodeid].fid, tree_[nodeid].bid, \
                        nodeid, tree_[nodeid].left_nodeid, tree_[nodeid].right_nodeid)

    def sync_dispatch_node_host(self, dispatch_guest_data, dep=-1):
        LOGGER.info("send node to host to dispath, depth is {}".format(dep))
        federation.remote(obj=dispatch_guest_data,
                          name=self.transfer_inst.dispatch_node_host.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.dispatch_node_host, dep),
                          role=consts.HOST,
                          idx=0)

    def sync_dispatch_node_host_result(self, dep=-1):
        LOGGER.info("get host dispatch result, depth is {}".format(dep))
        dispatch_node_host_result = federation.get(
            name=self.transfer_inst.dispatch_node_host_result.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.dispatch_node_host_result, dep),
            idx=0)

        return dispatch_node_host_result

    def redispatch_node(self, dep=-1):
        LOGGER.info("redispatch node of depth {}".format(dep))
        dispatch_node_method = functools.partial(
            self.dispatch_node,
            tree_=self.tree_,
            decoder=self.decode,
            split_maskdict=self.split_maskdict,
            bin_sparse_points=self.bin_sparse_points)
        dispatch_guest_result = self.data_bin_with_node_dispatch.mapValues(
            dispatch_node_method)
        tree_node_num = self.tree_node_num
        LOGGER.info("rmask edispatch node result of depth {}".format(dep))
        dispatch_node_mask = dispatch_guest_result.mapValues(
            lambda state_nodeid:
            (state_nodeid[0], random.randint(0, tree_node_num - 1))
            if len(state_nodeid) == 2 else state_nodeid)
        self.sync_dispatch_node_host(dispatch_node_mask, dep)
        dispatch_node_host_result = self.sync_dispatch_node_host_result(dep)

        self.node_dispatch = dispatch_guest_result.join(dispatch_node_host_result, \
                                                        lambda unleaf_state_nodeid1, unleaf_state_nodeid2: \
                                                            unleaf_state_nodeid1 if len(
                                                                unleaf_state_nodeid1) == 2 else unleaf_state_nodeid2)

    def sync_tree(self):
        LOGGER.info("sync tree to host")
        federation.remote(obj=self.tree_,
                          name=self.transfer_inst.tree.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.tree),
                          role=consts.HOST,
                          idx=0)

    def convert_bin_to_real(self):
        LOGGER.info("convert tree node bins to real value")
        for i in range(len(self.tree_)):
            if self.tree_[i].is_leaf is True:
                continue
            if self.tree_[i].sitename == consts.GUEST:
                fid = self.decode("feature_idx",
                                  self.tree_[i].fid,
                                  split_maskdict=self.split_maskdict)
                bid = self.decode("feature_val", self.tree_[i].bid,
                                  self.tree_[i].id, self.split_maskdict)
                real_splitval = self.encode("feature_val",
                                            self.bin_split_points[fid][bid],
                                            self.tree_[i].id)
                self.tree_[i].bid = real_splitval

    def fit(self):
        LOGGER.info("begin to fit guest decision tree")
        self.sync_encrypted_grad_and_hess()

        root_sum_grad, root_sum_hess = self.get_grad_hess_sum(
            self.grad_and_hess)
        root_node = Node(id=0,
                         sitename=consts.GUEST,
                         sum_grad=root_sum_grad,
                         sum_hess=root_sum_hess,
                         weight=self.splitter.node_weight(
                             root_sum_grad, root_sum_hess))
        self.tree_node_queue = [root_node]

        self.dispatch_all_node_to_root()

        for dep in range(self.max_depth):
            LOGGER.info(
                "start to fit depth {}, tree node queue size is {}".format(
                    dep, len(self.tree_node_queue)))
            self.sync_tree_node_queue(self.tree_node_queue, dep)
            if len(self.tree_node_queue) == 0:
                break

            self.sync_node_positions(dep)

            node_map = {}
            node_num = 0
            for tree_node in self.tree_node_queue:
                node_map[tree_node.id] = node_num
                node_num += 1

            self.data_bin_with_node_dispatch = self.data_bin.join(
                self.node_dispatch, lambda data_inst, dispatch_info:
                (data_inst, dispatch_info))

            acc_histograms = self.get_histograms(node_map=node_map)
            self.best_splitinfo_guest = self.splitter.find_split(
                acc_histograms, self.valid_features)

            self.federated_find_split(dep)

            final_splitinfo_host = self.sync_final_split_host(dep)

            splitinfos = self.merge_splitinfo(self.best_splitinfo_guest,
                                              final_splitinfo_host)
            max_depth_reach = True if dep + 1 == self.max_depth else False
            self.update_tree_node_queue(splitinfos, max_depth_reach)
            self.redispatch_node(dep)

        self.sync_tree()
        self.convert_bin_to_real()
        tree_ = self.tree_
        LOGGER.info("tree node num is %d" % len(tree_))
        self.predict_weights = self.node_dispatch.mapValues(
            lambda unleaf_state_nodeid: tree_[unleaf_state_nodeid[1]].weight)

        LOGGER.info("end to fit guest decision tree")

    @staticmethod
    def traverse_tree(predict_state,
                      data_inst,
                      tree_=None,
                      decoder=None,
                      split_maskdict=None):
        tag, nid = predict_state
        if tag == 0:
            return (tag, nid)

        while tree_[nid].sitename != consts.HOST:
            if tree_[nid].is_leaf is True:
                return (0, nid)

            fid = decoder("feature_idx",
                          tree_[nid].fid,
                          split_maskdict=split_maskdict)
            bid = decoder("feature_val", tree_[nid].bid, nid, split_maskdict)

            if data_inst.features.get_data(fid, 0) <= bid:
                nid = tree_[nid].left_nodeid
            else:
                nid = tree_[nid].right_nodeid

        return (1, nid)

    def sync_predict_finish_tag(self, finish_tag, send_times):
        LOGGER.info("send the {}-th predict finish tag {} to host".format(
            finish_tag, send_times))
        federation.remote(obj=finish_tag,
                          name=self.transfer_inst.predict_finish_tag.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.predict_finish_tag,
                              send_times),
                          role=consts.HOST,
                          idx=0)

    def sync_predict_data(self, predict_data, send_times):
        LOGGER.info("send predict data to host, sending times is {}".format(
            send_times))
        federation.remote(obj=predict_data,
                          name=self.transfer_inst.predict_data.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.predict_data, send_times),
                          role=consts.HOST,
                          idx=0)

    def sync_data_predicted_by_host(self, send_times):
        LOGGER.info(
            "get predicted data by host, recv times is {}".format(send_times))
        predict_data = federation.get(
            name=self.transfer_inst.predict_data_by_host.name,
            tag=self.transfer_inst.generate_transferid(
                self.transfer_inst.predict_data_by_host, send_times),
            idx=0)
        return predict_data

    def predict(self, data_inst):
        LOGGER.info("start to predict!")
        predict_data = data_inst.mapValues(lambda data_inst: (1, 0))
        site_host_send_times = 0
        while True:
            traverse_tree = functools.partial(
                self.traverse_tree,
                tree_=self.tree_,
                decoder=self.decode,
                split_maskdict=self.split_maskdict)
            predict_data = predict_data.join(data_inst, traverse_tree)

            unleaf_node_count = predict_data.reduce(
                lambda value1, value2: (value1[0] + value2[0], 0))[0]

            if unleaf_node_count == 0:
                self.sync_predict_finish_tag(True, site_host_send_times)
                break

            predict_data_mask = predict_data.mapValues(lambda state_nodeid: (
                state_nodeid[0], random.randint(0,
                                                len(self.tree_) - 1)
            ) if state_nodeid[0] == 0 else state_nodeid)

            self.sync_predict_finish_tag(False, site_host_send_times)
            self.sync_predict_data(predict_data_mask, site_host_send_times)

            predict_data_host = self.sync_data_predicted_by_host(
                site_host_send_times)
            predict_data = predict_data.join(predict_data_host, \
                                             lambda unleaf_state1_nodeid1, unleaf_state2_nodeid2: \
                                                 unleaf_state1_nodeid1 if unleaf_state1_nodeid1[
                                                                              0] == 0 else unleaf_state2_nodeid2)

            site_host_send_times += 1

        predict_data = predict_data.mapValues(
            lambda tag_nid: self.tree_[tag_nid[1]].weight)

        LOGGER.info("predict finish!")
        return predict_data

    def get_tree_model(self):
        LOGGER.info("get tree model")
        tree_model = DecisionTreeModelMeta()
        tree_model.tree_ = self.tree_
        tree_model.split_maskdict = self.split_maskdict
        return tree_model

    def set_tree_model(self, tree_model):
        LOGGER.info("set tree model")
        self.tree_ = tree_model.tree_
        self.split_maskdict = tree_model.split_maskdict