def __init__(self):
        super(HomoSecureBoostingTreeArbiter, self).__init__()

        self.mode = consts.H**O
        self.feature_num = 0
        self.role = consts.ARBITER
        self.transfer_inst = HomoSecureBoostingTreeTransferVariable()
        self.check_convergence_func = None
        self.tree_dim = None
        self.aggregator = SecureBoostArbiterAggregator()
        self.global_loss_history = []

        # federated_binning obj
        self.binning_obj = HomoFeatureBinningServer()
Ejemplo n.º 2
0
    def fit(self, data_inst, validate_data=None):

        # init binning obj
        self.aggregator = HomoBoostArbiterAggregator()
        self.binning_obj = HomoFeatureBinningServer()

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        # sync start round and end round
        self.sync_start_round_and_end_round()

        LOGGER.info('begin to fit a boosting tree')
        self.preprocess()
        for epoch_idx in range(self.start_round, self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_learner(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))
        self.postprocess()
        self.callback_list.on_train_end()
        self.set_summary(self.generate_summary())
Ejemplo n.º 3
0
 def __init__(self):
     super(HomoBoostingArbiter, self).__init__()
     self.aggregator = HomoBoostArbiterAggregator()
     self.transfer_inst = HomoBoostingTransferVariable()
     self.check_convergence_func = None
     self.binning_obj = HomoFeatureBinningServer()
Ejemplo n.º 4
0
class HomoBoostingArbiter(Boosting, ABC):
    def __init__(self):
        super(HomoBoostingArbiter, self).__init__()
        self.aggregator = HomoBoostArbiterAggregator()
        self.transfer_inst = HomoBoostingTransferVariable()
        self.check_convergence_func = None
        self.binning_obj = HomoFeatureBinningServer()

    def federated_binning(self, ):

        binning_param = HomoFeatureBinningParam(method=consts.RECURSIVE_QUERY,
                                                bin_num=self.bin_num,
                                                error=self.binning_error)

        if self.use_missing:
            self.binning_obj = recursive_query_binning.Server(
                binning_param, abnormal_list=[NoneType()])
        else:
            self.binning_obj = recursive_query_binning.Server(binning_param,
                                                              abnormal_list=[])

        self.binning_obj.fit_split_points(None)

    def sync_feature_num(self):
        feature_num_list = self.transfer_inst.feature_number.get(
            idx=-1, suffix=('feat_num', ))
        for num in feature_num_list[1:]:
            assert feature_num_list[0] == num
        return feature_num_list[0]

    def check_label(self):
        pass

    def fit(self, data_inst, validate_data=None):

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = HomoLabelEncoderArbiter().label_alignment()
            LOGGER.info('label mapping is {}'.format(label_mapping))
            self.booster_dim = len(
                label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.info('begin to fit a boosting tree')
        for epoch_idx in range(self.boosting_round):

            LOGGER.info('cur epoch idx is {}'.format(epoch_idx))

            for class_idx in range(self.booster_dim):
                model = self.fit_a_booster(epoch_idx, class_idx)

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.history_loss.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.history_loss)}))

        self.set_summary(self.generate_summary())

    def predict(self, data_inst=None):
        LOGGER.debug('arbiter skip prediction')

    @abc.abstractmethod
    def fit_a_booster(self, epoch_idx: int, booster_dim: int):
        raise NotImplementedError()

    @abc.abstractmethod
    def load_booster(self, model_meta, model_param, epoch_idx, booster_idx):
        raise NotImplementedError()
class HomoSecureBoostingTreeArbiter(BoostingTree):
    def __init__(self):
        super(HomoSecureBoostingTreeArbiter, self).__init__()

        self.mode = consts.H**O
        self.feature_num = 0
        self.role = consts.ARBITER
        self.transfer_inst = HomoSecureBoostingTreeTransferVariable()
        self.check_convergence_func = None
        self.tree_dim = None
        self.aggregator = SecureBoostArbiterAggregator()
        self.global_loss_history = []

        # federated_binning obj
        self.binning_obj = HomoFeatureBinningServer()

    def sample_valid_feature(self):

        chosen_feature = random.choice(
            range(0, self.feature_num),
            max(1, int(self.subsample_feature_rate * self.feature_num)),
            replace=False)
        valid_features = [False for i in range(self.feature_num)]
        for fid in chosen_feature:
            valid_features[fid] = True

        return valid_features

    def sync_feature_num(self):
        feature_num_list = self.transfer_inst.feature_number.get(
            idx=-1, suffix=('feat_num', ))
        for num in feature_num_list[1:]:
            assert feature_num_list[0] == num
        return feature_num_list[0]

    def sync_stop_flag(self, stop_flag, suffix):
        self.transfer_inst.stop_flag.remote(stop_flag, idx=-1, suffix=suffix)

    def sync_current_loss(self, suffix):
        loss_status_list = self.transfer_inst.loss_status.get(idx=-1,
                                                              suffix=suffix)
        total_loss, total_num = 0, 0
        for l_ in loss_status_list:
            total_loss += l_['cur_loss'] * l_['sample_num']
            total_num += l_['sample_num']
        LOGGER.debug(
            'loss status received, total_loss {}, total_num {}'.format(
                total_loss, total_num))
        return total_loss / total_num

    def sync_tree_dim(self):
        tree_dims = self.transfer_inst.tree_dim.get(idx=-1,
                                                    suffix=('tree_dim', ))
        dim0 = tree_dims[0]
        for dim in tree_dims[1:]:
            assert dim0 == dim
        return dim0

    def check_convergence(self, cur_loss):
        LOGGER.debug('checking convergence')
        return self.check_convergence_func.is_converge(cur_loss)

    def generate_flowid(self, round_num, tree_num):
        LOGGER.info("generate flowid, flowid {}".format(self.flowid))
        return ".".join(map(str, [self.flowid, round_num, tree_num]))

    def label_alignment(self) -> List:
        labels = self.transfer_inst.local_labels.get(idx=-1,
                                                     suffix=('label_align', ))
        label_set = set()
        for local_label in labels:
            label_set.update(local_label)
        global_label = list(label_set)
        global_label = sorted(global_label)
        label_mapping = {v: k for k, v in enumerate(global_label)}
        self.transfer_inst.label_mapping.remote(label_mapping,
                                                idx=-1,
                                                suffix=('label_mapping', ))
        return label_mapping

    def federated_binning(self):
        self.binning_obj.average_run()

    def send_valid_features(self, valid_features, epoch_idx, t_idx):
        self.transfer_inst.valid_features.remote(valid_features,
                                                 idx=-1,
                                                 suffix=('valid_features',
                                                         epoch_idx, t_idx))

    def fit(self, data_inst, valid_inst=None):

        self.federated_binning()
        # initializing
        self.feature_num = self.sync_feature_num()
        self.tree_dim = 1

        if self.task_type == consts.CLASSIFICATION:
            label_mapping = self.label_alignment()
            LOGGER.debug('label mapping is {}'.format(label_mapping))
            self.tree_dim = len(label_mapping) if len(label_mapping) > 2 else 1

        if self.n_iter_no_change:
            self.check_convergence_func = converge_func_factory(
                "diff", self.tol)

        LOGGER.debug('begin to fit a boosting tree')
        for epoch_idx in range(self.num_trees):

            for t_idx in range(self.tree_dim):
                valid_feature = self.sample_valid_feature()
                self.send_valid_features(valid_feature, epoch_idx, t_idx)
                flow_id = self.generate_flowid(epoch_idx, t_idx)
                new_tree = HomoDecisionTreeArbiter(self.tree_param,
                                                   valid_feature=valid_feature,
                                                   epoch_idx=epoch_idx,
                                                   flow_id=flow_id,
                                                   tree_idx=t_idx)
                new_tree.fit()

            global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx, ))
            self.global_loss_history.append(global_loss)
            LOGGER.debug('cur epoch global loss is {}'.format(global_loss))

            self.callback_metric("loss", "train",
                                 [Metric(epoch_idx, global_loss)])

            if self.n_iter_no_change:
                should_stop = self.aggregator.broadcast_converge_status(
                    self.check_convergence, (global_loss, ),
                    suffix=(epoch_idx, ))
                LOGGER.debug('stop flag sent')
                if should_stop:
                    break

        self.callback_meta(
            "loss", "train",
            MetricMeta(name="train",
                       metric_type="LOSS",
                       extra_metas={"Best": min(self.global_loss_history)}))

        LOGGER.debug('fitting h**o decision tree done')

    def predict(self, data_inst):

        LOGGER.debug('start predicting')